from functools import partial
import jax
from jax import lax
from jax import numpy as jnp
from jax.scipy.special import zeta, gammaln, i0, i1, bernoulli
from jax.scipy.special import gamma as jax_gamma
euler_gamma = 0.57721566490153286061
zeta_3 = 1.202056903159594
# Number of terms for Li series.
L = 60
# List of Bernoulli numbers B_n
bernoulli_ary = bernoulli(L)
[docs]
def comb(N,k):
"""
Combinatoric factor N! / (k! (N-k)!).
Parameters
----------
N : int or array
Total number of items.
k : int or array
Number of items to choose.
Returns
-------
float or array
Binomial coefficient.
"""
return jnp.exp((gammaln(N+1) - gammaln(k+1) - gammaln(N-k+1) ))
[docs]
def Bernoulli(n, x):
"""
The Bernoulli polynomial B_n(x), n < 60.
Parameters
----------
n : int
Order of the Bernoulli polynomial (must be < 60).
x : float or array
Argument of the polynomial.
Returns
-------
float or array
Value of B_n(x).
Notes
-----
See Wikipedia article on Bernoulli polynomials for definition.
"""
return lax.fori_loop(
0, n+1, lambda i, val: val+bernoulli_ary[i] * comb(n,i) * x**(n-i), 0
)
[docs]
def gamma(x):
"""
Gamma function using the Lanczos approximation.
Parameters
----------
x : float or complex
Argument of the gamma function.
Returns
-------
float or complex
Value of Γ(x).
Notes
-----
Uses Lanczos approximation with reflection formula for x < 0.5.
Supports complex arguments.
"""
def lanczos(x):
g = 7
p_vals = jnp.array([
0.99999999999980993,
676.5203681218851,
-1259.1392167224028,
771.32342877765313,
-176.61502916214059,
12.507343278686905,
-0.13857109526572012,
9.9843695780195716e-6,
1.5056327351493116e-7
])
eps = 1e-06
z = x
z_fill = z - 1
denom = z_fill + jnp.arange(1, len(p_vals))
x_fill = p_vals[0] + jnp.sum(p_vals[1:] / denom)
t = z_fill + g + 0.5
result1 = (
jnp.sqrt(2 * jnp.pi) * t ** (z_fill + 0.5) * jnp.exp(-t) * x_fill
)
result2 = jnp.where(
jnp.abs(result1.imag) <= eps, result1.real, result1
)
return result2
def reflect(x):
return jnp.pi / jnp.sin(jnp.pi * x) * lanczos(1. - x)
return lax.cond(x < 0.5, reflect, lanczos, x)
[docs]
def Riemann_zeta(n):
"""
Riemann zeta function with extended domain.
Parameters
----------
n : int or array
Argument of the zeta function (must be > -60).
Returns
-------
float or array
Value of ζ(n).
Notes
-----
Uses Bernoulli numbers for negative arguments. Returns 0 for
negative even integers.
"""
return jnp.where(
n > 0, zeta(n, 1), jnp.where(
(n < 0) & (-n % 2 == 0), 0.,
(-1.)**(-n) * bernoulli_ary[-n+1] / (-n + 1)
)
)
[docs]
@partial(jax.jit, static_argnums=(0,))
def Li(n, z):
r"""
Polylogarithm of order n and argument z.
Parameters
----------
n : int
Order of the polylogarithm (static argument).
z : float or complex
Argument of the polylogarithm.
Returns
-------
float
Value of Li_n(z).
Notes
-----
Uses different series expansions depending on \|z\|:
- \|z\| ≤ 0.5: direct series
- 0.5 < \|z\| < 2: intermediate series with harmonic terms
- \|z\| ≥ 2: reciprocal series with Bernoulli polynomials
"""
def _Li_z_small(z):
return lax.fori_loop(1, L, lambda j,val: val + z**j / j**n, 0)
def _Li_z_intermed(z):
# Oddly enough, the fastest way to do this.
zeta_ary = jnp.concatenate((
jnp.array([Riemann_zeta(n - m) for m in jnp.arange(n-1)]),
jnp.array([0.]),
jnp.array([Riemann_zeta(n - m) for m in jnp.arange(n, L)])
))
zeta_series_term = jnp.sum(
zeta_ary * jnp.concatenate(
(jnp.array([1., jnp.log(z+0j)]),
jnp.log(z+0j)**jnp.arange(2, L))
) / jax_gamma(jnp.arange(L) + 1.)
)
H_n_m_1 = jnp.sum(1. / jnp.arange(1, n))
harmonic_term = jnp.where(
jnp.isclose(z - 1., 0), 0.,
jnp.log(z+0j)**(n-1) / jax_gamma(n) * (
H_n_m_1 - jnp.log(-jnp.log(
jnp.where(jnp.isclose(z - 1., 0), 2., z)+0j) + 0j
)
)
)
res = zeta_series_term + harmonic_term
return jnp.real(res)
def _Li_z_large(z):
recip_Li = lax.fori_loop(1, L, lambda j,val: val + (1/z)**j / j**n, 0)
B_n = Bernoulli(n, jnp.log(z+0j)/(2 * jnp.pi * 1j))
return jnp.real(
- (-1)**n * recip_Li
- (2*jnp.pi*1j)**n / jax_gamma(n + 1) * B_n
)
small_range = jnp.abs(z) <= 0.5
intermed_range = (0.5 < jnp.abs(z)) & (jnp.abs(z) < 2)
large_range = jnp.abs(z) > 2
return jnp.where(
small_range, _Li_z_small(jnp.where(small_range, z, 3.)), jnp.where(
intermed_range, _Li_z_intermed(
jnp.where(intermed_range, z, 3.)
), jnp.where(
large_range, _Li_z_large(jnp.where(large_range, z, 3.)), 0.
)
)
)
[docs]
def K0(z):
"""
Modified Bessel function of the second kind of order 0.
Parameters
----------
z : float or array
Argument of the Bessel function.
Returns
-------
float or array
Value of K_0(z).
Notes
-----
Uses different series approximations for z < 9 and z ≥ 9.
Algorithm from Zhang and Jin.
"""
def K0_small(z):
# n = 30 sufficient for abstol ~ 1e-12 and reltol ~ 1e-8
int_ary = jnp.arange(1., 31)
harmonic_ary = jnp.cumsum(1. / jnp.arange(1, 31))
return -(jnp.log(z/2.) + euler_gamma) * i0(jnp.where(z < 600., z, 600.)) + jnp.sum(
harmonic_ary * (z/2.)**(2.*int_ary) / jax_gamma(int_ary+1)**2.
)
def K0_large(z):
# n = 10 sufficient for abstol ~ 1e-12 and reltol ~ 1e-8
int_ary = jnp.arange(1., 11)
prod_term = jnp.cumprod(-(2.*int_ary - 1.) / (2.*int_ary) * (2.*int_ary - 1.)**2.)
res = 1. / 2. / z / i0(jnp.where(z < 600., z, 600.)) * (1. + jnp.sum((-1)**int_ary * prod_term / (2.*z)**(2.*int_ary)))
return jnp.where(z < 600, res, 0.)
return lax.cond(z < 9, K0_small, K0_large, z)
[docs]
def K1(z):
"""
Modified Bessel function of the second kind of order 1.
Parameters
----------
z : float or array
Argument of the Bessel function.
Returns
-------
float or array
Value of K_1(z).
"""
def K1_small(z):
return (1 / z - i1(z) * K0(z)) / i0(z)
return jnp.where(z < 600., K1_small(jnp.where(z < 600. , z, 600.)), 0.)
[docs]
def K2(z):
"""
Modified Bessel function of the second kind of order 2.
Parameters
----------
z : float or array
Argument of the Bessel function.
Returns
-------
float or array
Value of K_2(z).
Notes
-----
Computed using recurrence relation: K_2(z) = K_0(z) + 2/z K_1(z).
"""
return K0(z) + 2 / z * K1(z)