Batch Processing¶
dynaris supports fitting and forecasting multiple time series in parallel
using jax.vmap.
Batch fitting¶
Pass a batch of series as a 3D array with shape (n_series, T, obs_dim):
import jax.numpy as jnp
from dynaris import LocalLevel, DLM
model = LocalLevel()
dlm = DLM(model)
# y_batch: (n_series, T, 1)
batch_result = dlm.fit_batch(y_batch)
print(batch_result.log_likelihood) # shape (n_series,)
Each series is filtered independently, but all series run in parallel on the same hardware (CPU cores or GPU).
Batch forecasting¶
After batch fitting, generate forecasts for all series at once:
from dynaris.forecast import forecast_batch
fc = forecast_batch(batch_result, model, steps=12)
Low-level API¶
The batch functions wrap jax.vmap over the single-series equivalents:
from dynaris.forecast import fit_batch, forecast_batch
batch_filter = fit_batch(model, y_batch)
batch_fc = forecast_batch(batch_filter, model, steps=12)
See Forecasting for the full API.