"""State-space model representation (West & Harrison notation)."""
from __future__ import annotations
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from jax import Array
from dynaris.core.types import GaussianState
[docs]
@dataclass(frozen=True)
class StateSpaceModel:
"""Linear-Gaussian Dynamic Linear Model.
Following West & Harrison (1997) notation:
System equation: theta_t = G @ theta_{t-1} + omega_t, omega_t ~ N(0, W)
Observation eq: Y_t = F' @ theta_t + nu_t, nu_t ~ N(0, V)
Attributes:
observation_matrix: F, shape (obs_dim, state_dim).
system_matrix: G, shape (state_dim, state_dim).
obs_cov: V, shape (obs_dim, obs_dim).
evolution_cov: W, shape (state_dim, state_dim).
input_matrix: B, shape (state_dim, input_dim) or None.
"""
observation_matrix: Array # F: (m, n)
system_matrix: Array # G: (n, n)
obs_cov: Array # V: (m, m)
evolution_cov: Array # W: (n, n)
input_matrix: Array | None = None # B: (n, p) or None
# --- Dimension properties ---
@property
def state_dim(self) -> int:
"""Dimension of the state (parameter) vector."""
return int(self.system_matrix.shape[-1])
@property
def obs_dim(self) -> int:
"""Dimension of the observation vector."""
return int(self.observation_matrix.shape[-2])
# --- Short aliases (West & Harrison notation) ---
@property
def F(self) -> Array: # noqa: N802
"""Observation/regression matrix (F in W&H)."""
return self.observation_matrix
@property
def G(self) -> Array: # noqa: N802
"""System/evolution matrix (G in W&H)."""
return self.system_matrix
@property
def V(self) -> Array: # noqa: N802
"""Observational variance/covariance (V in W&H)."""
return self.obs_cov
@property
def W(self) -> Array: # noqa: N802
"""Evolution covariance (W in W&H)."""
return self.evolution_cov
@property
def B(self) -> Array | None: # noqa: N802
"""Alias for input_matrix."""
return self.input_matrix
# --- Factory methods ---
[docs]
def initial_state(
self,
mean: Array | None = None,
cov: Array | None = None,
) -> GaussianState:
"""Create a default initial GaussianState for this model.
Args:
mean: Initial state mean (m_0). Defaults to zeros.
cov: Initial state covariance (C_0). Defaults to 1e6 * I (diffuse prior).
Returns:
GaussianState with the specified or default initial conditions.
"""
n = self.state_dim
if mean is None:
mean = jnp.zeros(n)
if cov is None:
cov = jnp.eye(n) * 1e6
return GaussianState(mean=mean, cov=cov)
# --- Composition (superposition principle) ---
def __add__(self, other: StateSpaceModel) -> StateSpaceModel:
"""Compose two models via superposition (West & Harrison).
The resulting model has:
- G_new = block_diag(G1, G2)
- F_new = [F1, F2] (horizontal concatenation)
- W_new = block_diag(W1, W2)
- V_new = V1 + V2 (shared observation noise adds)
"""
n1, n2 = self.state_dim, other.state_dim
system = jnp.block(
[
[self.G, jnp.zeros((n1, n2))],
[jnp.zeros((n2, n1)), other.G],
]
)
observation = jnp.concatenate([self.F, other.F], axis=-1)
evolution = jnp.block(
[
[self.W, jnp.zeros((n1, n2))],
[jnp.zeros((n2, n1)), other.W],
]
)
obs = self.V + other.V
input_mat: Array | None = None
if self.input_matrix is not None and other.input_matrix is not None:
p1 = self.input_matrix.shape[-1]
p2 = other.input_matrix.shape[-1]
input_mat = jnp.block(
[
[self.input_matrix, jnp.zeros((n1, p2))],
[jnp.zeros((n2, p1)), other.input_matrix],
]
)
elif self.input_matrix is not None:
input_mat = jnp.concatenate(
[
self.input_matrix,
jnp.zeros((n2, self.input_matrix.shape[-1])),
],
axis=-2,
)
elif other.input_matrix is not None:
input_mat = jnp.concatenate(
[
jnp.zeros((n1, other.input_matrix.shape[-1])),
other.input_matrix,
],
axis=-2,
)
return StateSpaceModel(
observation_matrix=observation,
system_matrix=system,
obs_cov=obs,
evolution_cov=evolution,
input_matrix=input_mat,
)
def __repr__(self) -> str:
b_info = (
f", input_dim={self.input_matrix.shape[-1]}" if self.input_matrix is not None else ""
)
return f"StateSpaceModel(state_dim={self.state_dim}, obs_dim={self.obs_dim}{b_info})"
# --- JAX pytree registration ---
[docs]
def tree_flatten(self) -> tuple[list[Array], dict[str, bool]]:
"""Flatten into JAX pytree leaves and auxiliary data."""
has_input = self.input_matrix is not None
leaves: list[Array] = [
self.observation_matrix,
self.system_matrix,
self.obs_cov,
self.evolution_cov,
]
if has_input:
leaves.append(self.input_matrix) # type: ignore[arg-type]
return leaves, {"has_input": has_input}
[docs]
@classmethod
def tree_unflatten(cls, aux_data: dict[str, bool], children: list[Array]) -> StateSpaceModel:
"""Reconstruct from JAX pytree leaves."""
if aux_data["has_input"]:
return cls(
observation_matrix=children[0],
system_matrix=children[1],
obs_cov=children[2],
evolution_cov=children[3],
input_matrix=children[4],
)
return cls(
observation_matrix=children[0],
system_matrix=children[1],
obs_cov=children[2],
evolution_cov=children[3],
)
jax.tree_util.register_pytree_node_class(StateSpaceModel)