API Reference

Top-level

The main entry point is compile(), which returns a Solver bundle containing all JIT-compiled observables.

lax.compile(*, mesh, channels=None, blocks=None, operators=('T+L', ), solvers=('spectrum', 'rmatrix', 'smatrix', 'phases'), energies=None, energy_dependent=False, method=None, V_is_complex=False, grid=None, momenta=None, z1z2=None, dps=40, mass_factor_grid=None, dtype=<class 'jax.numpy.float64'>, device=None)[source]

Build a compiled solver bundle for one mesh/channel definition.

Note

lax.compile shadows Python’s built-in compile. Avoid from lax import compile in modules that also use the built-in.

Parameters:
  • mesh – User-facing mesh specification. The chosen mesh family, regularization, scale, and any mesh-specific extras are resolved at compile time.

  • channels – Channel definitions for a single coupled-channel block, baked into the compiled solver structure. Mutually exclusive with blocks; exactly one must be given.

  • blocks – A batch of same-shaped symmetry blocks (independent (J, π) groups, partial waves, …); see DESIGN.md §15.5. Each inner group must have the same length N_c. The compiled solver carries the block set on solver.blocks, the boundary values gain a leading (N_b,) axis, and every observable output gains a corresponding leading block axis. Partial-wave batching is the N_c == 1 case: blocks=[[ChannelSpec(l=0, …)], [ChannelSpec(l=1, …)], …]. Mutually exclusive with channels.

  • operators – Compile-time operator matrices to precompute. "T+L" is injected automatically whenever the requested solver path needs it.

  • solvers – Runtime entry points to expose on the returned Solver. The potential passed to solver.spectrum(V) or solver.rmatrix_direct(V) must be an Interaction. Build one with solver.local_potential(fn)/solver.nonlocal_potential(fn) or the solver.interaction_from_{block,array,funcs} builders.

  • energies – Compile-time energy grid used for boundary-value-dependent observables and aligned-grid workflows.

  • energy_dependent – Whether the caller intends to provide an energy-dependent potential on the compile-time energy grid. Build the potential with energy_dependent=Truesolver.spectrum dispatches the energy axis internally and returns a batched Spectrum — then use solver.phases_grid(spectra) for the aligned-grid observables.

  • method – Explicit solver method. When omitted, the method is chosen from V_is_complex and the active JAX backend.

  • V_is_complex – Whether the potential path is complex-valued.

  • grid – Optional radial grid used to precompute mesh-to-grid transforms. Accessible afterward as solver.grid_r.

  • momenta – Optional momentum grid used to precompute Fourier transforms. Accessible afterward as solver.momenta.

  • z1z2 – Optional pair of charges passed to compile-time boundary-value evaluation.

  • dps – Decimal precision for the mpmath boundary-value calculation.

  • mass_factor_grid – Per-energy (and optionally per-channel) ℏ²/2μ values in MeV·fm². Accepted shapes (all broadcast to the canonical (N_E, N_c) form):

    • None — use each channel’s ChannelSpec.mass_factor uniformly.

    • scalar — the same value for all energies and channels.

    • shape (N_E,) — one value per energy, shared across channels.

    • shape (N_E, N_c) — fully independent per (energy, channel) pair.

    When provided, len(mass_factor_grid) along the first axis must equal len(energies). The grid is used in two places:

    1. Boundary values — wave numbers and Sommerfeld parameters at each (energy, channel) pair use mass_factor_grid[ie, ic].

    2. Aligned-grid direct observables — the Hamiltonian is assembled with the per-energy per-channel mass factor at each grid point.

  • dtype – Floating-point precision for the baked arrays (mesh, operators, boundary values, energy grid, transforms). Default jnp.float64; complex caches use the matching complex dtype (complex64 for float32). x64 itself is enabled globally via jax.config; dtype only selects the precision of the compile-time caches. Runtime kernels compute in the promoted dtype of the baked arrays and the supplied Interaction — build interactions through the solver’s own builders to stay in the requested precision.

  • device – Optional device (or device-platform string such as "cpu"/"gpu") on which to place the compiled solver’s cached arrays via jax.device_put. None keeps JAX’s default placement.

Returns:

Solver – A pickle-safe solver bundle containing compile-time caches and bound runtime observables.

Parameters:
Return type:

Solver

lax.make_wavefunction_source(solver, channel_index, energy_index)[source]

Build the mesh-space source vector for one incoming channel.

The source term drives the internal Green’s function to produce the scattering wavefunction for a reaction incoming in channel c at energy E_i. Following Descouvemont [2] eq. 27, the source is:

source[c·N : (c+1)·N] = φ_n(a) · H⁻_c(E_i)   (all other blocks zero)

where φ_n(a) are the Lagrange-basis boundary values and H⁻ is the incoming Coulomb/Whittaker function at the channel radius. This is a slice of the compile-time stack built by build_wavefunction_sources() (baked on the solver when a wavefunction entry point was requested, rebuilt on demand otherwise).

Parameters:
  • solver – Compiled solver bundle. Must have been built with an energy grid (so solver.boundary is not None).

  • channel_index – Index c of the incoming channel (0-based).

  • energy_index – Index into the compile-time energy grid (0-based).

Returns:

jnp.ndarray – Source vector of shape (N_c · N,), where N = solver.mesh.n and N_c = len(solver.channels). For a solver compiled with blocks= (DESIGN.md §15.5) the per-block sources are stacked on a leading block axis — shape (N_b, N_c · N) — matching the input expected by solver.wavefunction_direct in blocks mode.

Raises:

ValueError – If solver.boundary is None (no energy grid was compiled).

Parameters:
Return type:

Array

Examples

>>> import lax, lax.constants as C, jax.numpy as jnp
>>> HBAR2_2MU = C.hbar2_over_2mu(1.008665, 1.008665)
>>> energies  = jnp.linspace(1.0, 10.0, 20)
>>> solver = lax.compile(
...     mesh=lax.MeshSpec("legendre", "x", n=20, scale=8.0),
...     channels=(lax.ChannelSpec(l=0, threshold=0.0, mass_factor=HBAR2_2MU),),
...     solvers=("spectrum", "wavefunction"),
...     energies=energies,
... )
>>> V = solver.nonlocal_potential(lambda r1, r2: jnp.zeros_like(r1))
>>> spec = solver.spectrum(V)
>>> src  = lax.make_wavefunction_source(solver, channel_index=0, energy_index=5)
>>> psi  = solver.wavefunction(spec, energies[5], src)

For the direct (linear-solve) path — no eigendecomposition required — compile with solvers=("rmatrix_direct",) and use:

interaction = solver.interaction_from_block(V[0, 0])  # (M, M) block
psi = solver.wavefunction_direct(interaction, src, energy_index=5)
class lax.MeshSpec(family, regularization, n, scale, extras=<factory>)[source]

Bases: object

User-facing mesh specification passed to lax.compile().

Variables:
  • family (lax.types.MeshFamily) – Mesh family registered in lax.meshes. The current public API supports "legendre" and "laguerre".

  • regularization (lax.types.Regularization) –

    Endpoint regularization used by the chosen family. The currently supported combinations are:

    • Legendre: "x", "x(1-x)", "x^3/2"

    • Laguerre: "x", "modified_x^2"

  • n (int) – Number of mesh basis functions.

  • scale (float) – Physical length scale for the mesh. For finite-interval meshes this is the channel radius; for semi-infinite meshes it is the radial scaling factor described in DESIGN.md.

  • extras (dict[str, object]) – Mesh-specific compile-time options forwarded to the registered mesh builder.

Parameters:
family: MeshFamily
regularization: Regularization
n: int
scale: float
extras: dict[str, object]
class lax.ChannelSpec(l, threshold, mass_factor)[source]

Bases: object

One scattering channel baked into the compiled solver structure.

Variables:
  • l (int) – Orbital angular momentum for the channel.

  • threshold (float) – Channel threshold in MeV. Assembly code converts it to fm^-2 using mass_factor.

  • mass_factor (float | jax.Array) – Conversion factor ℏ² / in MeV·fm². Required — there is no default, since any fixed value would be physically meaningless for an arbitrary nucleus. Use lax.constants.hbar2_over_2mu() to compute it from particle masses in AMU, e.g. lax.constants.hbar2_over_2mu(1.008665, 1.008665) ≈ 41.47 MeV·fm² for nucleon–nucleon systems.

Parameters:
l: int
threshold: float
mass_factor: float | Array
class lax.Solver(mesh, operators, channels, energies, boundary, transforms, method, mass_factor_grid=None, blocks=None, spectrum=None, rmatrix=None, smatrix=None, phases=None, greens=None, wavefunction=None, wavefunction_grid=None, wavefunction_sources=None, eigh=None, rmatrix_grid=None, smatrix_grid=None, phases_grid=None, rmatrix_direct=None, smatrix_direct=None, phases_direct=None, wavefunction_direct=None, wavefunction_direct_grid=None, interaction_from_block=None, interaction_from_array=None, interaction_from_funcs=None, local_potential=None, nonlocal_potential=None, to_grid_vector=None, from_grid_vector=None, to_grid_matrix=None, fourier=None, double_fourier_transform=None, integrate=None, matrix_element=None)[source]

Bases: object

Compiled solver bundle produced by lax.compile().

Holds all compile-time caches (mesh, operators, boundary values, transform matrices) alongside JIT-compiled, pickle-safe runtime callables. Call print(solver) to see which observables were compiled.

Variables:
  • mesh (lax.types.Mesh) – Compiled mesh data (nodes, weights, radii, boundary values).

  • operators (lax.types.OperatorMatrices) – Precomputed single-channel operator matrices in fm⁻².

  • channels (tuple[lax.types.ChannelSpec, ...]) – Channel definitions baked into the solver at compile time.

  • energies (jax.Array) – Compile-time energy grid in MeV, shape (N_E,).

  • boundary (lax.spectral.types.BoundaryValues | None) – Coulomb/Whittaker boundary values at r = a, or None if no energy grid was supplied.

  • transforms (lax.types.TransformMatrices) – Precomputed radial-grid and momentum-space transform matrices.

  • method (lax.types.Method) – Linear-algebra backend: "eigh", "eig", or "linear_solve".

  • mass_factor_grid (jax.Array | None) – Per-energy ℏ²/2μ values in MeV·fm², shape (N_E,), or None when a constant mass factor is used. Stored here so the aligned-grid observables can use the correct μ(E) at each energy point.

  • blocks (tuple[tuple[lax.types.ChannelSpec, ...], ...] | None) – The symmetry-block set passed to lax.compile(blocks=…), or None for a channels-compiled solver (DESIGN.md §15.5). When set, channels holds the template block blocks[0], boundary carries a leading (N_b,) axis, and every observable output gains a leading block axis.

  • "eigh"/"eig") (**Spectral-path observables** (present when method is)

  • spectrum (lax.types.SpectrumKernel | None) – (V) Spectrum — one eigendecomposition per potential.

  • rmatrix (lax.types.RMatrixObservable | None) – (spectrum, E) R(E) — R-matrix at any scalar energy.

  • smatrix (lax.types.SpectrumObservable | None) – (spectrum) S — S-matrix on the compile-time energy grid.

  • phases (lax.types.SpectrumObservable | None) – (spectrum) δ — phase shifts (N_E, N_c) in radians.

  • greens (lax.types.GreenFunctionObservable | None) – (spectrum, E) G(E) — Green’s function; requires 'greens' in solvers=.

  • wavefunction (lax.types.WavefunctionObservable | None) – (spectrum, E, source) ψ_int — internal wavefunction; requires 'wavefunction' in solvers=.

  • wavefunction_grid (lax.types.WavefunctionGridObservable | None) – (spectrum, channel_index=0) ψ — internal wavefunctions at every compile-time grid energy (both evaluation regimes); requires 'wavefunction' in solvers= and an energy grid.

  • wavefunction_sources (jax.Array | None) – Baked Descouvemont eq.-27 source stack (N_E, N_c, M)(N_b, N_E, N_c, M) in blocks mode — or None when no wavefunction entry point was requested.

  • eigh (lax.types.EigenpairAccessor | None) – (spectrum) (ε, U) — raw eigenpairs; raises if eigenvectors were not retained.

  • rmatrix_grid (lax.types.SpectrumGridObservable | None) – (spectra) R — aligned-grid R for energy-dependent workflows.

  • smatrix_grid (lax.types.SpectrumGridObservable | None) – (spectra) S — aligned-grid S.

  • phases_grid (lax.types.SpectrumGridObservable | None) – (spectra) δ — aligned-grid phases.

  • solvers=) (**Direct-path observables** (present when "rmatrix_direct" in)

  • rmatrix_direct (lax.types.DirectRMatrixKernel | None) – (V) R — per-energy linear-solve R-matrix on the compile-time grid.

  • wavefunction_direct_grid (lax.types.WavefunctionDirectGridObservable | None) – (V, channel_index=0) ψ — direct-path wavefunctions at every compile-time grid energy; bound whenever the direct path is active.

  • helpers** (**Transform)

  • to_grid_vector (lax.types.GridVectorTransform | None) – (c) ψ(r) — mesh coefficients to fine radial grid.

  • from_grid_vector (lax.types.FromGridVectorTransform | None) – (ψ_or_fn) c — fine grid values back to mesh coefficients.

  • to_grid_matrix (lax.types.GridMatrixTransform | None) – (V) V(r, r') — mesh kernel to fine radial grid.

  • fourier (lax.types.FourierTransform | None) – (c, channel_index=0) ũ(k) — momentum-space transform.

  • double_fourier_transform (lax.types.DoubleFourierTransform | None) – (V, ...) V(p, p') — double Bessel transform for kernels.

  • integrate (lax.types.Integrator | None) – (c, operator=None) ⟨ψ|O|ψ⟩ — norms and expectation values.

  • matrix_element (lax.types.MatrixElementHelper | None) – (bra, ket, operator=None, *, conjugate) braᵀ·O·ket — two-state bilinear form, batched over block/energy axes; always bound.

Parameters:
  • mesh (Mesh)

  • operators (OperatorMatrices)

  • channels (tuple[ChannelSpec, ...])

  • energies (Array)

  • boundary (BoundaryValues | None)

  • transforms (TransformMatrices)

  • method (Method)

  • mass_factor_grid (Array | None)

  • blocks (tuple[tuple[ChannelSpec, ...], ...] | None)

  • spectrum (SpectrumKernel | None)

  • rmatrix (RMatrixObservable | None)

  • smatrix (SpectrumObservable | None)

  • phases (SpectrumObservable | None)

  • greens (GreenFunctionObservable | None)

  • wavefunction (WavefunctionObservable | None)

  • wavefunction_grid (WavefunctionGridObservable | None)

  • wavefunction_sources (Array | None)

  • eigh (EigenpairAccessor | None)

  • rmatrix_grid (SpectrumGridObservable | None)

  • smatrix_grid (SpectrumGridObservable | None)

  • phases_grid (SpectrumGridObservable | None)

  • rmatrix_direct (DirectRMatrixKernel | None)

  • smatrix_direct (SMatrixDirectObservable | None)

  • phases_direct (PhasesDirectObservable | None)

  • wavefunction_direct (WavefunctionDirectObservable | None)

  • wavefunction_direct_grid (WavefunctionDirectGridObservable | None)

  • interaction_from_block (Callable[[...], Any] | None)

  • interaction_from_array (Callable[[...], Any] | None)

  • interaction_from_funcs (Callable[[...], Any] | None)

  • local_potential (Callable[[...], Any] | None)

  • nonlocal_potential (Callable[[...], Any] | None)

  • to_grid_vector (GridVectorTransform | None)

  • from_grid_vector (FromGridVectorTransform | None)

  • to_grid_matrix (GridMatrixTransform | None)

  • fourier (FourierTransform | None)

  • double_fourier_transform (DoubleFourierTransform | None)

  • integrate (Integrator | None)

  • matrix_element (MatrixElementHelper | None)

mesh: Mesh
operators: OperatorMatrices
channels: tuple[ChannelSpec, ...]
energies: Array
boundary: BoundaryValues | None
transforms: TransformMatrices
method: Method
mass_factor_grid: Array | None = None
blocks: tuple[tuple[ChannelSpec, ...], ...] | None = None
spectrum: SpectrumKernel | None = None
rmatrix: RMatrixObservable | None = None
smatrix: SpectrumObservable | None = None
phases: SpectrumObservable | None = None
greens: GreenFunctionObservable | None = None
wavefunction: WavefunctionObservable | None = None
wavefunction_grid: WavefunctionGridObservable | None = None
wavefunction_sources: Array | None = None
eigh: EigenpairAccessor | None = None
rmatrix_grid: SpectrumGridObservable | None = None
smatrix_grid: SpectrumGridObservable | None = None
phases_grid: SpectrumGridObservable | None = None
rmatrix_direct: DirectRMatrixKernel | None = None
smatrix_direct: SMatrixDirectObservable | None = None
phases_direct: PhasesDirectObservable | None = None
wavefunction_direct: WavefunctionDirectObservable | None = None
wavefunction_direct_grid: WavefunctionDirectGridObservable | None = None
interaction_from_block: Callable[[...], Any] | None = None
interaction_from_array: Callable[[...], Any] | None = None
interaction_from_funcs: Callable[[...], Any] | None = None
local_potential: Callable[[...], Any] | None = None
nonlocal_potential: Callable[[...], Any] | None = None
to_grid_vector: GridVectorTransform | None = None
from_grid_vector: FromGridVectorTransform | None = None
to_grid_matrix: GridMatrixTransform | None = None
fourier: FourierTransform | None = None
double_fourier_transform: DoubleFourierTransform | None = None
integrate: Integrator | None = None
matrix_element: MatrixElementHelper | None = None
property grid_r: Array | None

Radial grid passed to lax.compile(), or None.

property momenta: Array | None

Momentum grid passed to lax.compile(), or None.

lax.constants

Physical constants and mass-factor utilities.

Fundamental physical constants in MeV-fm units.

All values are from the Particle Data Group (PDG) 2022 review unless noted. Lengths are in fm, energies in MeV, momenta in MeV/c.

Example usage:

import lax.constants as C

# ℏ²/2μ for neutron–proton system (MeV·fm²)
HBAR2_2MU = C.hbar2_over_2mu(1.008665, 1.007276)

# ℏ²/2μ for α + ⁴⁰Ca
HBAR2_2MU = C.hbar2_over_2mu(4.001506, 39.96259)
lax.constants.ALPHA: float = 0.007297352568431779

Fine structure constant (dimensionless).

lax.constants.AMU: float = 931.494102

Atomic mass unit in MeV/c² (PDG).

lax.constants.E2: float = 1.4399645472428273

e² in MeV·fm (Coulomb coupling constant), the exact α·ℏc ≈ 1.43996.

This is the physically correct value and the single source of truth for the Coulomb constant — reference it (e.g. from boundary.coulomb and models.optical) rather than hard-coding a literal. Some published benchmark references were generated with the conventional rounded value 1.44; those specific tests override this constant locally (see the legacy_coulomb_constant fixture in tests/conftest.py) rather than forcing the rounded value on user applications.

lax.constants.HBARC: float = 197.3269804

ℏc in MeV·fm.

lax.constants.MASS_E: float = 0.5109989461

Electron mass in MeV/c² (PDG).

lax.constants.MASS_N: float = 939.5654983938299

Neutron mass in MeV/c² (PDG).

lax.constants.MASS_P: float = 938.271653086152

Proton mass in MeV/c² (PDG).

lax.constants.WAVENUMBER_PION: float = 0.7071067811865476

Pion wavenumber mπ/(ℏc) in fm⁻¹. [Thompson & Nuñes, below eq. 4.3.10]

lax.constants.hbar2_over_2mu(m1_amu, m2_amu)[source]

Return ℏ²/2μ in MeV·fm² for a two-body system.

Parameters:

m1_amu, m2_amu – Particle masses in atomic mass units (AMU). Use PDG masses for accuracy: neutron 1.008665, proton 1.007276, alpha 4.001506, etc.

Returns:

float – ℏ²/2μ = (ℏc)² / (2 μc²) in MeV·fm².

Parameters:
Return type:

float

Examples

>>> import lax.constants as C
>>> round(C.hbar2_over_2mu(1.008665, 1.008665), 3)  # n-n
41.47...
>>> round(C.hbar2_over_2mu(1.008665, 1.007276), 3)  # n-p
41.47...

lax.spectral

Mesh-independent observables operating on a Spectrum.

class lax.spectral.Spectrum(eigenvalues, surface_amplitudes, eigenvectors, is_hermitian)[source]

Bases: object

Spectral decomposition of the Bloch-augmented Hamiltonian.

Produced by solver.spectrum(V) and consumed by all downstream observables. One Spectrum supports R-matrix, Green’s function, S-matrix, and phase-shift evaluation at arbitrary energies via spectral sums — no further linear-algebra calls are needed.

Variables:
  • eigenvalues (jax.Array) – Eigenvalues ε_k of H in fm⁻². Shape (M,) where M = N_c · N. Real-valued for Hermitian H (real potential); complex for complex-symmetric H (optical/absorptive potential).

  • surface_amplitudes (jax.Array) – Reduced-width amplitudes γ_kc = (U^T Q)_kc. Shape (M, N_c). Sufficient on their own for R-matrix and S-matrix evaluation.

  • eigenvectors (jax.Array | None) – Full eigenvector matrix U, shape (M, M). None if neither 'greens' nor 'wavefunction' was requested at compile time.

  • is_hermitian (bool) – Static flag. True when H is Hermitian (real V, method='eigh'), routing downstream code to use conjugate-transpose instead of plain transpose in spectral sums.

Parameters:
  • eigenvalues (Array)

  • surface_amplitudes (Array)

  • eigenvectors (Array | None)

  • is_hermitian (bool)

eigenvalues: Array
surface_amplitudes: Array
eigenvectors: Array | None
is_hermitian: bool
class lax.spectral.CoupledChannelParameters(phase_1, phase_2, mixing_angle)[source]

Bases: object

Blatt-Biedenharn eigenphases and mixing angle for a 2×2 S-matrix.

Stores the bar-phase (ε-bar) parameterisation used by Blatt and Biedenharn (1952): the physical 2×2 S-matrix is diagonalised, giving two eigenphases phase_1/phase_2 and one mixing angle mixing_angle (the bar-epsilon parameter ε̄).

Variables:
  • phase_1 (jax.Array) – First eigenphase in radians (eigenvalue with smaller argument).

  • phase_2 (jax.Array) – Second eigenphase in radians (eigenvalue with larger argument).

  • mixing_angle (jax.Array) – Bar-epsilon mixing angle ε̄ in radians. Zero for an uncoupled channel.

Parameters:
phase_1: Array
phase_2: Array
mixing_angle: Array
lax.spectral.rmatrix_from_spectrum(spectrum, energy, channel_radius, mass_factor)[source]

Return the Wigner-Eisenbud R-matrix. [DESIGN.md §10.2]

Parameters:
  • spectrum – Stored eigensystem of the Bloch-augmented Hamiltonian.

  • energy – Physical energy in MeV. The helper converts it to the library’s fm^-2 convention before forming the spectral denominator.

  • channel_radius – Channel radius a used in the surface-amplitude normalization.

  • mass_factor – Conversion factor ℏ² / in MeV fm².

Returns:

jax.Array – Channel-space R-matrix evaluated at energy.

Parameters:
Return type:

Array

lax.spectral.greens_from_spectrum(spectrum, energy, mass_factor)[source]

Return the resolvent (H - E / mass_factor)^-1. [DESIGN.md §10.2]

Parameters:
  • spectrum – Stored eigensystem of the Bloch-augmented Hamiltonian.

  • energy – Physical energy in MeV.

  • mass_factor – Conversion factor ℏ² / in MeV fm².

Returns:

jax.Array – Matrix resolvent reconstructed from the stored eigenpairs.

Raises:

ValueError – If the spectrum was compiled without eigenvectors.

Parameters:
Return type:

Array

lax.spectral.wavefunction_internal_from_spectrum(spectrum, energy, source, mass_factor)[source]

Return the internal wavefunction generated by a source term.

Parameters:
  • spectrum – Stored eigensystem of the Bloch-augmented Hamiltonian.

  • energy – Physical energy in MeV.

  • source – Mesh-space source vector of shape (N_c · N,). Following Descouvemont [2] eq. 27, for a reaction driven by an incoming wave in channel c at energy E_i:

    source[c*N : (c+1)*N] = basis_at_boundary * H_minus[E_i, c]
    # all other blocks are zero
    

    Use lax.make_wavefunction_source() to build this automatically from a compiled Solver.

  • mass_factor – Conversion factor ℏ² / in MeV fm².

Returns:

jax.Array – Internal wavefunction values in the mesh basis.

Parameters:
Return type:

Array

lax.spectral.smatrix_from_R(R, boundary_at_energy)[source]

Return the full channel-space S-matrix at one energy.

Applies the standard R-matrix matching formula [Descouvemont eqs. 16-17], normalised by channel wave numbers:

S = (H⁻ - R̃ H⁻') / (H⁺ - R̃ H⁺')   where  R̃ = K R K⁻¹
Parameters:
  • R – R-matrix at one energy, shape (N_c, N_c).

  • boundary_at_energy – Boundary values at the same energy. For a multi-energy solver, pass a single energy’s slice (shape (N_c,) per array field).

Returns:

jax.Array – S-matrix, shape (N_c, N_c), complex.

Parameters:
Return type:

Array

lax.spectral.open_channel_smatrix_from_R(R, boundary_at_energy)[source]

Return the physical open-channel S-matrix at one energy.

Parameters:
  • R – Full channel-space R-matrix at one energy.

  • boundary_at_energy – Boundary values for the same energy. Closed channels may be included.

Returns:

jax.Array – Physical S-matrix restricted to the open-channel subspace.

Parameters:
Return type:

Array

lax.spectral.phases_from_S(S)[source]

Return channel phase shifts from the eigenphases of S.

Computes δ_k = ½ arg(λ_k) where λ_k are the eigenvalues of S. For a diagonal (uncoupled) S-matrix these coincide with the standard single-channel phase shifts.

Parameters:

S – S-matrix at one energy, shape (N_c, N_c).

Returns:

jax.Array – Phase shifts in radians, shape (N_c,). For a single-channel solver index with [0] to get a scalar.

Parameters:

S (Array)

Return type:

Array

lax.spectral.coupled_channel_parameters_from_S(S)[source]

Return Blatt-Biedenharn eigenphases and mixing angle from a 2×2 S-matrix.

Diagonalises the symmetric 2×2 collision matrix and extracts the bar-phase parameterisation [Blatt & Biedenharn 1952].

Parameters:

S – Symmetric 2×2 S-matrix at one energy.

Returns:

CoupledChannelParameters – Eigenphases and mixing angle.

Raises:

ValueError – If S.shape != (2, 2).

Parameters:

S (Array)

Return type:

CoupledChannelParameters

lax.models

Reusable interaction models and preset system parameters.

class lax.models.RotorCoupledOpticalModel(mass_factor, potential_radius, coupling_radius, coulomb_radius, diffuseness, real_depth, imaginary_depth, deformation, multipole, total_angular_momentum, projectile_charge, target_charge, channels)[source]

Bases: object

Parameters for a rotor-coupled optical potential model.

Parameters:
  • mass_factor (float) – Conversion factor ℏ² / in MeV fm².

  • potential_radius (float) – Woods-Saxon radius used for the diagonal optical potential.

  • coupling_radius (float) – Radius that multiplies the derivative coupling term.

  • coulomb_radius (float) – Radius of the uniformly charged sphere used for the Coulomb term.

  • diffuseness (float) – Woods-Saxon diffuseness in fm.

  • real_depth (float) – Real optical depth in MeV.

  • imaginary_depth (float) – Imaginary optical depth in MeV.

  • deformation (float) – Rotor deformation parameter, typically written β.

  • multipole (int) – Coupling multipole λ.

  • total_angular_momentum (int) – Total coupled angular momentum J.

  • projectile_charge (int) – Projectile charge Z_1.

  • target_charge (int) – Target charge Z_2.

  • channels (tuple[lax.models.optical.RotorChannel, …]) – Channel definitions included in the coupled model.

Parameters:
mass_factor: float
potential_radius: float
coupling_radius: float
coulomb_radius: float
diffuseness: float
real_depth: float
imaginary_depth: float
deformation: float
multipole: int
total_angular_momentum: int
projectile_charge: int
target_charge: int
channels: tuple[RotorChannel, ...]
class lax.models.RotorChannel(orbital_angular_momentum, target_spin, threshold, label=None)[source]

Bases: object

One channel in a rotor-coupled optical model.

Parameters:
  • orbital_angular_momentum (int) – Relative orbital angular momentum L for the channel.

  • target_spin (int) – Rotor (target-state) spin coupled to the projectile.

  • threshold (float) – Channel threshold in MeV.

  • label (str | None) – Optional human-readable label for plots or tables.

Parameters:
  • orbital_angular_momentum (int)

  • target_spin (int)

  • threshold (float)

  • label (str | None)

orbital_angular_momentum: int
target_spin: int
threshold: float
label: str | None = None
lax.models.channels_from_rotor_model(model)[source]

Return ChannelSpec objects for a rotor model.

Parameters:

model – Rotor-coupled optical model definition.

Returns:

tuple[ChannelSpec, …] – Channel layout ready to pass to lax.compile().

Parameters:

model (RotorCoupledOpticalModel)

Return type:

tuple[ChannelSpec, …]

lax.models.rotor_coupled_optical_potential(model, radii, channel_index, coupled_index)[source]

Return one matrix element of a rotor-coupled optical potential in MeV.

Parameters:
  • model – Rotor-coupled optical model definition.

  • radii – Radial grid in fm.

  • channel_index – Bra-channel index.

  • coupled_index – Ket-channel index.

Returns:

jax.Array – Complex local potential evaluated on radii.

Parameters:
Return type:

Array

lax.models.interaction_from_rotor_model(model, solver)[source]

Build an Interaction for a rotor-coupled optical model.

Decomposes the potential into three local terms via the §6.1 term-decomposition pattern and assembles them through interaction_from_funcs():

  • Nuclear diagonal: -complex_depth · WS(r) on the diagonal channels.

  • Coulomb diagonal: Coulomb(r) on the diagonal channels.

  • Derivative coupling: -complex_depth · β · R_c · dWS/dr scaled by the angular coupling matrix A[c, c'] = rotor_coupling_coefficient(c, c').

Parameters:
  • model – Rotor-coupled optical model definition.

  • solver – Compiled solver whose interaction_from_funcs() entry point is used to assemble the potential block.

Returns:

Interaction – Energy-independent assembled potential block ready for solver.spectrum or solver.rmatrix_direct.

Parameters:
Return type:

Interaction

lax.models.open_channel_count(model, energy)[source]

Return the number of channels open at one center-of-mass energy.

Parameters:
  • model – Rotor-coupled optical model definition.

  • energy – Center-of-mass energy in MeV.

Returns:

int – Number of channels with threshold below or equal to energy.

Parameters:
Return type:

int

lax.models.first_column_amplitudes_and_phases(smatrix, open_count)[source]

Return amplitudes and half-angles for the first open-channel column.

Parameters:
  • smatrix – Open-channel collision matrix.

  • open_count – Number of open channels represented in the matrix.

Returns:

tuple[np.ndarray, np.ndarray] – Absolute values and half-angles of the first column.

Parameters:
Return type:

tuple[ndarray, ndarray]

lax.models.woods_saxon_form_factor(radii, radius, diffuseness)[source]

Return the Woods-Saxon form factor.

Parameters:
  • radii – Radial grid in fm.

  • radius – Woods-Saxon radius in fm.

  • diffuseness – Woods-Saxon diffuseness in fm.

Returns:

jax.Array – Dimensionless Woods-Saxon profile.

Parameters:
Return type:

Array

lax.models.woods_saxon_derivative(radii, radius, diffuseness)[source]

Return the positive radial derivative factor used in rotor coupling.

Parameters:
  • radii – Radial grid in fm.

  • radius – Woods-Saxon radius in fm.

  • diffuseness – Woods-Saxon diffuseness in fm.

Returns:

jax.Array – Derivative factor entering the deformation coupling term.

Parameters:
Return type:

Array

lax.models.uniform_sphere_coulomb_potential(radii, radius, projectile_charge, target_charge)[source]

Return the uniformly charged-sphere Coulomb potential in MeV.

Parameters:
  • radii – Radial grid in fm.

  • radius – Sphere radius in fm.

  • projectile_charge – Projectile charge Z_1.

  • target_charge – Target charge Z_2.

Returns:

jax.Array – Coulomb potential in MeV.

Parameters:
Return type:

Array

lax.models.rotor_coupling_coefficient(model, channel_index, coupled_index)[source]

Return the angular coupling coefficient for one channel pair.

Parameters:
  • model – Rotor-coupled optical model definition.

  • channel_index – Bra-channel index.

  • coupled_index – Ket-channel index.

Returns:

float – Angular coupling coefficient multiplying the derivative term.

Parameters:
Return type:

float

lax.models.reid_np_j1_channels()[source]

Return the coupled ^3S_1-^3D_1 channel pair for J=1.

Returns:

tuple[ChannelSpec, …] – Two-channel layout for the standard Reid soft-core triplet example.

Return type:

tuple[ChannelSpec, …]

lax.models.interaction_from_reid_np_j1(solver)[source]

Build an Interaction for the coupled Reid soft-core n-p model.

Decomposes the J=1 triplet potential into its three physical terms via the §6.1 term-decomposition pattern and assembles them through interaction_from_funcs():

  • Central: v_central(r) on the diagonal channels.

  • Tensor: v_tensor(r) scaled by [[0, 2√2], [2√2, -2]].

  • Spin-orbit: v_spin_orbit(r) scaled by [[0, 0], [0, -3]].

Parameters:

solver – Compiled two-channel solver (see reid_np_j1_channels()) whose interaction_from_funcs() entry point is used to assemble the potential block.

Returns:

Interaction – Energy-independent assembled potential block ready for solver.spectrum or solver.rmatrix_direct.

Parameters:

solver (Solver)

Return type:

Interaction

lax.models.reid_soft_core_triplet_components(radii)[source]

Return the Reid soft-core triplet components in MeV.

Parameters:

radii – Radial grid in fm.

Returns:

tuple[jax.Array, jax.Array, jax.Array] – Central, tensor, and spin-orbit terms in MeV.

Parameters:

radii (Array)

Return type:

tuple[Array, Array, Array]

Internal modules

The modules below are part of the public API but are not typically imported directly — they are accessed through compile() and the Solver bundle.

lax.boundary

class lax.boundary.BoundaryValues(H_plus, H_minus, H_plus_p, H_minus_p, is_open, k)[source]

Bases: object

Coulomb and Whittaker boundary values at the channel radius.

Precomputed at compile time by mpmath for every (energy, channel) pair. Open channels use Coulomb Hankel functions; closed channels use Whittaker functions that decay exponentially into the barrier.

Variables:
  • H_plus (jax.Array) – Outgoing Coulomb Hankel function H⁺ = G + iF at r = a, shape (N_E, N_c), complex.

  • H_minus (jax.Array) – Incoming Coulomb Hankel function H⁻ = G - iF at r = a, shape (N_E, N_c), complex.

  • H_plus_p (jax.Array) – ρ · d/dρ H⁺ evaluated at ρ = ka, shape (N_E, N_c), complex.

  • H_minus_p (jax.Array) – ρ · d/dρ H⁻ evaluated at ρ = ka, shape (N_E, N_c), complex.

  • is_open (jax.Array) – Boolean mask: True for open channels (E > E_threshold), shape (N_E, N_c).

  • k (jax.Array) – Channel wave numbers k_c(E) in fm⁻¹, shape (N_E, N_c).

Parameters:

Notes

For a solver compiled with a symmetry-block set (lax.compile(blocks=…), DESIGN.md §15.5) every field carries a leading (N_b,) axis — shape (N_b, N_E, N_c) — stacked per block at compile time.

H_plus: Array
H_minus: Array
H_plus_p: Array
H_minus_p: Array
is_open: Array
k: Array
lax.boundary.compute_boundary_values(channels, energies, channel_radius, z1z2=None, dps=40, mass_factor_grid=None)[source]

Compute Coulomb and Whittaker boundary values at the channel radius.

Evaluates the Hankel functions = G_L ± iF_L and their ρ d/dρ derivatives at ρ = k·a for every (energy, channel) pair using mpmath with dps decimal digits of precision. Closed channels use the Whittaker function W_{-η, ℓ+1/2}(2|k|a) instead.

This function runs at compile time (pure Python/NumPy) and is never traced by JAX.

Parameters:
  • channels – Channel definitions specifying l, threshold, and mass_factor for each channel.

  • energies – Compile-time energy grid in MeV, shape (N_E,).

  • channel_radius – Channel radius a in fm.

  • z1z2 – Charge product (Z_1, Z_2) for Coulomb scattering. Pass None for neutral particles (η = 0).

  • dpsmpmath decimal precision. The default of 40 provides ample guard digits against cancellation near resonances.

  • mass_factor_grid – Per-energy, per-channel ℏ²/2μ values in MeV·fm², shape (N_E, N_c). When provided, mass_factor_grid[ie, ic] overrides channels[ic].mass_factor in the wave-number and Sommerfeld-parameter computation for energy index ie. Pass None to use the scalar ChannelSpec.mass_factor uniformly.

Returns:

BoundaryValues – Boundary values for all (N_E, N_c) pairs. Open-channel entries use Coulomb Hankel functions; closed-channel entries use Whittaker functions and the is_open mask is False.

Parameters:
Return type:

BoundaryValues

lax.types

class lax.types.Mesh(family, regularization, n, scale, n_intervals, basis_size_per_interval, nodes, weights, radii, basis_at_boundary, propagation=None)[source]

Bases: object

Concrete mesh data cached inside a compiled solver.

Produced by the mesh registry and embedded in the Solver at compile time. Static fields are baked into the JAX JIT cache key; changing them requires recompilation.

Variables:
  • family (lax.types.MeshFamily) – Mesh family, e.g. "legendre" or "laguerre" (static).

  • regularization (lax.types.Regularization) – Endpoint regularization, e.g. "x" or "x(1-x)" (static).

  • n (int) – Number of basis functions per channel (static).

  • scale (float) – Physical scale: channel radius a in fm for finite-interval meshes, or the Laguerre scaling factor h in fm (static).

  • n_intervals (int) – Number of subintervals for propagated meshes; 1 otherwise (static).

  • basis_size_per_interval (int) – Basis functions per subinterval; equals n when n_intervals == 1 (static).

  • nodes (jax.Array) – Canonical mesh nodes on (0, 1) or (0, ∞), shape (n,).

  • weights (jax.Array) – Gauss quadrature weights λ_i, shape (n,).

  • radii (jax.Array) – Physical radial mesh points r_i = scale · x_i, shape (n,).

  • basis_at_boundary (jax.Array) – Lagrange basis values φ_j(a) at the channel surface, shape (n,). All zeros for semi-infinite (Laguerre) meshes.

  • propagation (lax.types.PropagationMatrices | None) – Subinterval propagation matrices, or None for single-interval meshes.

Parameters:
  • family (MeshFamily)

  • regularization (Regularization)

  • n (int)

  • scale (float)

  • n_intervals (int)

  • basis_size_per_interval (int)

  • nodes (Array)

  • weights (Array)

  • radii (Array)

  • basis_at_boundary (Array)

  • propagation (PropagationMatrices | None)

family: MeshFamily
regularization: Regularization
n: int
scale: float
n_intervals: int
basis_size_per_interval: int
nodes: Array
weights: Array
radii: Array
basis_at_boundary: Array
propagation: PropagationMatrices | None = None
class lax.types.OperatorMatrices(T=None, TpL=None, T_alpha=None, D=None, inv_r=None, inv_r2=None)[source]

Bases: object

Precomputed single-channel operator matrices in fm⁻² units.

All populated fields are (N, N) symmetric real matrices in the Lagrange-mesh basis. Unrequested operators are None.

Variables:
  • T (jax.Array | None) – Kinetic-energy matrix -d²/dr² (Laguerre meshes, no Bloch term).

  • TpL (jax.Array | None) – Bloch-augmented kinetic matrix T + L(B=0). This is the standard operator for R-matrix calculations on finite-interval (Legendre) meshes.

  • T_alpha (jax.Array | None) – Hyperradial kinetic matrix for α-type coordinates (Laguerre meshes with three-body regularization).

  • D (jax.Array | None) – First-derivative matrix d/dr.

  • inv_r (jax.Array | None) – Diagonal 1/r matrix.

  • inv_r2 (jax.Array | None) – Diagonal 1/r² matrix.

Parameters:
T: Array | None = None
TpL: Array | None = None
T_alpha: Array | None = None
D: Array | None = None
inv_r: Array | None = None
inv_r2: Array | None = None

lax.meshes

lax.meshes.build_mesh(family, regularization, n, scale, operators, **extras)[source]

Dispatch to the concrete mesh builder for the requested mesh kind.

Parameters:
  • family – Mesh family identifier.

  • regularization – Regularization identifier.

  • n – Number of mesh basis functions.

  • scale – Physical scale (channel radius or Laguerre scale) in fm.

  • operators – Set of operator names to precompute, e.g. {"T+L", "1/r"}.

  • **extras – Additional keyword arguments forwarded to the concrete builder (e.g. n_intervals for propagated meshes).

Returns:

tuple[Mesh, OperatorMatrices] – Compiled mesh data and precomputed operator matrices.

Raises:

ValueError – If no builder is registered for (family, regularization).

Parameters:
  • family (MeshFamily)

  • regularization (Regularization)

  • n (int)

  • scale (float)

  • operators (set[str])

  • extras (object)

Return type:

tuple[Mesh, OperatorMatrices]

lax.propagate

lax.propagate.build_legendre_x_propagation(*, basis_size_per_interval, n_intervals, scale)[source]

Return the precomputed matrices used by Descouvemont-style R-matrix propagation.

Divides the internal region [0, a] into n_intervals equal subintervals of width a / n_intervals. For each subinterval the per-interval kinetic matrix and Bloch surface-overlap matrices are built using the shifted Legendre-x formulae from Descouvemont. The resulting PropagationMatrices object is stored inside the Mesh and consumed by _propagated_rmatrix_at_energy at runtime.

Parameters:
  • basis_size_per_interval – Number of Legendre basis functions per subinterval.

  • n_intervals – Number of subintervals to divide the internal region into.

  • scale – Total channel radius a in fm.

Returns:

PropagationMatrices – All precomputed kinetic and boundary-overlap matrices for the propagation recursion.

Parameters:
  • basis_size_per_interval (int)

  • n_intervals (int)

  • scale (float)

Return type:

PropagationMatrices