import sys
sys.path.append('..')
import jax.numpy as jnp
import equinox as eqx
from diffrax import diffeqsolve, ODETerm, Tsit5, Kvaerno3, PIDController, SaveAt
import interpax
import linx.nuclear as nucl
import linx.const as const
from linx.const import ma, me, mn, mp
import linx.weak_rates as wr
import linx.thermo as thermo
from linx.thermo import rho_EM_std_v, p_EM_std_v, nB
from linx.special_funcs import zeta_3
from linx.tau_n_vary_me import tau_n_fac_vary_me
[docs]
class AbundanceModel(eqx.Module):
"""
Abundance model and BBN abundance prediction.
Attributes
----------
nuclear_net : NuclearRates
Nuclear network to be used for BBN prediction.
weak_rates : WeakRates
Weak rates for neutron-proton interconversion.
species_dict : dict
Dictionary of species considered in LINX.
species_Z : list
Number of protons in each species.
species_N : list
Number of neutrons in each species.
species_A : list
Atomic mass number of each species.
species_excess_mass : list
Excess mass (mass - A*amu) of each species.
species_spin : list
Spin of each species.
species_binding_energy : list
Binding energy of each species.
throw : bool
Whether to raise exceptions on solver failure.
"""
nuclear_net : nucl.NuclearRates
weak_rates : wr.WeakRates
species_dict : dict
species_Z : list
species_N : list
species_A : list
species_excess_mass : dict
species_spin : list
species_binding_energy : list
throw : bool
def __init__(self, nuclear_net, weak_rates=wr.WeakRates(), throw=True):
"""
Initialize the AbundanceModel with nuclear and weak rate networks.
Parameters
----------
nuclear_net : NuclearRates
Nuclear reaction network to be used for BBN calculations. This
defines which nuclear reactions are included in the evolution.
weak_rates : WeakRates, optional
Weak interaction rates for neutron-proton interconversion.
Defaults to standard WeakRates instance.
throw : bool, optional
If True, raise exceptions on solver failure. Default is True.
Set to False for parameter scans where some combinations may fail.
Notes
-----
This constructor initializes the species dictionary, atomic properties
(Z, N, A), masses, spins, binding energies, and excess masses for all
nuclear species considered in the LINX package (n, p, d, t, He3, He4,
Li6, Li7, Li8, Be7, He6, B8).
"""
self.nuclear_net = nuclear_net
self.weak_rates = weak_rates
self.throw = throw
self.species_dict = {
0:'n', 1:'p', 2:'d', 3:'t', 4:'He3', 5:'a', 6:'Li7', 7:'Be7',
8: 'Li6', 9: 'He6', 10: 'Li8', 11:'B8'
}
self.species_Z = jnp.array([0, 1, 1, 1, 2, 2, 3, 4, 3, 2, 3, 5])
self.species_N = jnp.array([1, 0, 1, 2, 1, 2, 4, 3, 3, 4, 5, 3])
self.species_A = self.species_Z + self.species_N
# in MeV
self.species_excess_mass = jnp.array([
8071.3171, 7288.9706, 13135.722, 14949.81,
14931.218, 2424.9156, 14907.105, 15769.,
14086.8789, 17592.10, 20945.80, 22921.6
]) * 1e-3
self.species_spin = jnp.array([
1./2., 1./2., 1., 1./2., 1./2., 0.,
3./2., 3./2., 1., 0., 2., 2.
])
# in MeV
self.species_binding_energy = (
self.species_N * self.species_excess_mass[0]
+ self.species_Z * self.species_excess_mass[1]
- self.species_excess_mass
)
# in MeV
# requires recompilation for each me--moved to YNSE method
# self.species_mass = (
# self.species_A * ma + self.species_excess_mass - self.species_Z * me
# )
[docs]
@eqx.filter_jit
def __call__(
self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec,
a_vec=None, t_vec=None,
eta_fac=jnp.asarray(1.), tau_n_fac = jnp.asarray(1.),
nuclear_rates_q=None, me = const.me,
Y_i=None, T_start=None, T_end=None, sampling_nTOp=150,
rtol=1e-6, atol=1e-9, solver=Kvaerno3(),
max_steps=4096,
save_history=False
):
"""
Calculate BBN abundance.
Parameters
----------
rho_g_vec : array
Energy density of photons in MeV^4.
rho_nu_vec : array
Energy density of a single neutrino species in MeV^4
(all neutrinos assumed to have the same temperature).
rho_NP_vec : array
Energy density of all new physics particles in MeV^4.
P_NP_vec : array
Pressure of all new physics particles in MeV^4.
a_vec : array, optional
Scale factor. If `None`, will be computed in function.
t_vec : array, optional
Time in seconds. If `None`, will be computed in function.
eta_fac : float, optional
Rescaling factor for baryon-to-photon ratio, 1 for fiducial value
in `const.eta0` (or `const.Omegabh2`).
tau_n_fac : float, optional
Rescaling factor for neutron decay lifetime, 1 for fiducial value
in `const.tau_n`.
nuclear_rates_q : array, optional
q ~ N(0,1) specifies the nuclear rate in its log-normal
distribution. If not specified, will be taken to be `q = 0`.
me : float, optional
Electron mass in MeV. Defaults to `const.me`.
Y_i : tuple of float, optional
Initial abundances :math:`n_i/n_b` for species. Length must be equal to
`self.nuclear_net.max_i_species`. Must specify `T_start` and `T_end` if not `None`.
T_start : float
Temperature in MeV to start integration. Must specify `Y_i` and `T_end` if not `None`, otherwise `const.T_start` used.
T_end : float
Temperature in MeV to end integration.
sampling_nTOp : int
Number of points to subdivide (`T_end`, `T_start`) for neutron-proton interconversion rate interpolation table.
rtol : float, optional
Relative tolerance of the abundance solver. Default is `1e-4`.
atol : float, optional
Absolute tolerance of the abundance solver. Default is `1e-9`.
max_steps : int, optional
Maximum number of steps taken by the solver. Default is `4096`.
Increasing this slows down the code, while decreasing this could
mean that the solver cannot complete the solution.
solver : Diffrax ODE solver
The Diffrax ODE solver to use. A stiff solver is recommended.
Default is the 3rd order Kvaerno solver.
save_history : bool
If `True`, full solution is returned with temperature and time
abscissa.
Returns
-------
tuple of array or array
If `save_history` is set to `True`, a tuple containing an array of EM temperatures, an array of times, and a Diffrax `Solution` instance, which can be called as a function of time. Otherwise, returns yields of all species considered in `self.nuclear_net`.
"""
print('Compiling abundance model...')
if Y_i is not None:
if T_start is None:
raise TypeError('Specifying Y_i requires specifying a T_start')
if T_start is not None:
if Y_i is None:
raise TypeError('Specifying T_start requires specifying Y_i')
if nuclear_rates_q is None:
nuclear_rates_q = jnp.array(
[0. for _ in self.nuclear_net.reactions]
)
if t_vec is None:
t_vec = self.get_t(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec)
if a_vec is None:
a_vec = self.get_a(rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec)
if T_start is None:
T_start = const.T_start
if T_end is None:
T_end = const.T_end
# check if the user has varied me, and adjust the neutron lifetime if so
diff = jnp.abs(me - const.me)/const.me
tau_n_fac = jnp.where(diff > 1e-5, tau_n_fac_vary_me(me), 1.) * tau_n_fac
# These are in MeV
T_g_vec = thermo.T_g(rho_g_vec)
T_nu_vec = thermo.T_nu(rho_nu_vec)
# Sort by log(T_g) to handle non-monotonic temperature evolution
# (e.g., in reheating scenarios). interpax.interp1d requires
# monotonically increasing x coordinates.
log_T_g_vec = jnp.log(T_g_vec)
sort_idx = jnp.argsort(log_T_g_vec)
log_T_g_sorted = log_T_g_vec[sort_idx]
log_a_sorted = jnp.log(a_vec)[sort_idx]
log_t_sorted = jnp.log(t_vec)[sort_idx]
a_start = jnp.exp(
interpax.interp1d(
jnp.log(T_start),
log_T_g_sorted,
log_a_sorted,
method='linear',
extrap=True
)
)
t_start = jnp.exp(
interpax.interp1d(
jnp.log(T_start),
log_T_g_sorted,
log_t_sorted,
method='linear',
extrap=True
)
)
t_end = jnp.exp(
interpax.interp1d(
jnp.log(T_end),
log_T_g_sorted,
log_t_sorted,
method='linear',
extrap=True
)
)
##################################
# Weak Rates #
##################################
T_interval_nTOp, nTOp_frwrd, nTOp_bkwrd = self.weak_rates(
jnp.array([T_g_vec, T_nu_vec]),
T_start=T_start, T_end=T_end, sampling_nTOp=sampling_nTOp, me=me
)
##################################
# Initialization of Abundances #
##################################
if Y_i is None:
# If not provided, initialized to const.T_start.
# Neutron and proton yields, based on rates.
Yn_i = nTOp_bkwrd[0] / (nTOp_frwrd[0] + nTOp_bkwrd[0])
Yp_i = 1. - Yn_i
# Other elements start at statistical equilibrium.
n_CMB_start = thermo.n_massless_BE(T_start, 0., 2.)
eta_T_start = nB(a_start, eta_fac=eta_fac) / n_CMB_start
Y_YNSE = self.YNSE(Yn_i, Yp_i, const.T_start, eta_T_start, me)
Y_others_i = Y_YNSE[2:self.nuclear_net.max_i_species]
Y_i = (Yn_i, Yp_i) + tuple(Y_others_i)
if save_history:
saveat = SaveAt(dense=True)
else:
# Default SaveAt
saveat = SaveAt(t1=True)
sol = diffeqsolve(
ODETerm(self.Y_prime), solver,
t0=t_start, t1=t_end, dt0=None, y0=Y_i,
args=(
a_vec, t_vec, T_g_vec, T_interval_nTOp, nTOp_frwrd,
nTOp_bkwrd, eta_fac, tau_n_fac, nuclear_rates_q
),
saveat=saveat,
stepsize_controller=PIDController(
rtol=rtol, atol=atol,
),
max_steps=max_steps,
throw=self.throw
)
if save_history:
return sol
else:
Y_f = jnp.array(sol.ys).flatten()
return Y_f
[docs]
@eqx.filter_jit
def get_t(self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec):
"""
Time elapsed.
Parameters
----------
rho_g_vec : array
Energy density of photons.
rho_nu_vec : array
Energy density of one species of neutrinos (assumed identical for all species).
rho_NP_vec : array
Energy density of all new physics fluids.
P_NP_vec : array
Pressure of all new physics fluids.
Returns
-------
array
Array of times in seconds corresponding to physical parameters
above. Initial time taken to be 1 / (2H).
"""
T_g_vec = thermo.T_g(rho_g_vec)
rho_tot_vec = (
rho_EM_std_v(T_g_vec) + 3 * rho_nu_vec + rho_NP_vec
)
P_tot_vec = p_EM_std_v(T_g_vec) + 3 * (rho_nu_vec/3) + P_NP_vec
# Sort by log(rho_tot) to handle non-monotonic energy density evolution
# (e.g., in reheating scenarios)
log_rho_tot_vec = jnp.log(rho_tot_vec)
sort_idx_rho = jnp.argsort(log_rho_tot_vec)
log_rho_sorted = log_rho_tot_vec[sort_idx_rho]
log_P_sorted = jnp.log(P_tot_vec)[sort_idx_rho]
def P_tot(rho_tot):
return jnp.exp(
interpax.interp1d(
jnp.log(rho_tot),
log_rho_sorted,
log_P_sorted,
method='linear',
extrap=True
)
)
def dt_prime(rho_tot, t, args):
return 1. / (
-3. * thermo.Hubble(rho_tot) *(rho_tot + P_tot(rho_tot))
)
rho_tot_init = rho_tot_vec[0]
rho_tot_fin = rho_tot_vec[-1]
sol_t = diffeqsolve(
ODETerm(dt_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
y0=1. / (2 * thermo.Hubble(rho_tot_init)),
dt0=None, max_steps=4096,
saveat=SaveAt(ts=rho_tot_vec),
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10),
throw=self.throw
)
return sol_t.ys
[docs]
@eqx.filter_jit
def get_a(self, rho_g_vec, rho_nu_vec, rho_NP_vec, P_NP_vec):
"""
Scale factor.
Parameters
----------
rho_g_vec : array
Energy density of photons.
rho_nu_vec : array
Energy density of one species of neutrinos (assumed identical for
all species).
rho_NP_vec : array
Energy density of all new physics fluids.
P_NP_vec : array
Pressure of all new physics fluids.
Returns
-------
array
Array of scale factors corresponding to physical parameters above.
Notes
-----
The final entry `a[-1]` is given by `const.T0CMB / T_gamma[-1]`, where `const.T0CMB` is the CMB temperature measured today, and `T_gamma[-1]` is the temperature of photons in the last entry of `rho_g_vec`. In other words, we assume no subsequent entropy dump in the electromagnetic sector.
"""
T_g_vec = thermo.T_g(rho_g_vec)
rho_tot_vec = (
rho_EM_std_v(T_g_vec) + 3 * rho_nu_vec + rho_NP_vec
)
P_tot_vec = p_EM_std_v(T_g_vec) + 3 * (rho_nu_vec/3) + P_NP_vec
# Sort by log(rho_tot) to handle non-monotonic energy density evolution
# (e.g., in reheating scenarios)
log_rho_tot_vec = jnp.log(rho_tot_vec)
sort_idx_rho = jnp.argsort(log_rho_tot_vec)
log_rho_sorted = log_rho_tot_vec[sort_idx_rho]
log_P_sorted = jnp.log(P_tot_vec)[sort_idx_rho]
def P_tot(rho_tot):
return jnp.exp(
interpax.interp1d(
jnp.log(rho_tot),
log_rho_sorted,
log_P_sorted,
method='linear',
extrap=True
)
)
def dlna_prime(rho_tot, t, args):
return 1. / (-3. * (rho_tot + P_tot(rho_tot)))
rho_tot_init = rho_tot_vec[0]
rho_tot_fin = rho_tot_vec[-1]
# a_0 = 1 arbitrarily, will rescale later.
sol_lna = diffeqsolve(
ODETerm(dlna_prime), Tsit5(),
t0=rho_tot_init, t1=rho_tot_fin,
y0=0., dt0=None, max_steps=4096,
saveat=SaveAt(ts=rho_tot_vec),
stepsize_controller=PIDController(rtol=1e-8, atol=1e-10),
throw=self.throw
)
a_fin = const.T0CMB / T_g_vec[-1]
a_vec = jnp.exp(sol_lna.ys)
# Rescale so that the last a is given by T_g_vec[-1] / TCMB today.
a_vec = a_vec / a_vec[-1] * a_fin
return a_vec
[docs]
@eqx.filter_jit
def Y_prime(self, t, Y, args):
"""Returns :math:`dY_i/dt` for this abundance model.
Parameters
----------
t : float
Time at which to evaluate :math:`dY_i/dt`.
Y : array
Array of abundances for evaluating :math:`dY_i/dt`.
args : tuple of arrays
Other relevant information for evaluating the derivative. These are respectively, 0) an array of scale factors; 1) an array of times; 2) an array of EM sector temperatures; 3) an array representing the abscissa of EM sector temperatures for evaluating weak rates; 4) an array of n -> p rates to interpolate over; 5) an array of p -> n rates to interpolate over; 6) the rescaling factor for baryon-to-photon ratio `eta_fac`; 7) the rescaling factor for neutron decay lifetime `tau_n_fac` and 8) the array rescaling nuclear rates, `nuclear_rates_q`.
Returns
-------
array
:math:`dY_i/dt` at the given time and at the present abundance levels.
"""
a_vec_in = args[0]
t_vec_in = args[1]
T_g_vec_in = args[2]
T_interval_in = args[3]
nTOp_frwrd_vec_in = args[4]
nTOp_bkwrd_vec_in = args[5]
eta_fac = args[6]
tau_n_fac = args[7]
nuclear_rates_q = args[8]
a_in = a_vec_in[0]
a_fin = a_vec_in[-1]
a = jnp.interp(
t, t_vec_in, a_vec_in,
left=a_in,right=a_fin
)
# Baryon density rescaled by eta_fac.
n0B = eta_fac*const.n0CMB*const.eta0
rho0BmaOvermB = ma * n0B
# number density times amu.
rhoBBN = rho0BmaOvermB * const.MeV4_to_gcmm3/a**3
T_t = jnp.interp(
t, t_vec_in, T_g_vec_in, left=T_g_vec_in[0],right=T_g_vec_in[-1]
)
dY = self.nuclear_net(
Y, T_t, rhoBBN, T_interval_in, nTOp_frwrd_vec_in,
nTOp_bkwrd_vec_in, tau_n_fac=tau_n_fac,
nuclear_rates_q=nuclear_rates_q
)
return dY
[docs]
def YNSE(self, Yn, Yp, T, eta, me=const.me):
"""
Nuclear statistical equilibrium yields for all species.
Parameters
----------
Yn : float
The yield :math:`n_n / n_b` of free neutrons.
Yp : float
The yield :math:`n_p / n_b` of free protons.
T : float
The temperature of the baryons in MeV.
eta : float
The baryon-to-photon ratio.
me: float, optional
Electron mass in MeV. Defaults to const.me
Returns
-------
array
Yields for all species considered in LINX (13 of them).
"""
species_mass = (
self.species_A * ma + self.species_excess_mass - self.species_Z * me
)
A32Overmn = (
species_mass / (
mn**(self.species_A - self.species_Z)
* mp**self.species_Z
)
)**(3/2)
return (
(2 * self.species_spin + 1) * zeta_3**(self.species_A-1)
* jnp.pi**((1-self.species_A)/2) * 2**((3*self.species_A-5)/2)
* A32Overmn * T**(3/2*(self.species_A-1))
* eta**(self.species_A-1) * Yp**self.species_Z
* Yn**(self.species_A-self.species_Z)
* jnp.exp(self.species_binding_energy / T)
)