Source code for reactions

import os

import numpy as np

import jax
import jax.numpy as jnp
import equinox as eqx

file_dir = os.path.dirname(__file__)

[docs] class Reaction(eqx.Module): """ Nuclear reaction. Attributes ---------- name : str Name of the reaction. in_states : tuple Particles on the LHS of the reaction. Convention is 0:n, 1:p, 2:d, 3:t, 4:He3, 5:a, 6:Li7, 7:Be7, 8: He6, 9: Li6, 10: Li8, 11:B8. out_states : tuple Particles on the RHS of the reaction. frwrd_symmetry_fac : float Symmetry factor associated with forward direction. bkwrd_symmetry_fac : float Symmetry factor associated with backward direction. alpha : float Coefficient for relation between forward and backward reaction. Dimensionful. beta : float Coefficient, same as above, dimensionless. gamma : float Coefficient, same as above, dimensionless. T9_vec : list (T/10^9 K), abscissa for reaction rate parameter with spline fit. mu_median_vec : list Median reaction rate parameter with spline fit. expsigma_vec : list Exponential uncertainty of reaction rate parameter with spline fit. interp_type : str Interpolation for spline fit. Either 'linear' or 'log'. frwrd_rate_param_func : callable Forward rate parameter function, if no spline fit. Notes ----- The rate functions are either <sigma v> or <sigma v^2> divided by (1 amu)^(N_in-1)) for each reaction, units (cm^3/s/g or cm^6/s/g^2). """ name : str in_states : tuple out_states : tuple frwrd_symmetry_fac : float bkwrd_symmetry_fac : float alpha : float beta : float gamma : float T9_vec : list mu_median_vec : list expsigma_vec : list interp_type : str frwrd_rate_param_func : callable def __init__( self, name, in_states, out_states, alpha, beta, gamma, spline_data=None, frwrd_rate_param_func=None, interp_type=None ): """ Parameters ---------- name : str Name of the reaction. in_states : tuple Particles on the LHS of the reaction. Convention is 0:n, 1:p, 2:d, 3:t, 4:He3, 5:a, 6:Li7, 7:Be7, 8: He6, 9: Li6, 10: Li8, 11:B8. out_states : tuple Particles on the RHS of the reaction. alpha : float Coefficient for relation between forward and backward reaction. Dimensionless. beta : float Coefficient, same as above, dimensionless. gamma : float Coefficient, same as above, dimensionless. spline_data : string, optional If provided, reads data from 'data/nuclear_rates/'+spline_data in the code. Otherwise, frwrd_rate_param_func must be specified. frwrd_rate_param_func : callable, optional If provided, is a function that returns the forward rate parameter. Takes two arguments, `T` for EM temperature in K and `p` for rescaling of the rate. Notes ----- For `spline_data`, data file should contain three columns: first column `T9` gives the EM temperature in units of 1e9 K, second column `mu` gives the mean rate, and third column `expsigma`, with `log(expsigma)` giving the uncertainty in log of the rate. In other words, we take log(<sigma v>) = log(mu) + p*log(expsigma), where p follows a Gaussian distribution, or equivalently <sigma v> = mu * exp(p*log(expsigma)). The rates in `spline_data` or `frwrd_rate_param_func` are <sigma v> or <sigma v^2> divided by (1 amu)^(N_in-1)) for each reaction, units (cm^3/s/g or cm^6/s/g^2). """ self.name = name self.in_states = in_states self.out_states = out_states multiplicity_in = jnp.array( [self.in_states.count(i) for i in set(self.in_states)] ) self.frwrd_symmetry_fac = jnp.prod(1. / multiplicity_in) multiplicity_out = jnp.array( [self.out_states.count(i) for i in set(self.out_states)] ) self.bkwrd_symmetry_fac = jnp.prod(1. / multiplicity_out) self.alpha = alpha self.beta = beta self.gamma = gamma self.interp_type = interp_type self.T9_vec = None self.mu_median_vec = None self.expsigma_vec = None self.frwrd_rate_param_func = None if spline_data: self.T9_vec, self.mu_median_vec, self.expsigma_vec = np.loadtxt( file_dir+'/data/nuclear_rates/'+spline_data, unpack=True ) try: gpus = jax.devices('gpu') self.T9_vec = jax.device_put(self.T9_vec, device=gpus[0]) self.mu_median_vec = jax.device_put( self.mu_median_vec, device=gpus[0] ) self.expsigma_vec = jax.device_put( self.expsigma_vec, device=gpus[0] ) except (RuntimeError, IndexError): # No GPU available or no GPU devices found - data stays on CPU pass elif frwrd_rate_param_func is not None: self.frwrd_rate_param_func = frwrd_rate_param_func else: raise TypeError('Must include spline data points or analytic ' 'function for rates.')
[docs] @eqx.filter_jit def frwrd_rate_param(self, T, p): """ Forward rate parameter. Parameters ---------- T : float Temperature in K. p : float Rescaling parameter for expsigma. interp_type : str, optional Interpolation type for spline data. Either 'linear' or 'log'. Returns ------- float Notes ----- We take log(<sigma v>) = log(mu) + p*log(expsigma), where p follows a Gaussian distribution, or equivalently <sigma v> = mu * exp(p*log(expsigma)). The rate here is either <sigma v> or <sigma v^2> divided by (1 amu)^(N_in-1)) for each reaction, units (cm^3/s/g or cm^6/s/g^2). """ T9 = T*1e-9 if self.T9_vec is not None: rate_vec = self.mu_median_vec * jnp.exp( p * jnp.log(self.expsigma_vec) ) if self.interp_type == 'linear': return jnp.interp( T9, self.T9_vec, rate_vec, left=0., right=0. ) elif self.interp_type == 'log': return jnp.exp(jnp.interp( jnp.log(T9), jnp.log(self.T9_vec), jnp.log(rate_vec), left=0., right=0. )) else: return self.frwrd_rate_param_func(T, p)
[docs] @eqx.filter_jit def bkwrd_rate_param(self, T, p): """ Backward rate parameter. Parameters ---------- T : float Temperature in K. p : float Rescaling parameter for expsigma. interp_type : str, optional Interpolation type for spline data. Either 'linear' or 'log'. Returns ------- float Notes ----- We take log(<sigma v>) = log(mu) + p*log(expsigma), where p follows a Gaussian distribution, or equivalently <sigma v> = mu * exp(p*log(expsigma)). The rate here is either <sigma v> or <sigma v^2> divided by (1 amu)^(N_in-1)) for each reaction, units (cm^3/s/g or cm^6/s/g^2). """ T9 = T*1e-9 return self.alpha*T9**self.beta*jnp.exp(self.gamma/T9) * ( self.frwrd_rate_param(T, p) )