Source code for thermo

import os

import numpy as np

import jax.numpy as jnp
import jax.lax as lax
from jax import grad, vmap, device_put, devices
import interpax

import linx.const as const 
from linx.special_funcs import Li, K1, K2
from linx.P_QED import explicit_P2, explicit_P3, dPdTQED_2,dPdTQED_3

###########################################
#                 Cosmology               #
############################################


[docs] def Hubble(rho_tot): """ The Hubble parameter in s^-1. Parameters ---------- rho_tot : float or array The total energy density in MeV^4. Returns ------- float or array """ return ( rho_tot * 8 * jnp.pi / (3 * const.Mpl**2) )**0.5 / const.hbar
[docs] def N_eff(rho_tot, rho_g): """ Neff parameter. Parameters ---------- rho_tot : float or array Total energy density of all fluids. rho_g : float or array Energy density of photons. Returns ------- float or array """ return 8./7. * (11./4.)**(4./3.) * (rho_tot-rho_g) / rho_g
[docs] def nB(a, eta_fac=1.): """ Number density of baryons in MeV^3. Parameters ---------- a : float or array Scale factor of interest. eta_fac : float, optional Factor to rescale the central value of the baryon-to-photon ratio by. Returns ------- float or array """ n0B = eta_fac*const.n0CMB*const.eta0 # baryon density of today MeV^3 return n0B/a**3 # MeV^3
#################################################################### # Generic Thermodynamic Variables # ####################################################################
[docs] def rho_massless_BE(T, mu, g): """ Energy density of a massless particle with Bose-Einstein statistcs. Parameters ---------- T : float Temperature of massless particle species in units of MeV. mu : float Chemical potential of massless particle species. g : float Degrees of freedom of massless particle species (1 for scalar, 2 for vector, etc.). Returns ------- float Units of MeV^4. """ # epsilonbig = 1e12 # epsilonsmall = 1e-12 # T_non_negative = jnp.maximum(T, epsilonsmall) # T_safer = lax.cond( # T_non_negative < epsilonbig, # lambda _: T_non_negative, # lambda _: 0.0, # None # ) # T_safe_pow4 = T_safer**4 # return g * 3 / jnp.pi**2 * T_safe_pow4 * Li(4, jnp.exp(mu/T_safer)) return lax.cond( T > 0., lambda T: g * 3 / jnp.pi**2 * Li(4, jnp.exp(mu/T)) * T**4, lambda T: 0., T )
[docs] def n_massless_BE(T, mu, g): """ Number density of a massless particle with Bose-Einstein statistcs. Parameters ---------- T : float Temperature of massless particle species in units of MeV. mu : float Chemical potential of massless particle species. g : float Degrees of freedom of massless particle species (1 for scalar, 2 for vector, etc.). Returns ------- float Units of MeV^3. """ return lax.cond( T > 0., lambda T: g / jnp.pi**2 * Li(3, jnp.exp(mu/T)) * T**3, lambda T: 0., T )
[docs] def p_massless_BE(T, mu, g): """ Pressure of a massless particle with Bose-Einstein statistcs. Parameters ---------- T : float Temperature of massless particle species in units of MeV. mu : float Chemical potential of massless particle species. g : float Degrees of freedom of massless particle species (1 for scalar, 2 for vector, etc.). Returns ------- float Units of MeV^4. """ return rho_massless_BE(T, mu, g) / 3.
[docs] def rho_massless_FD(T, mu, g): """ Energy density of a massless particle with Fermi-Dirac statistcs. Parameters ---------- T : float Temperature of massless particle species in units of MeV. mu : float Chemical potential of massless particle species. g : float Degrees of freedom of massless particle species (2 for Weyl fermion). Returns ------- float Units of MeV^4. """ return lax.cond( T > 0., lambda T: -g * 3 / jnp.pi**2 * Li(4, -jnp.exp(mu/T)) * T**4, lambda T: 0., T )
# T = jnp.where(jnp.isnan(T**4),0,T) # value = jnp.where(jnp.isnan(g * 45 / jnp.pi**2 * T**4),0, T**4) # epsilonbig = 1e12 # epsilonsmall = 1e-12 # # Compute T^4 safely, avoid potential underflow or NaN by ensuring T is always above epsilon # T_non_negative = jnp.maximum(T, epsilonsmall) # T_safer = lax.cond( # T_non_negative < epsilonbig, # lambda _: T_non_negative, # lambda _: 0.0, # None # ) # T_safe_pow4 = T_safer**4 # return g * 3 / jnp.pi**2 * T_safe_pow4 #* Li(4, -jnp.exp(mu/T))
[docs] def n_massless_FD(T, mu, g): """ Number density of a massless particle with Fermi-Dirac statistcs. Parameters ---------- T : float Temperature of massless particle species in units of MeV. mu : float Chemical potential of massless particle species. g : float Degrees of freedom of massless particle species (2 for Weyl fermion). Returns ------- float Units of MeV^3. """ return lax.cond( T > 0., lambda T: -g / jnp.pi**2 * Li(3, -jnp.exp(mu/T)) * T**3, lambda T: 0., T )
[docs] def p_massless_FD(T, mu, g): """ Pressure of a massless particle with Fermi-Dirac statistcs. Parameters ---------- T : float Temperature of massless particle species in units of MeV. mu : float Chemical potential of massless particle species. g : float Degrees of freedom of massless particle species (2 for Weyl fermion). Returns ------- float Units of MeV^4. """ return rho_massless_FD(T, mu, g) / 3.
[docs] def rho_massless_MB(T, mu, g): """ Energy density of a massless particle with Maxwell-Boltzmann statistcs. Parameters ---------- T : float Temperature of massless particle species in units of MeV. mu : float Chemical potential of massless particle species. g : float Degrees of freedom of massless particle species. Returns ------- float Units of MeV^4. """ return lax.cond( T > 0., lambda T: g * 3 / jnp.pi**2 * jnp.exp(mu/T) * T**4, lambda T: 0., T )
[docs] def n_massless_MB(T, mu, g): """ Number density of a massless particle with Maxwell-Boltzmann statistcs. Parameters ---------- T : float Temperature of massless particle species in units of MeV. mu : float Chemical potential of massless particle species. g : float Degrees of freedom of massless particle species. Returns ------- float Units of MeV^3. """ return lax.cond( T > 0., lambda T: g / jnp.pi**2 * jnp.exp(mu/T) * T**3, lambda T: 0., T )
[docs] def p_massless_MB(T, mu, g): """ Pressure of a massless particle with Maxwell-Boltzmann statistcs. Parameters ---------- T : float Temperature of massless particle species in units of MeV. mu : float Chemical potential of massless particle species. g : float Degrees of freedom of massless particle species. Returns ------- float Units of MeV^4. """ return rho_massless_MB(T, mu, g) / 3.
# Parameters for series approximation of (massive) thermodynamic integrals. # Series method is much faster than integral computation. N_series_terms=20 # 20 Seems fine for mu < 0.7 T, << 1% error. rel_thres=30.
[docs] def rho_massive_BE(T, mu, m, g): """ Series approximation for energy density of a massive particle with Bose-Einstein statistcs. Parameters ---------- T : float Temperature of massive particle species in MeV. mu : float Chemical potential of massive particle species in MeV. m : float Mass of particle in MeV. g : float Degrees of freedom of massive particle species (1 for scalar, 3 for vector, etc.). Returns ------- float Units of MeV^4. """ def res(i, val): return val + jnp.exp(i * mu / T) * ( (m / T)**3 / i * K1(i * m / T) + 3 * (m / T)**2 / i**2 * K2(i * m / T) ) return jnp.where( (T / m > rel_thres) | (T <= 0.), rho_massless_BE(T, mu, g), g / (2 * jnp.pi**2) * T**4 * lax.fori_loop(1, N_series_terms, res, 0) )
[docs] def n_massive_BE(T, mu, m, g): """ Series approximation for number density of a massive particle with Bose-Einstein statistcs. Parameters ---------- T : float Temperature of massive particle species in MeV. mu : float Chemical potential of massive particle species in MeV. m : float Mass of particle in MeV. g : float Degrees of freedom of massive particle species (1 for scalar, 3 for vector, etc.). Returns ------- float Units of MeV^3. """ def res(i, val): return val + jnp.exp(i * mu / T) * (m / T)**2 / i * K2(i * m / T) return jnp.where( (T / m > rel_thres) | (T <= 0.), n_massless_BE(T, mu, g), g / (2 * jnp.pi**2) * T**3 * lax.fori_loop(1, N_series_terms, res, 0) )
[docs] def p_massive_BE(T, mu, m, g): """ Series approximation for pressure of a massive particle with Bose-Einstein statistcs. Parameters ---------- T : float Temperature of massive particle species in MeV. mu : float Chemical potential of massive particle species in MeV. m : float Mass of particle in MeV. g : float Degrees of freedom of massive particle species (1 for scalar, 3 for vector, etc.). Returns ------- float Units of MeV^4. """ def res(i, val): return val + jnp.exp(i * mu / T) * 3 * (m / T)**2 / i**2 * K2(i * m / T) return jnp.where( (T / m > rel_thres) | (T <= 0.), p_massless_BE(T, mu, g), g / (6 * jnp.pi**2) * T**4 * lax.fori_loop(1, N_series_terms, res, 0) )
[docs] def rho_massive_FD(T, mu, m, g): """ Series approximation for energy density of a massive particle with Fermi-Dirac statistcs. Parameters ---------- T : float Temperature of massive particle species in MeV. mu : float Chemical potential of massive particle species in MeV. m : float Mass of particle in MeV. g : float Degrees of freedom of massive particle species (2 for Majorana Fermion, 4 for Dirac Fermion, etc.). Returns ------- float Units of MeV^4. """ def res(i, val): return val + (-1)**(i - 1) * jnp.exp(i * mu / T) * ( (m / T)**3 / i * K1(i * m / T) + 3 * (m / T)**2 / i**2 * K2(i * m / T) ) return jnp.where( (T / m > rel_thres) | (T <= 0.), rho_massless_FD(T, mu, g), g / (2 * jnp.pi**2) * T**4 * lax.fori_loop(1, N_series_terms, res, 0) )
[docs] def n_massive_FD(T, mu, m, g): """ Series approximation for number density of a massive particle with Fermi-Dirac statistcs. Parameters ---------- T : float Temperature of massive particle species in MeV. mu : float Chemical potential of massive particle species in MeV. m : float Mass of particle in MeV. g : float Degrees of freedom of massive particle species (2 for Majorana Fermion, 4 for Dirac Fermion, etc.). Returns ------- float Units of MeV^3. """ def res(i, val): return val + ( (-1)**(i - 1) * jnp.exp(i * mu / T) * (m / T)**2 / i * K2(i * m / T) ) return jnp.where( (T / m > rel_thres) | (T <= 0.), n_massless_FD(T, mu, g), g / (2 * jnp.pi**2) * T**3 * lax.fori_loop(1, N_series_terms, res, 0) )
[docs] def p_massive_FD(T, mu, m, g): """ Series approximation for pressure of a massive particle with Fermi-Dirac statistcs. Parameters ---------- T : float Temperature of massive particle species in MeV. mu : float Chemical potential of massive particle species in MeV. m : float Mass of particle in MeV. g : float Degrees of freedom of massive particle species (1 for Majorana Fermion, 2 for Dirac Fermion, etc.). Returns ------- float Units of MeV^4. """ def res(i, val): return val + ( (-1)**(i - 1) * jnp.exp(i * mu / T) * 3 * (m / T)**2 / i**2 * K2(i * m / T) ) return jnp.where( (T / m > rel_thres) | (T <= 0.), p_massless_FD(T, mu, g), g / (6 * jnp.pi**2) * T**4 * lax.fori_loop(1, N_series_terms, res, 0) )
[docs] def rho_massive_MB(T, mu, m, g): """ Series approximation for energy density of a massive particle with Maxwell-Boltzmann statistcs. Parameters ---------- T : float Temperature of massive particle species in MeV. mu : float Chemical potential of massive particle species in MeV m : float Mass of particle in MeV. g : float Degrees of freedom of massive particle species. Returns ------- float Units of MeV^4. """ return lax.cond( T > 0., lambda T: g * m**2 * T * jnp.exp(mu/T) / (2 * jnp.pi**2) * ( m * K1(m / T) + 3 * T * K2(m / T) ), lambda T: 0., T )
[docs] def n_massive_MB(T, mu, m, g): """ Series approximation for number density of a massive particle with Maxwell-Boltzmann statistcs. Parameters ---------- T : float Temperature of massive particle species in MeV. mu : float Chemical potential of massive particle species in MeV. m : float Mass of particle in MeV. g : float Degrees of freedom of massive particle species. Returns ------- float Units of MeV^3. """ return lax.cond( T > 0., lambda T: g * m**2 * T * jnp.exp(mu/T) / (2 * jnp.pi**2) * K2(m / T), lambda T: 0., T )
[docs] def p_massive_MB(T, mu, m, g): """ Series approximation for pressure of a massive particle with Maxwell-Boltzmann statistcs. Parameters ---------- T : float Temperature of massive particle species in units of MeV. mu : float Chemical potential of massive particle species. m : float Mass of particle in units of MeV. g : float Degrees of freedom of massive particle species. Returns ------- float Units of MeV^4. """ return n_massive_MB(T, mu, m, g) * T
#################################################################### # Electromagnetic Sector and Neutrinos # #################################################################### file_dir = os.path.dirname(__file__) # QED Corrections - flip to ensure monotonically increasing T for interpax.interp1d (assume me = 0.511 MeV) P_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_P_int.txt"), axis=0) dPdT_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_dP_intdT.txt"), axis=0) # d2PdT2_QED_tab = np.flip(np.loadtxt(file_dir+"/data/background/"+"QED_d2P_intdT2.txt"), axis=0) # CG: JAX grad obviates this import... # Effect of standard value of electron mass in scattering matrix elements (assume me = 0.511 MeV) f_nue_scat_tab = np.loadtxt(file_dir+"/data/background/"+"nue_scatt.txt") f_numu_scat_tab = np.loadtxt(file_dir+"/data/background/"+"numu_scatt.txt") # Effect of standard value of electron mass in annihilation matrix elements (assume me = 0.511 MeV) f_nue_ann_tab = np.loadtxt(file_dir+"/data/background/"+"nue_ann.txt") f_numu_ann_tab = np.loadtxt(file_dir+"/data/background/"+"numu_ann.txt") # Use scattering coefficients provided by Miguel Escudero, Greg Jackson, Stefan Sandner and Mikko Laine, to appear # no assumption that me = 0.511 MeV f_coeffs = np.loadtxt(file_dir+"/data/background/"+"MB_coefficients.txt") try: gpus = devices('gpu') P_QED_tab = device_put( P_QED_tab, device=gpus[0] ) dPdT_QED_tab = device_put( dPdT_QED_tab, device=gpus[0] ) # d2PdT2_QED_tab = device_put( # d2PdT2_QED_tab , device=gpus[0] # ) f_nue_scat_tab = device_put( f_nue_scat_tab, device=gpus[0] ) f_numu_scat_tab = device_put( f_numu_scat_tab, device=gpus[0] ) f_nue_ann_tab = device_put( f_nue_ann_tab, device=gpus[0] ) f_numu_ann_tab = device_put( f_numu_ann_tab, device=gpus[0] ) f_coeffs = device_put( f_coeffs, device=gpus[0] ) except (RuntimeError, IndexError): # No GPU available or no GPU devices found - data stays on CPU pass ###################### # Standard EM Sector # ######################
[docs] def rho_EM_std(T_g, mu=0, me=const.me, LO=True, NLO=True): """ Total energy density of EM-coupled SM fluids. Parameters ---------- T_g : float Photon temperature in MeV. mu : float, optional Parameter added for syntax consistency--does not impact function behavior. Defaults to 0. me : float, optional Electron mass in MeV. Defaults to const.me. LO : bool True includes leading order QED corrections to the energy density. Defaults to 'True'. NLO : bool True includes next-to-leading order QED corrections to the energy density. Defaults to 'True'. Returns ------- float Units of MeV^4. """ corr_QED = jnp.where(jnp.abs(me/const.me - 1) > 1e-8, # if input me is sufficiently different from const.me, -(LO*explicit_P2(T_g, me) + NLO*explicit_P3(T_g, me)) + T_g*(LO*dPdTQED_2(T_g, me) + NLO*dPdTQED_3(T_g, me)), # compute the QED correction ( -interpax.interp1d( T_g, P_QED_tab[:,0], LO*P_QED_tab[:,1]+NLO*P_QED_tab[:,2] ) + T_g*interpax.interp1d( T_g, dPdT_QED_tab[:,0], LO*dPdT_QED_tab[:,1]+NLO*dPdT_QED_tab[:,2] ) ) # otherwise just use pretabulated values ) return ( rho_massless_BE(T_g, 0., 2) + rho_massive_FD(T_g, 0., me, 4) + corr_QED )
rho_EM_std_v = vmap(rho_EM_std, in_axes=0)
[docs] def p_EM_std(T_g, mu=0, me=const.me, LO=True, NLO=True): """ Total pressure of EM-coupled SM fluids. Parameters ---------- T_g : float Temperature of massive particle species in MeV. mu : float, optional Parameter added for syntax consistency--does not impact function behavior. Defaults to 0. me : float, optional Electron mass in MeV. Defaults to const.me. LO : bool True includes leading order QED corrections to the pressure. Defaults to 'True'. NLO : bool True includes next-to-leading order QED corrections to the pressure. Defaults to 'True'. Returns ------- float Units of MeV^4. """ corr_QED = jnp.where(jnp.abs(me/const.me - 1) > 1e-8, # if input me is sufficiently different from const.me, LO*explicit_P2(T_g, me) + NLO*explicit_P3(T_g, me), # compute the QED correction interpax.interp1d( T_g, P_QED_tab[:,0], LO*P_QED_tab[:,1] + NLO*P_QED_tab[:,2] ) # otherwise just use pretabulated values ) return ( p_massless_BE(T_g, 0., 2) + p_massive_FD(T_g, 0., me, 4) + corr_QED )
p_EM_std_v = vmap(p_EM_std, in_axes=0)
[docs] def rho_plus_p_EM_std(T_g, mu=0, me=const.me, LO=True, NLO=True): """ Sum of energy densities and pressures of all EM-coupled SM fluids. Parameters ---------- T_g : float Photon temperature in MeV. mu : float, optional Parameter added for syntax consistency--does not impact function behavior. Defaults to 0. me : float, optional Electron mass in MeV. Defaults to const.me. LO : bool True includes leading order QED corrections to the energy density and pressure. Defaults to 'True'. NLO : bool True includes next-to-leading order QED corrections to the energy density and pressure. Defaults to 'True'. Returns ------- float Units of MeV^4. """ corr_QED = jnp.where(jnp.abs(me/const.me - 1) > 1e-8, # if input me is sufficiently different from const.me, T_g*(LO*dPdTQED_2(T_g, me) + NLO*dPdTQED_3(T_g, me)), # compute the QED correction T_g * interpax.interp1d( T_g, dPdT_QED_tab[:,0], LO*dPdT_QED_tab[:,1] + NLO*dPdT_QED_tab[:,2] ) # otherwise just use pretabulated values ) return ( 4/3 * rho_massless_BE(T_g, 0., 2) + rho_massive_FD(T_g, 0., me, 4) + p_massive_FD(T_g, 0., me, 4) + corr_QED )
[docs] def T_g(rho_g): """ Photon temperature from photon energy density. Parameters ---------- rho_g : float or array Energy density of photons. Returns ------- float or array Same units as (rho_g)**0.25 """ return (30*rho_g/(2*jnp.pi**2))**(1/4)
drho_EM_dT_g_std = grad(rho_EM_std, argnums=0) ############################ # Standard Neutrino Sector # ############################
[docs] def rho_nue_std(T_nue, mu_nue=0.): """ Total energy density of electron neutrinos. Parameters ---------- T_nue : float Electron neutrino temperature in MeV. mu_nue : float, optional Chemical potential of electron neutrinos in MeV. Defaults to 0. Returns ------- float Units of MeV^4 """ return rho_massless_FD(T_nue, mu_nue, 2)
[docs] def p_nue_std(T_nue, mu_nue=0.): """ Total pressure of electron neutrinos. Parameters ---------- T_nue : float Electron neutrino temperature in MeV. mu_nue : float, optional Chemical potential of electron neutrinos in MeV. Defaults to 0. Returns ------- float Units of MeV^4. """ return rho_massless_FD(T_nue, mu_nue, 2) / 3
[docs] def n_nue_std(T_nue, mu_nue=0.): """ Total number density of electron neutrinos. Parameters ---------- T_nue : float Electron neutrino temperature in MeV. mu_nue : float, optional Chemical potential of electron neutrinos in MeV. Defaults to 0. Returns ------- float Units of MeV^3. """ return n_massless_FD(T_nue, mu_nue, 2)
[docs] def rho_numt_std(T_numt, mu_numt=0.): """ Total energy density of mu, tau neutrinos. Parameters ---------- T_numt : float Mu, tau neutrino temperature in MeV. mu_numt : float, optional Chemical potential of mu,tau neutrinos in MeV. Defaults to 0. Returns ------- float Units of MeV^4. """ return 2 * rho_massless_FD(T_numt, mu_numt, 2)
[docs] def p_numt_std(T_numt, mu_numt=0.): """ Total pressure of mu, tau neutrinos. Parameters ---------- T_numt : float Mu, tau neutrino temperature in MeV. mu_numt : float, optional Chemical potential of mu,tau neutrinos in MeV. Defaults to 0. Returns ------- float Units of MeV^4. """ return 2 * rho_massless_FD(T_numt, mu_numt, 2) / 3
[docs] def n_numt_std(T_numt, mu_numt=0.): """ Total number density of mu, tau neutrinos. Parameters ---------- T_numt : float Mu, tau neutrino temperature in MeV. mu_numt : float, optional Chemical potential of mu,tau neutrinos in MeV. Defaults to 0. Returns ------- float Units of MeV^3. """ return 2 * n_massless_FD(T_numt, mu_numt, 2)
[docs] def T_nu(rho_nu): """ Neutrino temperature given an energy density. Parameters ---------- rho_nu : float or array Neutrino energy density in MeV^4. Returns ------- float or array Units of MeV. """ return (8./7.*30*rho_nu/(2*jnp.pi**2))**(1/4)
drho_nue_dT_nue_std = grad(rho_nue_std, argnums=0) drho_nue_dmu_nue_std = grad(rho_nue_std, argnums=1) dn_nue_dT_nue_std = grad(n_nue_std, argnums=0) dn_nue_dmu_nue_std = grad(n_nue_std, argnums=1) drho_numt_dT_numt_std = grad(rho_numt_std, argnums=0) drho_numt_dmu_numt_std = grad(rho_numt_std, argnums=1) dn_numt_dT_numt_std = grad(n_numt_std, argnums=0) dn_numt_dmu_numt_std = grad(n_numt_std, argnums=1)
[docs] def collision_terms_std( T_g, T_nue, T_numt, me=const.me, mu_nue=0., mu_numt=0., decoupled=False, use_FD=True, collision_me=True ): """ Energy and number density transfer rate between EM and neutrino sector relevant for incomplete neutrino decoupling. Parameters ---------- T_g : array_like Photon temperature in MeV. T_nue : array_like Electron neutrino temperature in MeV. T_numt : array_like Mu, tau neutrino temperature in MeV. me : float, optional Electron mass in MeV. Defaults to const.me mu_nue : float, optional Chemical potential of electron neutrinos in MeV. Defaults to 0. mu_numt : float, optional Chemical potential of mu, tau neutrinos in MeV. Defaults to 0. decoupled : bool, optional Neutrinos are assumed to be completely decoupled if True. use_FD : bool, optional Fermi-Dirac distribution used for neutrinos if True. collision_me : bool, optional Finite electron mass if true. Returns ------- tuple (C_rho_nue, C_rho_numu, C_n_nue, C_n_numu) for the energy density transfer rate (in MeV^4/s) to nu_e and (nu_mu, nu_tau), followed by the number density transfer rate (in MeV^3/s) to nu_e and (nu_mu, nu_tau) """ f_n, f_a, f_s = lax.cond( decoupled, lambda _: (0., 0., 0.), lambda _: lax.cond( use_FD, lambda _: (0.852, 0.884, 0.829), lambda _: (1., 1., 1.), 0. ), 0. ) geL = const.geL geR = const.geR gmuL = const.gmuL gmuR = const.gmuR def G(T_1, mu_1, T_2, mu_2): return ( 32 * f_a * ( T_1**9 * jnp.exp(2 * mu_1 / T_1) - T_2**9 * jnp.exp(2 * mu_2 / T_2) ) + 56 * f_s * jnp.exp(2 * mu_1 / T_1) * jnp.exp(2 * mu_2 / T_2) *( T_1**4 * T_2**4 * (T_1 - T_2) ) ) def G_nue_with_me(T_1, mu_1, T_2, mu_2, me): # CG: update to use interp1d def interp_fa1(f_tab): index = 1 return jnp.interp( me/T_1, f_tab[:,0], f_tab[:,index], left=f_tab[0,index], right=f_tab[-1,index] ) def interp_fa2(f_tab): index = 2 return jnp.interp( me/T_1, f_tab[:,0], f_tab[:,index], left=f_tab[0,index], right=f_tab[-1,index] ) def interp_fs1(f_tab): index = 5 return jnp.interp( me/T_1, f_tab[:,0], f_tab[:,index], left=f_tab[0,index], right=f_tab[-1,index] ) def interp_fs2(f_tab): index = 6 return jnp.interp( me/T_1, f_tab[:,0], f_tab[:,index], left=f_tab[0,index], right=f_tab[-1,index] ) # def interp_f(f_tab): # # Tables have boundary values 0.0 (low T) and 1.0 (high T) # return interpax.interp1d( # T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0) # ) # def interp_f(f_tab): # return jnp.interp( # T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1] # ) # f_nue_ann = lax.cond( # collision_me, interp_f, lambda _: 1., f_nue_ann_tab # ) # f_nue_scat = lax.cond( # collision_me, interp_f, lambda _: 1., f_nue_scat_tab # ) f_ann_1 = lax.cond( collision_me, interp_fa1, lambda _: 1., f_coeffs ) f_scat_1 = lax.cond( collision_me, interp_fs1, lambda _: 1., f_coeffs ) f_ann_2 = lax.cond( collision_me, interp_fa2, lambda _: 1., f_coeffs ) f_scat_2 = lax.cond( collision_me, interp_fs2, lambda _: 1., f_coeffs ) return ( # note f_a and f_s are now folded into f_nue_ann/scat 4 * (geL**2 + geR**2) * (32 * f_ann_1 * ( T_1**9 * jnp.exp(2 * mu_1 / T_1) - T_2**9 * jnp.exp(2 * mu_2 / T_2) ) + 56 * f_scat_1 * ( jnp.exp(2 * mu_1 / T_1) * jnp.exp(2 * mu_2 / T_2) * T_1**4 * T_2**4 * (T_1 - T_2) ) ) # new terms (previously baked into tabulated rates) + 4 * geL*geR * (f_ann_2 * 32 * ( T_1**9 * jnp.exp(2 * mu_1 / T_1) - T_2**9 * jnp.exp(2 * mu_2 / T_2) ) + 56 * f_scat_2 * ( jnp.exp(2 * mu_1 / T_1) * jnp.exp(2 * mu_2 / T_2) * T_1**4 * T_2**4 * (T_1 - T_2) ) ) ) # CG: update to use interp1d def G_numt_with_me(T_1, mu_1, T_2, mu_2, me): def interp_fa1(f_tab): index = 1 return jnp.interp( me/T_1, f_tab[:,0], f_tab[:,index], left=f_tab[0,index], right=f_tab[-1,index] ) def interp_fa2(f_tab): index = 2 return jnp.interp( me/T_1, f_tab[:,0], f_tab[:,index], left=f_tab[0,index], right=f_tab[-1,index] ) def interp_fs1(f_tab): index = 5 return jnp.interp( me/T_1, f_tab[:,0], f_tab[:,index], left=f_tab[0,index], right=f_tab[-1,index] ) def interp_fs2(f_tab): index = 6 return jnp.interp( me/T_1, f_tab[:,0], f_tab[:,index], left=f_tab[0,index], right=f_tab[-1,index]) # def G_numt_with_me(T_1, mu_1, T_2, mu_2): # def interp_f(f_tab): # # Tables have boundary values 0.0 (low T) and 1.0 (high T) # return interpax.interp1d( # T_1, f_tab[:,0], f_tab[:,1], extrap=(0.0, 1.0) # ) # def interp_f(f_tab): # return jnp.interp( # T_1, f_tab[:,0], f_tab[:,1], left=f_tab[0,1], right=f_tab[-1,1] # ) # f_numt_ann = lax.cond( # collision_me, interp_f, lambda _: 1., f_numu_ann_tab # ) # f_numt_scat = lax.cond( # collision_me, interp_f, lambda _: 1., f_numu_scat_tab # ) f_ann_1 = lax.cond( collision_me, interp_fa1, lambda _: 1., f_coeffs ) f_scat_1 = lax.cond( collision_me, interp_fs1, lambda _: 1., f_coeffs ) f_ann_2 = lax.cond( collision_me, interp_fa2, lambda _: 1., f_coeffs ) f_scat_2 = lax.cond( collision_me, interp_fs2, lambda _: 1., f_coeffs ) return ( # f_s, f_a now folded into f_ann and f_scat 4 * (gmuL**2 + gmuR**2) * (32 * f_ann_1 * ( T_1**9 * jnp.exp(2 * mu_1 / T_1) - T_2**9 * jnp.exp(2 * mu_2 / T_2) ) + 56 * f_scat_1 * ( jnp.exp(2 * mu_1 / T_1) * jnp.exp(2 * mu_2 / T_2) * T_1**4 * T_2**4 * (T_1 - T_2) ) ) # new terms (previously baked into tabulated rates) + 4 * gmuL*gmuR * (f_ann_2 * 32 * ( T_1**9 * jnp.exp(2 * mu_1 / T_1) - T_2**9 * jnp.exp(2 * mu_2 / T_2) ) + 56 * f_scat_2 * ( jnp.exp(2 * mu_1 / T_1) * jnp.exp(2 * mu_2 / T_2) * T_1**4 * T_2**4 * (T_1 - T_2) ) ) ) # Units MeV^4 s^-1 C_rho_nue = const.GF**2 / jnp.pi**5 * ( # prev coeff now in G def G_nue_with_me(T_g, 0., T_nue, mu_nue, me) + 2 * G(T_numt, mu_numt, T_nue, mu_nue) ) / const.hbar # Units MeV^4 s^-1 C_rho_numu = const.GF**2 / jnp.pi**5 * ( G_numt_with_me(T_g, 0., T_numt, mu_numt, me) - G(T_numt, mu_numt, T_nue, mu_nue) ) / const.hbar # Units MeV^3 s^-1 C_n_nue = 8 * f_n * const.GF**2 / jnp.pi**5 * ( 4 * (geL**2 + geR**2) * (T_g**8 - T_nue**8 * jnp.exp(2 * mu_nue / T_nue)) + 2 * ( T_numt**8 * jnp.exp(2 * mu_numt / T_numt) - T_nue**8 * jnp.exp(2 * mu_nue / T_nue) ) ) / const.hbar # Units MeV^3 s^-1 C_n_numu = 8 * f_n * const.GF**2 / jnp.pi**5 * ( 4 * (gmuL**2 + gmuR**2) * (T_g**8 - T_nue**8 * jnp.exp(2 * mu_numt / T_numt)) - ( T_numt**8 * jnp.exp(2 * mu_numt / T_numt) - T_nue**8 * jnp.exp(2 * mu_nue / T_nue) ) ) / const.hbar return C_rho_nue, C_rho_numu, C_n_nue, C_n_numu