jax_kalman_filter
Module which implements JAX-based Kalman filter algorithm.
JaxKalmanFilter
A class to implement the linear Kalman filter on scalar inputs using JAX.
Parameters:
-
df_psr
–DataFrame containing pulsar information including: - dim_M: integer, number of design parameters for that pulsar - F0: pulsar spin frequency
-
observations
–Dictionary containing 'toas', 'residuals', and 'errors' arrays from the data loader
-
Peps
–The uncertainty matrix for the epsilon states
-
hd_correlation_matrix
–Precomputed Hellings-Downs correlation matrix
-
pulsar_design_matrices
–Design matrices for each pulsar
-
use_gw
(bool
, default:True
) –If True, include GW terms in measurement equation. Default True.
get_likelihood
Run the Kalman filter algorithm over all observations and return a log likelihood.
F_matrix
Q_matrix
compute_predicted_state
Compute the predicted state vector by applying transition matrices to state blocks.
Parameters:
-
F_list
–Tuple of (F_gw, F_spin) transition matrices for GW and spin components
-
x
–Current state vector containing GW, spin and timing components
-
gw_size
–Size of gravitational wave state block
-
spin_size
–Size of spin state block
Returns
jax.Array: Predicted state vector with same structure as input, computed by:
- Applying F_gw transition to GW states
- Applying F_spin transition to spin states
- Keeping timing states unchanged
Note
The state vector x is assumed to have structure [x_gw, x_spin, x_timing] where each component has size determined by gw_size and spin_size parameters.
compute_predicted_covariance
compute_predicted_covariance(P: Array, F_list: Tuple[Array, Array], Q_list: Tuple[Array, ...], gw_size: int, spin_size: int) -> jax.Array
Compute predicted covariance matrix in one operation.
Parameters:
-
P
(Array
) –Full covariance matrix
-
F_list
(Tuple[Array, Array]
) –Tuple of (F_gw, F_spin) transition matrices
-
Q_list
(Tuple[Array, ...]
) –Tuple of (Q_gw, Q_spin, Q_timing) process noise matrices
-
gw_size
(int
) –Size of GW block
-
spin_size
(int
) –Size of spin block
Returns
jax.Array: Combined predicted covariance matrix
Note
Computing the predicted covariance by slicing the matrix into blocks and doing individual matrix products is significantly faster than doing the full matrix multiplication FPF^T + Q. This is because the block structure allows us to avoid many unnecessary multiplications with zero elements.