Skip to content

jax_kalman_filter

Module which implements JAX-based Kalman filter algorithm.

JaxKalmanFilter

JaxKalmanFilter(data: dict, use_gw: bool = True)

A class to implement the linear Kalman filter on scalar inputs using JAX.

Parameters:

  • df_psr

    DataFrame containing pulsar information including: - dim_M: integer, number of design parameters for that pulsar - F0: pulsar spin frequency

  • observations

    Dictionary containing 'toas', 'residuals', and 'errors' arrays from the data loader

  • Peps

    The uncertainty matrix for the epsilon states

  • hd_correlation_matrix

    Precomputed Hellings-Downs correlation matrix

  • pulsar_design_matrices

    Design matrices for each pulsar

  • use_gw (bool, default: True ) –

    If True, include GW terms in measurement equation. Default True.

get_likelihood

get_likelihood(θ)

Run the Kalman filter algorithm over all observations and return a log likelihood.

F_matrix

F_matrix(dt: float, γa: float, γp: float) -> tuple[np.ndarray, np.ndarray]

Return the state–transition matrix for time step dt.

Parameters:

  • dt (float) –

    Time step

  • γa (float) –

    GW damping rate

  • γp (float) –

    Spin noise damping rate

Returns
tuple: (F_gw, F_spin) matrices

Q_matrix

Q_matrix(dt: float, γa: float, γp: float, σa2: float, σp2: float) -> tuple[np.ndarray, np.ndarray]

Return the process–noise covariance matrix for time step dt.

Parameters:

  • dt (float) –

    Time step

  • γa (float) –

    GW damping rate

  • γp (float) –

    Spin noise damping rate

  • σa2 (float) –

    GW noise amplitude squared

  • σp2 (float) –

    Spin noise amplitude squared

Returns
tuple: (Q_gw, Q_spin) matrices

get_logger

get_logger()

Get the centralized logger instance.

compute_predicted_state

compute_predicted_state(F_list, x, gw_size, spin_size)

Compute the predicted state vector by applying transition matrices to state blocks.

Parameters:

  • F_list

    Tuple of (F_gw, F_spin) transition matrices for GW and spin components

  • x

    Current state vector containing GW, spin and timing components

  • gw_size

    Size of gravitational wave state block

  • spin_size

    Size of spin state block

Returns

jax.Array: Predicted state vector with same structure as input, computed by:
    - Applying F_gw transition to GW states
    - Applying F_spin transition to spin states  
    - Keeping timing states unchanged
Note

The state vector x is assumed to have structure [x_gw, x_spin, x_timing] where each component has size determined by gw_size and spin_size parameters.

compute_predicted_covariance

compute_predicted_covariance(P: Array, F_list: Tuple[Array, Array], Q_list: Tuple[Array, ...], gw_size: int, spin_size: int) -> jax.Array

Compute predicted covariance matrix in one operation.

Parameters:

  • P (Array) –

    Full covariance matrix

  • F_list (Tuple[Array, Array]) –

    Tuple of (F_gw, F_spin) transition matrices

  • Q_list (Tuple[Array, ...]) –

    Tuple of (Q_gw, Q_spin, Q_timing) process noise matrices

  • gw_size (int) –

    Size of GW block

  • spin_size (int) –

    Size of spin block

Returns

jax.Array: Combined predicted covariance matrix
Note

Computing the predicted covariance by slicing the matrix into blocks and doing individual matrix products is significantly faster than doing the full matrix multiplication FPF^T + Q. This is because the block structure allows us to avoid many unnecessary multiplications with zero elements.