Source code for dynaris.core.state_space

"""State-space model representation (West & Harrison notation)."""

from __future__ import annotations

from dataclasses import dataclass

import jax
import jax.numpy as jnp
from jax import Array

from dynaris.core.types import GaussianState


[docs] @dataclass(frozen=True) class StateSpaceModel: """Linear-Gaussian Dynamic Linear Model. Following West & Harrison (1997) notation: System equation: theta_t = G @ theta_{t-1} + omega_t, omega_t ~ N(0, W) Observation eq: Y_t = F' @ theta_t + nu_t, nu_t ~ N(0, V) Attributes: observation_matrix: F, shape (obs_dim, state_dim). system_matrix: G, shape (state_dim, state_dim). obs_cov: V, shape (obs_dim, obs_dim). evolution_cov: W, shape (state_dim, state_dim). input_matrix: B, shape (state_dim, input_dim) or None. """ observation_matrix: Array # F: (m, n) system_matrix: Array # G: (n, n) obs_cov: Array # V: (m, m) evolution_cov: Array # W: (n, n) input_matrix: Array | None = None # B: (n, p) or None # --- Dimension properties --- @property def state_dim(self) -> int: """Dimension of the state (parameter) vector.""" return int(self.system_matrix.shape[-1]) @property def obs_dim(self) -> int: """Dimension of the observation vector.""" return int(self.observation_matrix.shape[-2]) # --- Short aliases (West & Harrison notation) --- @property def F(self) -> Array: # noqa: N802 """Observation/regression matrix (F in W&H).""" return self.observation_matrix @property def G(self) -> Array: # noqa: N802 """System/evolution matrix (G in W&H).""" return self.system_matrix @property def V(self) -> Array: # noqa: N802 """Observational variance/covariance (V in W&H).""" return self.obs_cov @property def W(self) -> Array: # noqa: N802 """Evolution covariance (W in W&H).""" return self.evolution_cov @property def B(self) -> Array | None: # noqa: N802 """Alias for input_matrix.""" return self.input_matrix # --- Factory methods ---
[docs] def initial_state( self, mean: Array | None = None, cov: Array | None = None, ) -> GaussianState: """Create a default initial GaussianState for this model. Args: mean: Initial state mean (m_0). Defaults to zeros. cov: Initial state covariance (C_0). Defaults to 1e6 * I (diffuse prior). Returns: GaussianState with the specified or default initial conditions. """ n = self.state_dim if mean is None: mean = jnp.zeros(n) if cov is None: cov = jnp.eye(n) * 1e6 return GaussianState(mean=mean, cov=cov)
# --- Composition (superposition principle) --- def __add__(self, other: StateSpaceModel) -> StateSpaceModel: """Compose two models via superposition (West & Harrison). The resulting model has: - G_new = block_diag(G1, G2) - F_new = [F1, F2] (horizontal concatenation) - W_new = block_diag(W1, W2) - V_new = V1 + V2 (shared observation noise adds) """ n1, n2 = self.state_dim, other.state_dim system = jnp.block( [ [self.G, jnp.zeros((n1, n2))], [jnp.zeros((n2, n1)), other.G], ] ) observation = jnp.concatenate([self.F, other.F], axis=-1) evolution = jnp.block( [ [self.W, jnp.zeros((n1, n2))], [jnp.zeros((n2, n1)), other.W], ] ) obs = self.V + other.V input_mat: Array | None = None if self.input_matrix is not None and other.input_matrix is not None: p1 = self.input_matrix.shape[-1] p2 = other.input_matrix.shape[-1] input_mat = jnp.block( [ [self.input_matrix, jnp.zeros((n1, p2))], [jnp.zeros((n2, p1)), other.input_matrix], ] ) elif self.input_matrix is not None: input_mat = jnp.concatenate( [ self.input_matrix, jnp.zeros((n2, self.input_matrix.shape[-1])), ], axis=-2, ) elif other.input_matrix is not None: input_mat = jnp.concatenate( [ jnp.zeros((n1, other.input_matrix.shape[-1])), other.input_matrix, ], axis=-2, ) return StateSpaceModel( observation_matrix=observation, system_matrix=system, obs_cov=obs, evolution_cov=evolution, input_matrix=input_mat, ) def __repr__(self) -> str: b_info = ( f", input_dim={self.input_matrix.shape[-1]}" if self.input_matrix is not None else "" ) return f"StateSpaceModel(state_dim={self.state_dim}, obs_dim={self.obs_dim}{b_info})" # --- JAX pytree registration ---
[docs] def tree_flatten(self) -> tuple[list[Array], dict[str, bool]]: """Flatten into JAX pytree leaves and auxiliary data.""" has_input = self.input_matrix is not None leaves: list[Array] = [ self.observation_matrix, self.system_matrix, self.obs_cov, self.evolution_cov, ] if has_input: leaves.append(self.input_matrix) # type: ignore[arg-type] return leaves, {"has_input": has_input}
[docs] @classmethod def tree_unflatten(cls, aux_data: dict[str, bool], children: list[Array]) -> StateSpaceModel: """Reconstruct from JAX pytree leaves.""" if aux_data["has_input"]: return cls( observation_matrix=children[0], system_matrix=children[1], obs_cov=children[2], evolution_cov=children[3], input_matrix=children[4], ) return cls( observation_matrix=children[0], system_matrix=children[1], obs_cov=children[2], evolution_cov=children[3], )
jax.tree_util.register_pytree_node_class(StateSpaceModel)