
! Subroutine create_phase_space(nmuons, create_new_distribution, new_file, muon_file, &
!         epsdistr, epsx, epsy, tdistr, tlength, tsigma, pzdistr, pz, pzsigma, twiss2, inf_aperture, scatter_inf_end_us, scatter_inf_end_ds, energy_loss, muons, twiss_ref, mat_inv)
!
! Routine to create distribution of muons.
!
! Modules needed: 
!  use nr
!  use parameters_bmad
!  use muon_mod
!  use muon_interace
!  use materials_mod
!  use random_mod
!
! Input:
!    nmuons -- Integer: number of muons in the distribution
!    create_new_distribution -- Logical: If true create a new distribution. If false, read an existing distribution from a file
!    new_file -- Character: Name of the file with the new distribution. If empty, that is if new_file = '', then do not write the distribution to a file
!    muon_file -- Character: Name of the file of the distribution to read
!    epsdistr --  Character: If 'gaus' then new distribution will be gaussian in transverse phase space
!                            If 'flat' then new distribution will be flat in transverse phase space
!    epsx,epsy  -- Real: Emittance of gaussian distribution
!    tdistr --     Character: If tdistr = 'e821' then time distribution is gaussian with sigma =25ns
!                             If tdistr = 'e989' then time distribution is FNAL W
!                             If tdistr = 'gaus' then time distribution is gaussian with sigma = tsigma with tails cut off beyond +- tlength/2
!    tlength --    Real: Maximum time in temporal distribution if the distribution is gaussian (tdistr='gaus')
!    tsigma  --    Real: Width of gaussian temporal distribution if the distribution is gaussian (tdistr='gaus')
!    pzdistr --    Character: If pzdistr = 'gaus' then momentum (vec(6)) distribution is gaussian, with width pzsigma and tails cut off beyond +- pz/2
!                             If pzdistr = 'flat' then momentum distribution is flat, with width +- pz/2
!    pz     -- Real: Maximum momentum offset of momentum distribution if gaussian (pzdistr='gaus')
!    pzsigma -- Real: Width of momentum distribution (vec(6)) if pzdistr = 'gaus'
!    twiss2 -- G2_TWISS_STRUCT: Twiss parameters in the inflector. The distribution will be generated so that in the inflector it will have these twiss parameters
!    inf_aperture -- Character: If inf_aperture = 'e821' then the aperture is defined as the "D" with width 18mm and height 56mm  
!                               If inf_aperture = 'e989' then the aperture is defined as oval with width= 2*(inflector_width) and height 56mm
!    scatter_inf_end_us -- Logical: If true then there will be scattering in the upstream inflector coil. If false, no scattering 
!    scatter_inf_end_ds -- Logical: If true then there will be scattering in the downstream inflector coil. If false, no scattering 
!    energy_loss -- Logical: If true then there will be energy loss along with scattering in the inflector coils. If false, no energy loss
!    muons -- Muon_struct:  Phase space coordinates of the muon distribution generated (or read from a file)
!    twiss_ref -- character: 'center' put twiss beam parameters at inflector midpoint, 'end' put twiss beam parameters at end
!    mat_inv -- Real,optional :: dimension(6,6): Matrix to propagate phase space defined by TWISS2 backwards to the start of the injection line, (branch = 0)  
!    
subroutine create_phase_space(nmuons, create_new_distribution, new_file, muon_file, &
         epsdistr, epsx, epsy, tdistr, tlength, tsigma, pzdistr, pz, pzsigma, twiss2, inf_aperture, scatter_inf_end_us, scatter_inf_end_ds, energy_loss, muons, twiss_ref, mat_inv)
!USE nrtype
!USE precision_def
USE nr
!use bmad
use parameters_bmad
use muon_mod
use muon_interface, dummy => create_phase_space
use random_mod

use materials_mod

IMPLICIT NONE

type (averages_struct) averages
real(rp) Jx_0, Jy_0, beta_x, alpha_x, beta_y, alpha_y, eta, etap, tlength, tsigma, pz, pzsigma, gamma_x, gamma_y
real(rp) G(2,2), G_inv(2,2), vec_rot(2), phi
real(rp) xvec(2), yvec(2)
real(rp) x2_average/0/, y2_average/0/,px2_average/0/, py2_average/0/,xpx_average/0/, ypy_average/0/
real(rp) epsx, epsy
real(rp)  pz2_average/0/, xpz_average/0/,  pxpz_average/0/,  ypz_average/0/,  pypz_average/0/
real(rp) deltapz
real(rp) sigma_x/0.001/, sigma_xp/0.0/, sigma_y/0.001/, sigma_yp/0.0/
real(rp) T(6,6)
real(rp) vec(6), ent_inf_vec(6)
real(rp), optional :: mat_inv(6,6)

integer nmuons
integer i
integer unit
integer lost_at_inflector
integer tot
integer lun

character*16 epsdistr, tdistr, pzdistr, inf_aperture, twiss_ref
character*120 new_file, muon_file
logical scatter_inf_end_us, scatter_inf_end_ds, withinInflectorAperture
logical create_new_distribution
logical energy_loss
real(rp) s_target/-1000./ ! distance to target (m)
real(rp) betagamma
real(rp) sigma(6,6)

type (muon_struct), allocatable :: muons(:), muons_raw(:)
type (g2twiss_struct) twiss1, twiss2, twiss_inf


call initializeMaterials()
call AssignMaterialPointer(material_ptr,'Al')


allocate(muons(nmuons))
lost_at_inflector = 0

If(create_new_distribution) then
!================================================
! CREATE 6D PHASE-SPACE DISTRIBUTION (AT TARGET)
!================================================

! Assume for simplicity that transverse positions and momenta are uncorrelated at target, i.e. alpha_{x,y}=0
sigma_xp = epsx/sigma_x
sigma_yp = epsy/sigma_y

IF (epsdistr=='gaus') THEN
  ! Gaussian distribution
  call ran_gauss(muons(:)%gas)
  muons(:)%coord%vec(1) = muons(:)%gas * sigma_x
  call ran_gauss(muons(:)%gas)
  muons(:)%coord%vec(2) = muons(:)%gas * sigma_xp
  call ran_gauss(muons(:)%gas)
  muons(:)%coord%vec(3) = muons(:)%gas * sigma_y
  call ran_gauss(muons(:)%gas)
  muons(:)%coord%vec(4) = muons(:)%gas * sigma_yp
ELSEIF (epsdistr=='flat') THEN
  ! Uniform distribution (default).  The width of the flat distribution 
  ! is chosen so as to recover the user's input emittances numerically
  call ran_uniform(muons(:)%flat)
  muons(:)%coord%vec(1) = ((muons(:)%flat)-0.5) * 2*1.73421323168796*sigma_x
  call ran_uniform(muons(:)%flat)
  muons(:)%coord%vec(2) = ((muons(:)%flat)-0.5) * 2*1.73421323168796*sigma_xp
  call ran_uniform(muons(:)%flat)
  muons(:)%coord%vec(3) = ((muons(:)%flat)-0.5) * 2*1.73421323168796*sigma_y
  call ran_uniform(muons(:)%flat)
  muons(:)%coord%vec(4) = ((muons(:)%flat)-0.5) * 2*1.73421323168796*sigma_yp
ELSE ! delta-function
  sigma_x  = 0.
  sigma_y  = 0.
  sigma_xp = 0.
  sigma_yp = 0.
  muons(:)%coord%vec(1) = 0.
  muons(:)%coord%vec(2) = 0.
  muons(:)%coord%vec(3) = 0.
  muons(:)%coord%vec(4) = 0.
ENDIF

! Generate a time and longitudinal position at the target.  By default, the target is assumed to be at s=-1000 meters (CDR).
  print *,'s_target/betaMagic/c_light =',s_target/betaMagic/c_light
IF (tdistr=='e989' .or. tdistr=='E989') THEN
  ! Generate a random time
  DO i=1, nmuons
    muons(i)%coord%t = fnalw(0.0_rp,tlength)
  ENDDO
  ! Subtract the time it takes to get from the target to the ring
  muons(:)%coord%t = muons(:)%coord%t + s_target/betaMagic/c_light

  ! Convert to a longitudinal coordinate
  muons(:)%coord%vec(5) = betaMagic*c_light*muons(:)%coord%t
ELSEIF (tdistr=='e821' .or. tdistr=='E821') THEN
  ! Generate a random time (Guassian with sigma=25ns)
  call ran_gauss(muons(:)%gas)
  muons(:)%coord%t = muons(:)%gas * 25.e-9
  ! Subtract the time it takes to get from the target to the ring
  muons(:)%coord%t = muons(:)%coord%t + s_target/betaMagic/c_light
  ! Convert to a longitudinal coordinate
  muons(:)%coord%vec(5) = betaMagic*c_light*muons(:)%coord%t
ELSEIF (tdistr=='gaus' .or. tdistr=='Gaus' .or. tdistr=='GAUS') THEN
  DO i=1, nmuons
    ! Generate a random time
    call ran_gauss(muons(i)%gas)
    muons(i)%coord%t = muons(i)%gas * tsigma
    ! Ensure the random time lies within [-tlength/2, +tlength/2], i.e. "cut the tails off the Gaussian"
    DO WHILE (abs(muons(i)%coord%t)>tlength/2.)
      call ran_gauss(muons(i)%gas)
      muons(i)%coord%t = muons(i)%gas * tsigma
    ENDDO
    ! Subtract the time it takes to get from the target to the ring.  Convert to a longitudinal coordinate.
    muons(i)%coord%t = muons(i)%coord%t + s_target/betaMagic/c_light
    muons(i)%coord%vec(5) = betaMagic*c_light*muons(i)%coord%t
  ENDDO
ELSE
  ! Uniform distribution (default)
  call ran_uniform(muons(:)%flat)
  muons(:)%coord%t = s_target / betaMagic / c_light + ((muons(:)%flat)-0.5) * tlength
  muons(:)%coord%vec(5) = betaMagic*c_light*muons(:)%coord%t
ENDIF

! Generate a momentum at the target.  Assume that position and momentum are uncorrelated.
IF (pzdistr=="gaus" .or. tdistr=='Gaus' .or. tdistr=='GAUS') THEN
  DO i=1, nmuons
    ! Generate a random momentum
    call ran_gauss(muons(i)%gas)
    muons(i)%coord%vec(6) = muons(i)%gas * pzsigma
    ! Make sure the random momentum lies within [-pz/2, +pz/2], i.e. "cut the tails off the Gaussian"
    DO WHILE ( abs(muons(i)%coord%vec(6))>pz/2. )
      call ran_gauss(muons(i)%gas)
      muons(i)%coord%vec(6) = muons(i)%gas * pzsigma
    ENDDO
  ENDDO
ELSE
  ! Uniform distribution (default)
  call ran_uniform(muons(:)%flat)
  muons(:)%coord%vec(6) = ((muons(:)%flat)-0.5) * pz
ENDIF

! If so instructed write the distribution of muons to a file

  print '(3a,i10)', 'new_file =' ,new_file, ' len(new_file) =', len(trim(new_file))
 if(len(trim(new_file)) > 2 .and. index(new_file,'none')==0)then
  muons(:)%coord%state = alive$
  call write_phase_space(nmuons, muons,new_file, tot)
 endif


endif  !only if instructed to create new distribution
! If so instructed read muons from a file

 if(len(trim(muon_file)) > 2 .and. index(muon_file,'none') == 0)then
  lun = lunget()
   open(unit = lun, file = trim(muon_file))
   print '(a,a)',' Read muons from ',trim(muon_file)
   if(trim(muon_file) == "VDstop_DS_436_12000.dat")then
     call read_Vmuons(lun,nmuons, muons, s_target/betaMagic/c_light, "Volodya's distribution", tot)
   elseif(trim(muon_file) == "particles_endm4m5_100.txt")then   !Kim 8Apr2016
     call read_Dmuons(lun,nmuons, muons, s_target/betaMagic/c_light, "Diktys's distribution", tot) !Kim 8Apr2016
     IF (tdistr=='e989' .or. tdistr=='E989') THEN
   ! Generate a random time
       DO i=1, nmuons
          muons(i)%coord%t = fnalw(0.0_rp,tlength) + muons(i)%coord%t
       ENDDO
     ENDIF
   ! Subtract the time it takes to get from the target to the ring
     muons(:)%coord%t = muons(:)%coord%t + s_target/betaMagic/c_light
    else
     call read_phase_space(lun, nmuons, muons,'Distribution off target', tot)
   endif
   close(unit=lun)
   if(tot < nmuons)then
    print '(a,a,i10,a)',' The number of muons in file ',trim(muon_file), tot,' is less than the number required '
    stop
   endif
 endif


! Now that we have the longitudinal position/momenta at the target, we can propagate 
! the longitudinal coordinates to the ring using relativistic kinematics.  

DO i=1, nmuons
  betagamma = ( 1+muons(i)%coord%vec(6) )*pMagic/mmu
  muons(i)%coord%t = muons(i)%coord%t - s_target/( 1+muons(i)%coord%vec(6)/cosh(asinh(betagamma))**2 )/betaMagic/c_light
  muons(i)%coord%vec(5) = tanh(asinh(betagamma)) * c_light * muons(i)%coord%t
ENDDO


!!!!!!!!!!!
! compute twiss parameters at target
muons(:)%coord%state = alive$
muons(:)%coord%state = alive$

if (nmuons==1) then
  muons(1)%Jx        = epsx
  muons(1)%Jy        = epsy
  muons(1)%coord%vec = 0.
  muons(1)%coord%s   = 0.
  muons(1)%coord%t   = 0.
  return
endif

call compute_emittance_beta(nmuons,muons, twiss1, epsx,epsy,averages) !of the raw distribution (Volodya's muons for example)
call compute_beam_params(muons, 1, 'RawDistributionAtStart')
twiss1%phix   = 0.
twiss1%phiy   = 0.
allocate(muons_raw(0:nmuons))
muons_raw = muons

WRITE(*,'(a11)') 'AT TARGET or raw distribution:'  
WRITE(*,'(5(a15,es12.5))') 'epsx =', epsx, 'sigma_x =', sqrt(averages%x2_average), 'sigma_px =', sqrt(averages%px2_average), 'sigma_xpx =', averages%xpx_average
WRITE(*,'(4(a15,es12.5))') 'epsy =', epsy, 'sigma_y =', sqrt(averages%y2_average), 'sigma_py =', sqrt(averages%py2_average), 'sigma_ypy =', averages%ypy_average
WRITE(*,'(3(a15,es12.5))') 'sigma_pz =', averages%deltapz,'xpz_average =',averages%xpz_average,'pxpz_average =',averages%pxpz_average
WRITE(*,*)
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'beta_x', 'alpha_x', twiss1%betax,  twiss1%alphax
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'alpha_x','gamma_x', twiss1%alphax, twiss1%gammax
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'eta_x','eta_px', twiss1%etax, twiss1%etapx
WRITE(*,*)
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'beta_y', 'alpha_y', twiss1%betay,  twiss1%alphay
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'alpha_y','gamma_y', twiss1%alphay, twiss1%gammay

WRITE(*,'(a11)') 'AT INFLECTOR EXIT:'  
WRITE(*,*)
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'beta_x', 'alpha_x', twiss2%betax,  twiss2%alphax
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'alpha_x','gamma_x', twiss2%alphax, twiss2%gammax
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'eta_x','eta_px', twiss2%etax, twiss2%etapx
WRITE(*,*)
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'beta_y', 'alpha_y', twiss2%betay,  twiss2%alphay
WRITE(*,'(7x,2a9,7x,es12.5,1x,es12.5)') 'alpha_y','gamma_y', twiss2%alphay, twiss2%gammay

!construct transfer matrix from target to final twiss parameters and dispersion halfway through inflector (or at entrance to iron)
call make_transfer(twiss1,twiss2,T)  !transfer matrix from twiss parameters of the raw distribution to the twiss parameters specified in input file, namely midway through inflector

!propagate muons from production target to halfway through inflector
do i=1,nmuons
  muons(i)%Jx = epsx
  muons(i)%Jy = epsy
  vec = muons(i)%coord%vec
  muons(i)%coord%vec = matmul(T,vec)
end do

!construct transfer matrix from midpoint of inflector to exit (forwards) !assuming downstream half of inflector is a drift which is a pretty good approximation
T(:,:) = 0
forall(i=1:6)T(i,i)=1.
if(index(twiss_ref,'end')/=0)then
 T(1,2) = 0. ! dlr 20141009 inf_length/2
 T(3,4) = 0. !dlr 20141009 inf_length/2
 print '(a,1x,a)', ' Twiss_ref = ',twiss_ref
elseif(index(twiss_ref,'center')/=0)then
 T(1,2) = inf_length/2
 T(3,4) = inf_length/2
 print '(a,1x,a)', ' Twiss_ref = ',twiss_ref
else
 print '(a)',' You must specify twiss parameters reference point, (twiss_ref) in index file'
 stop
endif

! Propagate muons from midpoint of inflector to exit
do i=1,nmuons
  muons(i)%coord%vec = matmul(T,muons(i)%coord%vec)
end do

call compute_beam_params(muons, 1, 'InfExitNoScatter')

call write_phase_space( nmuons, muons,'InflectorExitNoScatter', tot) !Raw distribution inflector exit assuming no scattering or apertures

!propagate twiss parameters from midpoint of inflector to exit
call prop_phase_space(T, twiss2, twiss_inf)

!IF (epsdistr=='flat' .or. epsdistr=='gaus') THEN
  WRITE(*,*)
  WRITE(*,*) 'AT INFLECTOR MIDPOINT:'
  WRITE(*,'(4(a15,es12.5))') 'betax =', twiss2%betax, 'alphax =', twiss2%alphax, 'etax =', twiss2%etax, 'etapx =', twiss2%etapx
  WRITE(*,'(4(a15,es12.5))') 'betay =', twiss2%betay, 'alphay =', twiss2%alphay, 'etay =', twiss2%etay, 'etapy =', twiss2%etapy
  WRITE(*,*)
  WRITE(*,*) 'AT INFLECTOR EXIT:'
  WRITE(*,'(4(a15,es12.5))') 'betax =', twiss_inf%betax, 'alphax =', twiss_inf%alphax, 'etax =', twiss_inf%etax, 'etapx =', twiss_inf%etapx
  WRITE(*,'(4(a15,es12.5))') 'betay =', twiss_inf%betay, 'alphay =', twiss_inf%alphay, 'etay =', twiss_inf%etay, 'etapy =', twiss_inf%etapy
!ENDIF


! if mat_inv exits then use it to propagate distribution to start of injection line
!  Otherwise assume 1.7m long inflector and compute effect of scattering below
if(present(mat_inv))then
   do i=1,nmuons
     muons(i)%coord%vec = matmul(mat_inv,muons(i)%coord%vec)
   end do
endif
if(twiss2%betax <= 0 .or. twiss2%betay <= 0)then  !use distribution as is. as for example ditkys distribution
    muons=muons_raw
    deallocate(muons_raw)
endif
call write_phase_space(nmuons,muons,'DistStartInjLine', tot)  ! Raw distribution propagated forward to end of inflector then back using mat_inv to start of injection line
call compute_beam_params(muons,1,'DistStartInjLine')
return

! this stuff below is obsolete

!NOTE: If mat_inv exists (mat_inv propagates backwards through injectionline) then none of the below is relevant

if(.not. present(mat_inv)) then
  print *,' No inverse matix to propagate muons to start of injection line.'
  stop
endif

! Construct transfer matrix from exit of inflector to entrance (backwards)
T(:,:) = 0
forall(i=1:6)T(i,i)=1.
T(1,2) = -inf_length
T(3,4) = -inf_length

! Propagate muons from inflector exit to inflector entrance (backwards)
DO i=1, nmuons
  IF (epsdistr=='flat' .or. epsdistr=='gaus') muons(i)%coord%vec = matmul(T,muons(i)%coord%vec)
ENDDO

do i=1,nmuons
! print *, muons(i)%coord%vec
enddo

! Scattering and energy loss at upstream inflector end
if(scatter_inf_end_us)call inflector_scatter(nmuons, muons, .true., .false., energy_loss)
call check_inflector_aperture(nmuons,muons,inf_aperture, lost_at_inflector) !upstream end

! Construct transfer matrix from entrance of inflector to exit (forward)
T(:,:) = 0
forall(i=1:6)T(i,i)=1.
T(1,2) = inf_length
T(3,4) = inf_length

! Propagate muons forward to exit of inflector (downstream), and check if the muons clear the aperture

DO i=1, nmuons
!  ent_inf_vec = muons(i)%coord%vec                   ! inflector entrance (upstream)
  muons(i)%coord%vec = matmul(T,muons(i)%coord%vec)  ! inflector exit (downstream)
  vec = muons(i)%coord%vec                           ! shorthand
  write(31,'(4es12.4)') vec(1),vec(3), ent_inf_vec(1), ent_inf_vec(3)
END DO

call check_inflector_aperture(nmuons,muons,inf_aperture, lost_at_inflector) !downstream
  ! Check to see if the particle is within the inflector aperture at the upstream/downstream ends

print *
print '(i10,1x,a)', lost_at_inflector,' particles lost at inflector'

!call write_phase_space(nmuons,muons,'Cut distribution at inflector exit (before scattering)', tot)

if(scatter_inf_end_ds)call inflector_scatter(nmuons, muons, .false., .true., energy_loss)

return
END SUBROUTINE create_phase_space


subroutine makeGR(beta,alpha,phi,G,Ginv,R)
 use precision_def

  implicit none
   real(rp) beta,alpha, phi, G(2,2), Ginv(2,2), R(2,2)

G = 0.
Ginv = 0.
   
G(1,1) = 1./sqrt(beta)
G(2,1) = -alpha/sqrt(beta)
G(2,2) = -sqrt(beta)

Ginv(1,1) = sqrt(beta)
Ginv(2,1) = -alpha/sqrt(beta)
Ginv(2,2) = -1./sqrt(beta)

R(1,1) = cos(phi)
R(1,2) = sin(phi)
R(2,1) = -R(1,2)
R(2,2) = R(1,1)

return
end

subroutine make_m(eta1,eta0,etap1,etap0,MM,m)
use precision_def

implicit none

real(rp) eta1, eta0, etap1, etap0
real(rp) eta_b(2,2), MM(2,2), m(2,2), eta_e(2,2)

eta_e = 0
eta_e(1,2) = eta1
eta_e(2,2) = etap1

eta_b = 0
eta_b(1,2) = eta0
eta_b(2,2) = etap0

m=0
m = eta_e + matmul(MM,eta_b)

return
end


subroutine make_transfer(twiss0,twiss1,T)
use precision_def
use muon_mod
implicit none

real(rp) T(6,6), G(2,2), G_inv(2,2), R(2,2), m(2,2),Mx(2,2),Ny(2,2)
type (g2twiss_struct) twiss0, twiss1
integer i
 
T=0
do i=1,6
 T(i,i)=1.
end do

call makeGR(twiss0%betax,twiss0%alphax,twiss1%phix-twiss0%phix,G, G_inv,R)

Mx = matmul(R,G)
call makeGR(twiss1%betax,twiss1%alphax,twiss1%phix-twiss0%phix,G, G_inv,R)

T(1:2,1:2) = matmul(G_inv, Mx)

call makeGR(twiss0%betay,twiss0%alphay,twiss1%phiy - twiss0%phiy,G, G_inv,R)
Ny = matmul(R,G)
call makeGR(twiss1%betay,twiss1%alphay,twiss1%phiy - twiss0%phiy,G, G_inv,R)
T(3:4,3:4) = matmul(G_inv, Ny)

Mx = T(1:2,1:2)
call make_m(twiss1%etax,twiss0%etax,twiss1%etapx,twiss0%etapx,Mx,m)
T(1:2,5:6) = m

call make_m(twiss1%etay,twiss0%etay,twiss1%etapy,twiss0%etapy,T(3:4,3:4),m)
T(3:4,5:6) = m

return
end

subroutine write_phase_space(nmuons, muons, string, tot)

use muon_mod
implicit none
integer  nmuons
integer i, tot
integer lun
character*(*) string
character(72) plotting_script(20)/'hist_initial_dist.gnu','hist_after_inflector.gnu','hist_start_tracking.gnu','hist_out.gnu', &
'hist_off_target.gnu',9*' ','hist_start_injection_line.gnu','hist_after_inflector_scatter',4*' '/
type (muon_struct), allocatable :: muons(:)
tot=0
     lun = lunget()

open(unit = lun, file = trim(string)//'_phase_space.dat')

write(lun,'(a,a10,1x,9a12)')'!', 'muon', 'Jx', 'Jy', 'x', 'xp','y','yp','vec(5)','pz','t'

do i = 1,nmuons
  if(muons(i)%coord%state == alive$)then
   write(lun,'(i10,1x,9es12.4)') i, muons(i)%Jx, muons(i)%Jy, muons(i)%coord%vec(1), &
          muons(i)%coord%vec(2),  muons(i)%coord%vec(3), muons(i)%coord%vec(4), muons(i)%coord%vec(5), muons(i)%coord%vec(6),&
          muons(i)%coord%t
   tot = tot+1
  endif
end do

close(unit=lun)

print '(/,a,1x,i10,1x,a,1x,a)','Phase space vector for ',tot,'particles written to:',trim(string)//'_phase_space.dat' 
print '(6a)',' Use <plotting_script/hist_any_phase_space_file.gnu> to plot'
return
end

subroutine read_phase_space(unit, nmuons, muons, string, tot)

use muon_mod
implicit none
integer unit, nmuons
integer i, tot
character*(*) string
character*120 line
character(72) plotting_script(4)/'hist_initial_dist.gnu','hist_after_inflector.gnu','hist_start_tracking.gnu','hist_out.gnu'/
type (muon_struct), allocatable :: muons(:)
tot=1

do while(tot <= nmuons)
  read(unit,'(a)', end=99) line
  if(index(line(1:3),'!')/= 0)cycle
  read(line,'(i10,1x,9es12.4)') i, muons(i)%Jx, muons(i)%Jy, muons(i)%coord%vec(1), &
         muons(i)%coord%vec(2),  muons(i)%coord%vec(3), muons(i)%coord%vec(4), muons(i)%coord%vec(5), muons(i)%coord%vec(6),&
         muons(i)%coord%t
  tot = tot+1
end do
99 print '(/,a,1x,i10,1x,a,1x,i10)','Phase space vector for ',tot,' particles read from unit ',unit
return
end

subroutine read_Vmuons(unit, nmuons, muons,toff, string, tot)

use muon_mod
implicit none
integer unit, nmuons
integer i, tot
character*(*) string
character*120 line
character(72) plotting_script(4)/'hist_initial_dist.gnu','hist_after_inflector.gnu','hist_start_tracking.gnu','hist_out.gnu'/
type (muon_struct), allocatable :: muons(:)
real(rp) z,px,py,pz,t, deltap
real(rp) vec(1:9), time
real(rp) toff
time =0.
tot=0
i=0
do while(tot < nmuons)
  read(unit,'(a)', end=99) line
  if(index(line(1:3),'!')/= 0)cycle
  i=i+1
  read(line,*)vec(1:9) 
         muons(i)%coord%vec(1:5) = vec(1:5)/1000.
         muons(i)%coord%t =  vec(9) * 1.e-9
         px=vec(6)
         py=vec(7)
         pz=vec(8)
         deltap = (sqrt(px**2+py**2+pz**2) -3094.35)
         muons(i)%coord%vec(6) = deltap/3094.35      
         time = time + vec(9)
  tot = tot+1
end do
print *,' time/tot = ', time/tot
 muons(1:tot)%coord%t = (muons(1:tot)%coord%t - time/tot * 1.e-9) ! + toff

99 print '(/,a,1x,i10,1x,a,1x,i10)','Phase space vector for ',tot,' particles read from Volodyas file unit ',unit
return
end

subroutine read_Dmuons(unit, nmuons, muons,toff, string, tot)
 ! read distributions from Diktys' M4M5 simulation, 6th Apr 2016
 ! Under construction 
use muon_mod
implicit none
integer unit, nmuons, lun
integer i, tot, ii
character*(*) string
character*120 line
character(72) plotting_script(4)/'hist_initial_dist.gnu','hist_after_inflector.gnu','hist_start_tracking.gnu','hist_out.gnu'/
type (muon_struct), allocatable :: muons(:)
real(rp) z,px,py,pz,t, deltap
real(rp) vec(1:9), time
real(rp) toff
time =0.
tot=0
i=0
do while(tot < nmuons)
  read(unit,'(a)', end=99) line
  if(index(line(1:3),'#')/= 0)cycle
  i=i+1
  read(line,*)vec(1:7) 
  if(vec(7)>9889)cycle
         !muons(i)%coord%vec(1:5) = vec(1:5)/1000.
         muons(i)%coord%vec(1) = vec(1)/1000.
         muons(i)%coord%vec(2) = vec(4)/3094.35
         muons(i)%coord%vec(3) = vec(2)/1000.
         muons(i)%coord%vec(4) = vec(5)/3094.35
         muons(i)%coord%vec(5) = vec(3)/1000.
         muons(i)%coord%t =  vec(7) * 1.e-9
         px=vec(4)
         py=vec(5)
         pz=vec(6)
         deltap = (sqrt(px**2+py**2+pz**2) -3094.35)
         muons(i)%coord%vec(6) = deltap/3094.35      
         time = time + vec(7)
  tot = tot+1
end do
print *,' time/tot = ', time/tot
 muons(1:tot)%coord%t = (muons(1:tot)%coord%t - time/tot * 1.e-9) ! + toff

    lun= lunget()
      open(unit = lun, file='Diktys_phase_space.dat')
      write(lun,'(a,a10,1x,7a12)')'!', 'muon', 'x','xp','y','yp','vec(5)','pz','t'
      do ii = 1, 100
       write(lun,'(i10,1x,7es12.4)') ii, muons(ii)%coord%vec(1), &
          muons(ii)%coord%vec(2),  muons(ii)%coord%vec(3), muons(ii)%coord%vec(4),muons(ii)%coord%vec(5), muons(ii)%coord%vec(6),&
          muons(ii)%coord%t
      end do

99 print '(/,a,1x,i10,1x,a,1x,i10)','Phase space vector for ',tot,' particles read from Diktys file unit ',unit
return
end


subroutine inflector_end_scatter(nmuons, muons)

  use muon_mod
  use parameters_bmad
  use nr
  implicit none
  integer unit, nmuons
  integer i, tot
  real(rp) dev1, dev2, f, sum
  real(rp) theta0

  type (muon_struct), allocatable :: muons(:)

! apply a scattering angle:
  

do i=1, nmuons
  f = (sqrt(muons(i)%coord%vec(2)**2+muons(i)%coord%vec(4)**2) + 1.)

sum = (f*thickness_al/radlength_al)*(1+0.038*log(f*thickness_al/radlength_al))**2 + &
      (f*thickness_cu/radlength_cu)*(1+0.038*log(f*thickness_cu/radlength_cu))**2 + &
      (f*thickness_nbti*0.5/radlength_nb)*(1+0.038*log(f*thickness_nbti*0.5/radlength_nb))**2 + &
      (f*thickness_nbti*0.5/radlength_ti)*(1+0.038*log(f*thickness_nbti*0.5/radlength_ti))**2

theta0 = ((13.6/momentumunit_MeVperc)/magicmomentum) * sqrt(sum)
!print '(i10,1x,2es12.4)',i,f,theta0
call ran_gauss(dev1)
call ran_gauss(dev2)
muons(i)%coord%vec(2) = muons(i)%coord%vec(2)+theta0*dev1
muons(i)%coord%vec(4) = muons(i)%coord%vec(4)+theta0*dev2

enddo 
return
end
! approximation: leave tangential momentum unchanged, no energy loss

subroutine prop_phase_space(T, twiss1, twiss2)

use parameters_bmad
use muon_mod
use muon_interface

implicit none

type (g2twiss_struct) twiss1, twiss2
real(rp) T(6,6), N(2,2), A(2,2), temp(2,2), Ttrans(2,2), M_inv(2,2)

N(1,1:2) = [(1+twiss1%alphax**2)/twiss1%betax, twiss1%alphax] 
N(2,1:2) = [twiss1%alphax, twiss1%betax]

M_inv(1,1:2) = [T(2,2),-T(1,2)]
M_inv(2,1:2) = [-T(2,1),T(1,1)]
temp = matmul(N,M_inv)

Ttrans(1,1:2) = [M_inv(1,1),M_inv(2,1)]
Ttrans(2,1:2) = [M_inv(1,2),M_inv(2,2)]
 A = matmul(Ttrans,temp)

twiss2%betax = A(2,2)
twiss2%alphax = A(1,2)

twiss2%etax= T(1,1)*twiss1%etax + T(1,2)*twiss1%etapx
twiss2%etapx= T(2,1)*twiss1%etax + T(2,2)*twiss1%etapx

N(1,1:2) = [(1+twiss1%alphay**2)/twiss1%betay, twiss1%alphay] 
N(2,1:2) = [twiss1%alphay, twiss1%betay]

M_inv(1,1:2) = [T(4,4),-T(3,4)]
M_inv(2,1:2) = [-T(4,3),T(3,3)]
temp = matmul(N,M_inv)

Ttrans(1,1:2) = [M_inv(1,1),M_inv(2,1)]
Ttrans(2,1:2) = [M_inv(1,2),M_inv(2,2)]
 A = matmul(Ttrans,temp)

twiss2%betay = A(2,2)
twiss2%alphay = A(1,2)

twiss2%etax= T(1,1)*twiss1%etax + T(1,2)*twiss1%etapx
twiss2%etapx= T(2,1)*twiss1%etax + T(2,2)*twiss1%etapx

return
end


SUBROUTINE E989InflectorAperture(x,y,within,us,ds)
  USE parameters_bmad
  IMPLICIT NONE
  REAL(rp), INTENT(IN) :: x,y
  logical us,ds
  LOGICAL, INTENT(INOUT) :: within
  REAL(rp) :: ax,ay ! aperture dimensions
  real(rp) :: ax_min, ax_max
  REAL(rp) dx, xtemp

    ! Inflector aperture half-widths
    ay =   inflector_height ! same as E821
    ax =   inflector_width 

    xtemp = x - 0.009 !inner aperture does not move. 
    if(us)then
      dx = 1.7 *tilt
      xtemp = xtemp+dx
    endif


    IF ((xtemp**2/ax**2 + y**2/ay**2) < 1.) THEN
!    IF (xtemp > ax_min .and. xtemp < ax_max .and. abs(y) < ay) THEN
!    IF (abs(xtemp) < ax .and. abs(y) < ay) THEN
      within = .true.
    ENDIF

  RETURN
END SUBROUTINE E989InflectorAperture


SUBROUTINE E821InflectorAperture(x,y,within,us,ds)
  USE parameters_bmad
  IMPLICIT NONE
  REAL(rp), INTENT(IN) :: x,y
  LOGICAL, INTENT(INOUT) :: within
  logical  us,ds
  REAL(rp) :: m,b ! slope and intercept of slanted part of E821 inflector aperture
  real(rp) dx, xtemp

  m = (0.016-0.028)/(0.009-0.002) ! slope
  b = 0.028 - m*0.002 ! intercept
  xtemp = x
  if(us)then
    dx = 1.7 *tilt
    xtemp = xtemp+dx
!  print '(a,4es12.4)',' x,xtemp,tilt,dx ',x,xtemp,tilt,dx
  endif
    IF (abs(xtemp)>0.009 .or. abs(y)>0.028) THEN
      within = .false.
    ELSE IF (xtemp>0.002 .and. abs(y)>(m*xtemp+b)) THEN
      within = .false.
    ELSE
      within = .true.
    ENDIF
!if(.not. within)print *, 'not within'
  RETURN
END SUBROUTINE E821InflectorAperture


SUBROUTINE RectangularInflectorAperture(x,y,within)
  USE parameters_bmad
  IMPLICIT NONE
  REAL(rp), INTENT(IN) :: x,y
  LOGICAL, INTENT(INOUT) :: within

    within = (abs(x)<inflector_width) .and. (abs(y)<inflector_height)

  RETURN
END SUBROUTINE RectangularInflectorAperture


FUNCTION fnalw(center,width)
  USE parameters_bmad ! fnalw_integral table
  IMPLICIT NONE

  REAL(rp)             :: fnalw                       ! quantity to determine
  REAL(rp), INTENT(IN) :: center, width               ! center and width of distribution [a.u.]
  INTEGER,  PARAMETER  :: nbins=100                   ! number of bins in fnalw_integral table
  REAL(rp) rand, di, frac_low, frac_high, i_interp    ! helper variables
  INTEGER  i                                          ! iterator

  ! Generate a random number in [0,1)
  call ran_uniform(rand)

  ! Find the random number in the fnalw_integral table, and convert to a number based on the user's center/width arguments
  DO i=1, 99
    IF (fnalw_integral(i)<=rand .and. rand<fnalw_integral(i+1)) THEN
      di = fnalw_integral(i+1) - fnalw_integral(i)        ! values in the fnalw_integral table (y-axis)
      frac_low  = ( fnalw_integral(i+1) - rand )/di       ! fraction of lower value to include (linear interpolation)
      frac_high = ( rand -  fnalw_integral(i)  )/di       ! fraction of upper value to include (linear interpolation)
      i_interp  = frac_low*i + frac_high*(i+1)            ! corresponding fractional index (x-axis)
      fnalw     = center + (i_interp/nbins - 0.5) * width ! map to "center" and "width"
      EXIT
    ENDIF
  ENDDO

RETURN
END FUNCTION fnalw

!************************************************************
! subroutine inflector_scatter(nmuons, muons)
!  scatter in inflector ends
!********************************************

subroutine inflector_scatter(nmuons, muons, us, ds, energy_loss)

use parameters_bmad
use muon_mod
use muon_interface, dummy =>inflector_scatter

use materials_mod

IMPLICIT NONE

type (muon_struct), allocatable :: muons(:)
integer i,nmuons
logical us, ds, energy_loss

! Scattering and energy loss at upstream inflector end

energy_loss = eloss

IF (us) THEN
  DO i=1,nmuons
   call inflector_scatter1(muons(i)%coord,us,ds,energy_loss)
  ENDDO
ENDIF


! Scattering and energy loss at the downstream end of the inflector
IF (ds) THEN
  DO i=1,nmuons
   call inflector_scatter1(muons(i)%coord,us,ds,energy_loss)
  ENDDO
ENDIF

return
end subroutine !inflector_scatter

!*********************************************
!***********************************************
!***********************************************

subroutine check_inflector_aperture(nmuons,muons, inf_aperture, lost_at_inflector)

use parameters_bmad
use muon_mod
use muon_interface, dummy => check_inflector_aperture

use materials_mod

IMPLICIT NONE

type (muon_struct), allocatable :: muons(:)
integer i,nmuons, lost_at_inflector
logical withinInflectorAperture, us,ds
character*16 inf_aperture
real(rp) vec(6)

us = .true.
ds = .true.
! Check to see if muons clear the inflector aperture
DO i=1, nmuons

  vec = muons(i)%coord%vec ! upstream end of inflector
  withinInflectorAperture = .false.

  IF (inf_aperture=='e989' .or. inf_aperture=='E989') THEN
    call E989InflectorAperture(vec(1),vec(3),withinInflectorAperture,us,ds)
  ELSEIF (inf_aperture=='e821' .or. inf_aperture=='E821') THEN
    call E821InflectorAperture(vec(1),vec(3),withinInflectorAperture, us, ds)
  ELSEIF (inf_aperture=='rect' .or. inf_aperture=='Rect' .or. inf_aperture=='RECT') THEN
    call RectangularInflectorAperture(vec(1),vec(3),withinInflectorAperture)
  ELSE
    withinInflectorAperture = .true.
  ENDIF

  ! Kill the track if necessary
  IF (muons(i)%coord%state==alive$ .and. .not.withinInflectorAperture) THEN
    muons(i)%coord%state = lost$
    lost_at_inflector = lost_at_inflector + 1
  ENDIF
ENDDO
return
end subroutine !check_inflector_aperture

subroutine compute_emittance_beta(nmuons,muons, twiss1, epsx,epsy,averages)
 use bmad
 use muon_mod
 implicit none
 type (muon_struct), allocatable :: muons(:) 
 type (averages_struct) averages
 type (g2twiss_struct) twiss1
 real(rp) x2_average/0/, y2_average/0/,px2_average/0/, py2_average/0/,xpx_average/0/, ypy_average/0/
 real(rp) epsx, epsy
 real(rp)  pz2_average/0/, xpz_average/0/,  pxpz_average/0/,  ypz_average/0/,  pypz_average/0/
 real(rp) deltapz
 real(rp) sigma_x/0.001/, sigma_xp/0.0/, sigma_y/0.001/, sigma_yp/0.0/
 real(rp) sigma(6,6)
 integer nmuons, i

  x2_average  = 0.
  y2_average  = 0.
  px2_average = 0.
  py2_average = 0.
  xpx_average = 0.
  ypy_average = 0.
  pz2_average = 0.
  xpz_average = 0.
  pxpz_average = 0.
  ypz_average = 0.
  pypz_average = 0.

do i=1,nmuons
  x2_average  = x2_average  + muons(i)%coord%vec(1)**2
  y2_average  = y2_average  + muons(i)%coord%vec(3)**2
  px2_average = px2_average + muons(i)%coord%vec(2)**2
  py2_average = py2_average + muons(i)%coord%vec(4)**2
  xpx_average = xpx_average + muons(i)%coord%vec(1)*muons(i)%coord%vec(2)
  ypy_average = ypy_average + muons(i)%coord%vec(3)*muons(i)%coord%vec(4)
  pz2_average = pz2_average + muons(i)%coord%vec(6)**2
  xpz_average = xpz_average + muons(i)%coord%vec(1) * muons(i)%coord%vec(6)
  pxpz_average = pxpz_average + muons(i)%coord%vec(2)*muons(i)%coord%vec(6)
  ypz_average = ypz_average + muons(i)%coord%vec(3)*muons(i)%coord%vec(6)
  pypz_average = pypz_average + muons(i)%coord%vec(4)*muons(i)%coord%vec(6)
end do

x2_average  = x2_average /nmuons
y2_average  = y2_average /nmuons
px2_average = px2_average/nmuons
py2_average = py2_average/nmuons
xpx_average = xpx_average/nmuons
ypy_average = ypy_average/nmuons
pz2_average = pz2_average/nmuons
xpz_average = xpz_average/nmuons
pxpz_average = pxpz_average/nmuons
ypz_average = ypz_average/nmuons
pypz_average = pypz_average/nmuons

sigma(1,1) = x2_average - xpz_average**2/pz2_average 
sigma(1,2) = xpx_average - xpz_average*pxpz_average/pz2_average 
sigma(2,2) = px2_average - pxpz_average**2/pz2_average 

sigma(3,3) = y2_average - ypz_average**2/pz2_average 
sigma(3,4) = ypy_average - ypz_average*pypz_average/pz2_average 
sigma(4,4) = py2_average - pypz_average**2/pz2_average 

epsx = sqrt(x2_average*px2_average - xpx_average**2)
epsy = sqrt(y2_average*py2_average - ypy_average**2)

epsx = sqrt(sigma(1,1)*sigma(2,2)-sigma(1,2)**2)
epsy = sqrt(sigma(3,3)*sigma(4,4)-sigma(3,4)**2)

deltapz = sqrt(pz2_average)

averages%x2_average  = x2_average /nmuons
averages%y2_average  = y2_average /nmuons
averages%px2_average = px2_average/nmuons
averages%py2_average = py2_average/nmuons
averages%xpx_average = xpx_average/nmuons
averages%ypy_average = ypy_average/nmuons
averages%pz2_average = pz2_average/nmuons
averages%xpz_average = xpz_average/nmuons
averages%pxpz_average = pxpz_average/nmuons
averages%ypz_average = ypz_average/nmuons
averages%pypz_average = pypz_average/nmuons


 
  twiss1%betax  =  averages%x2_average/epsx 
  twiss1%alphax = averages%xpx_average/epsx
  twiss1%gammax = averages%px2_average/epsx
  twiss1%betay  =  averages%y2_average/epsy 
  twiss1%alphay = averages%ypy_average/epsy
  twiss1%gammay = averages%py2_average/epsy
  twiss1%etax = (averages%xpz_average)/pz2_average
  twiss1%etapx = (averages%pxpz_average)/pz2_average

  twiss1%betax  =  sigma(1,1)/epsx 
  twiss1%alphax = -sigma(1,2)/epsx
  twiss1%gammax = sigma(2,2)/epsx
  twiss1%betay  =  sigma(3,3)/epsy 
  twiss1%alphay = -sigma(3,4)/epsy
  twiss1%gammay = sigma(4,4)/epsy
  twiss1%etay = (averages%ypz_average)/averages%pz2_average
  twiss1%etapy = (averages%pypz_average)/averages%pz2_average

 return
end

   subroutine inflector_scatter1(orb,us,ds,energy_loss)
   use bmad
   use materials_mod
   use parameters_bmad
   implicit none
   type (coord_struct) orb
   logical us, ds, energy_loss
   REAL(rp) surface_normal(3)/0.,0.,1./
!print '(a,6es12.4)','inflector scatter 1 in:',orb%vec(1:6)
   if((us .and. us_scatter) .or. (ds .and. ds_scatter))then
    call scatter( Al,                0.0015_rp, orb,surface_normal) ! flange window  = 1.5 mm Al
!print '(a,6es12.4)','inflector scatter 1 Al:',orb%vec(1:6)
    call scatter( InflectorCoilE821, 0.0033_rp, orb,surface_normal) ! coil           = 3.3 mm kapton-wrapped, aluminum-stabilized NbTi/Cu superconducting wires
!print '(a,6es12.4)','inflector scatter 1 InfCoil:',orb%vec(1:6)
    call scatter( Al,                0.0015_rp, orb,surface_normal) ! mandrel window = 1.5 mm Al
    IF (eloss) THEN
      call energyLoss( Al,                0.0015_rp, orb) ! flange window  = 1.5 mm Al
      call energyLoss( InflectorCoilE821, 0.0033_rp, orb) ! coil           = 3.3 mm kapton-wrapped, aluminum-stabilized NbTi/Cu superconducting wires
      call energyLoss( Al,                0.0015_rp, orb) ! flange window  = 1.5 mm Al
    ENDIF
    endif
!print '(a,6es12.4)','inflector scatter 1 out:',orb%vec(1:6)
   return
   end
