import jax
import jax.numpy as jnp
from jax import vmap
import equinox as eqx
from diffrax import diffeqsolve, ODETerm, Tsit5, PIDController, SaveAt, Event
import linx.thermo as thermo
import linx.const as const
rho_massless_BE_v = vmap(
thermo.rho_massless_BE, in_axes=(0, None, None)
)
rho_massless_FD_v = vmap(
thermo.rho_massless_FD, in_axes=(0, None, None)
)
[docs]
class BackgroundModel(eqx.Module):
"""Background model.
Attributes
----------
decoupled : bool, optional
Whether neutrinos are always decoupled. Default is `False`.
use_FD : bool, optional
Whether to use Fermi-Dirac statistics for neutrinos, or a Maxwell-Boltzmann distribution. Default is `True`.
collision_me : bool, optional
Finite electron mass correction in energy transfer collision terms. Default is `True`.
LO : bool, optional
Whether to use leading order QED correction. Default is `True`.
NLO : bool, optional
Whether to use next-to-leading order QED correction. Default is True.
throw : bool, optional
Whether to raise exceptions on solver failure. Default is `True`.
Set to `False` for parameter scans where some combinations may fail.
"""
decoupled : bool
use_FD : bool
collision_me : bool
LO : bool
NLO : bool
max_steps : int
throw : bool
def __init__(self, decoupled=False, use_FD=True, collision_me=True, LO=True, NLO = True, throw=True, max_steps=512):
"""
Initialize the BackgroundModel with thermodynamic options.
Parameters
----------
decoupled : bool, optional
If True, neutrinos are always decoupled. Default is False.
use_FD : bool, optional
If True, use Fermi-Dirac statistics for neutrinos. Default is True.
collision_me : bool, optional
If True, include finite electron mass corrections in collision terms.
Default is True.
LO : bool, optional
If True, include leading order QED corrections. Default is True.
NLO : bool, optional
If True, include next-to-leading order QED corrections. Default is True.
throw : bool, optional
If True, raise exceptions on solver failure. Default is True.
Set to False for parameter scans where some combinations may fail.
"""
self.decoupled = decoupled
self.use_FD = use_FD
self.collision_me = collision_me
self.LO = LO
self.NLO = NLO
self.max_steps = max_steps
self.throw = throw
[docs]
@eqx.filter_jit
def __call__(
self, Delt_Neff_init, T_start=const.T_start,
T_end=const.T_end, me=const.me, rtol=1e-8, atol=1e-10,
solver=Tsit5(),
):
""" Calculate thermodynamics given an initial :math:`\\Delta N_\\mathrm{eff}`.
Parameters
----------
Delt_Neff_init : float
Initial :math:`\\Delta N_\\mathrm{eff}`. Can be positive or negative.
T_EM_init : float, optional
Initial EM (and neutrino) temperature. Default is `const.T_start`.
T_EM_end : float, optional
Final EM temperature to terminate integration at. Default is
`const.T_end`.
me : float, optional
Electron mass in MeV. Defaults to `const.me`.
rtol : float, optional
Relative tolerance of the abundance solver. Default is `1e-8`.
atol : float, optional
Absolute tolerance of the abundance solver. Default is `1e-10`.
max_steps : int, optional
Maximum number of steps taken by the solver. Default is `4096`.
Increasing this slows down the code, while decreasing this could
mean that the solver cannot complete the solution.
solver : Diffrax ODE solver
The Diffrax ODE solver to use. A stiff solver is recommended.
Default is the Tsitouras' 5/4 solver.
Returns
-------
t_vec : array_like
Times in s at which thermodynamics are saved.
a_vec : array_like
Scale factor at each point in time.
rho_g_vec : array_like
Energy density of photons in MeV^4 at each point in time.
rho_nu_vec : array_like
Energy density of one species of neutrinos in MeV^4 at each point in time.
rho_extra_vec : array_like
Energy density in MeV^4 of extra species at each point in time.
"""
print('`\\ /´ |||| |||| ||||| |||| |||| ||||')
print(' /\\_______/\\ |||| |||| ||||||| |||| |||| ||||')
print(' ) __` ´__ ( |||| |||| |||| |||| |||| |||||||')
print('/ `-|_|-´ \\ |||| |||| |||| |||| ||| ||||||| ')
print('/ (_x_) \\ |||||||||| |||| |||| ||||||| |||| ||||')
print(' ) `-´ ( |||||||||| |||| |||| |||||| |||| ||||')
print(' ')
print('Compiling thermodynamics model...')
lna_init = 0.
T_EM_init = T_start
T_nu_init = T_EM_init
rho_extra_init = (7/8) * (4/11)**(4/3) * (
thermo.rho_massless_BE(T_EM_init, 0., 2)
) * Delt_Neff_init
Y0 = (lna_init, T_EM_init, T_nu_init)
# use parametric form to estimate correct start
# time given T_start, assuming T ~ t^(-1/2) and
# initial g_* is 10.75
t0 = (1.5/T_start * 10.75**(-1./4))**2
def T_EM_check(t, y, args, **kwargs):
return y[1] < T_end
sol = diffeqsolve(
ODETerm(self.dY), solver, args=(lna_init, rho_extra_init, me),
t0 = t0, t1=jnp.inf, dt0=None, y0=Y0,
saveat=SaveAt(steps=True), event=Event(T_EM_check),
stepsize_controller=PIDController(
rtol=rtol, atol=atol
),
max_steps=self.max_steps,
throw=self.throw
)
a_vec = jnp.exp(sol.ys[0])
rho_g_vec = rho_massless_BE_v(sol.ys[1], 0., 2)
rho_nu_vec = rho_massless_FD_v(sol.ys[2], 0., 2)
# These vectors always have max_steps entries so that jit and grad
# work but the solver stops before hitting max_steps. Find the last
# legitimate step made by the solver, when T_g drops below T_end.
last_step_ind = jnp.max(
jnp.argwhere(
sol.ys[1] < T_end,
size=self.max_steps
)[:,0]
)
# Set every step after this in all vectors to be identical,
# so that there is no effect on interpolation.
t_vec = jnp.where(sol.ts == jnp.inf, sol.ts[last_step_ind], sol.ts)
a_vec = jnp.where(a_vec == jnp.inf, a_vec[last_step_ind], a_vec)
rho_g_vec = jnp.where(
rho_g_vec == jnp.inf, rho_g_vec[last_step_ind], rho_g_vec
)
rho_nu_vec = jnp.where(
rho_nu_vec == jnp.inf, rho_nu_vec[last_step_ind], rho_nu_vec
)
# Rescale a so that the present day CMB temperature is correct.
final_a = const.T0CMB / (
sol.ys[1][last_step_ind]
)
a_vec *= final_a / a_vec[-1]
# Trivial relation between rho_extra and a.
rho_extra_vec = rho_extra_init * (a_vec[0]**4 / a_vec**4)
P_extra_vec = rho_extra_vec / 3
T_g_vec = thermo.T_g(rho_g_vec)
rho_tot_vec = (
thermo.rho_EM_std_v(T_g_vec) + 3 * rho_nu_vec + rho_extra_vec
)
Neff_vec = thermo.N_eff(rho_tot_vec, rho_g_vec)
return (
t_vec, a_vec, rho_g_vec, rho_nu_vec,
rho_extra_vec, P_extra_vec, Neff_vec
)
[docs]
@eqx.filter_jit
def dY(self, t, Y, args):
""" Differential equation for background quantities.
Parameters
----------
t : float
Time in s.
Y : tuple of floats
The values of :math:`\\log a`, :math:`T_\\gamma` and :math:`T_\\nu`.
args : tuple of floats
The initial value of :math:`\\log a`, and the initial energy density in MeV^4 of the extra inert relativistic species.
Returns
-------
The time derivative of the quantities specified in `Y`.
"""
lna, T_g, T_nu = Y
lna_init, rho_extra_init, me = args
rho_EM = thermo.rho_EM_std(T_g, me=me, LO=self.LO, NLO=self.NLO)
rho_plus_p_EM = thermo.rho_plus_p_EM_std(T_g, me=me, LO=self.LO, NLO=self.NLO)
drho_EM_dT_g = thermo.drho_EM_dT_g_std(T_g, me=me, LO=self.LO, NLO=self.NLO)
rho_nu = 3*thermo.rho_nue_std(T_nu)
rho_plus_p_nu = (4/3) * rho_nu
drho_nu_dT_nu = 3*thermo.drho_nue_dT_nue_std(T_nu)
rho_extra = rho_extra_init * jnp.exp(lna_init)**4 / jnp.exp(lna)**4
H = thermo.Hubble(rho_EM + rho_nu + rho_extra)
C_rho_nue, C_rho_numu, _, _ = thermo.collision_terms_std(
T_g, T_nu, T_nu, me=me, decoupled=self.decoupled, use_FD=self.use_FD, collision_me=self.collision_me
)
drho_EM_dt = -3 * H * rho_plus_p_EM - C_rho_nue - 2*C_rho_numu
drho_nu_dt = -3 * H * rho_plus_p_nu + C_rho_nue + 2*C_rho_numu
dT_g_dt = drho_EM_dt / drho_EM_dT_g
dT_nu_dt = drho_nu_dt / drho_nu_dT_nu
return H, dT_g_dt, dT_nu_dt