Converting NumPyro objects to DataTree#
DataTree is the data format ArviZ relies on.
This page covers multiple ways to generate a DataTree from NumPyro MCMC and SVI objects.
See also
Conversion from Python, numpy or pandas objects
DataTree for Exploratory Analysis of Bayesian Models for an overview of
InferenceDataand its role within ArviZ.
We will start by importing the required packages and defining the model. The famous 8 school model.
import arviz_base as az
import numpy as np
from numpy.typing import ArrayLike
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide, Predictive
from jax import random
import jax.numpy as jnp
J = 8
y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools_model(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
eta = numpyro.sample("eta", dist.Normal(0, 1))
theta = numpyro.deterministic("theta", mu + tau * eta)
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
def eight_schools_custom_guide(J, sigma, y=None):
# Variational parameters for mu
mu_loc = numpyro.param("mu_loc", 0.0)
mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
mu = numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
# Variational parameters for tau (positive support)
tau_loc = numpyro.param("tau_loc", 1.0)
tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
tau = numpyro.sample("tau", dist.LogNormal(jnp.log(tau_loc), tau_scale))
# Variational parameters for eta
eta_loc = numpyro.param("eta_loc", jnp.zeros(J))
eta_scale = numpyro.param("eta_scale", jnp.ones(J), constraint=dist.constraints.positive)
with numpyro.plate("J", J):
eta = numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))
# Deterministic transform
numpyro.deterministic("theta", mu + tau * eta)
Convert from MCMC#
This first example shows conversion from MCMC
# fit with MCMC
nuts = NUTS(eight_schools_model)
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)
# sample the posterior predictive
predictive = Predictive(eight_schools_model, mcmc.get_samples())
samples_predictive = predictive(random.PRNGKey(1), J=J, sigma=sigma)
# Convert to MCMC
idata_mcmc = az.from_numpyro(mcmc, posterior_predictive=samples_predictive)
idata_mcmc
/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_61101/3262796440.py:3: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
sample: 100%|██████████| 2000/2000 [00:00<00:00, 3287.85it/s, 7 steps of size 4.01e-01. acc. prob=0.90]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 9001.07it/s, 7 steps of size 4.16e-01. acc. prob=0.86]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7635.89it/s, 15 steps of size 3.71e-01. acc. prob=0.92]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7994.02it/s, 15 steps of size 3.47e-01. acc. prob=0.93]
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*Convert from SVI with Autoguide#
eight_schools_guide = autoguide.AutoNormal(eight_schools_model, init_loc_fn=numpyro.infer.init_to_median(num_samples=100))
svi = SVI(
eight_schools_model,
guide=eight_schools_guide,
optim=numpyro.optim.Adam(0.01),
loss = Trace_ELBO()
)
svi_result = svi.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)
# sample the posterior predictive
predictive_svi = Predictive(eight_schools_model, guide=eight_schools_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi = predictive_svi(random.PRNGKey(1), J=J, sigma=sigma)
idata_svi = az.from_numpyro_svi(
svi,
svi_result=svi_result,
model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
num_samples = 4000, # number of samples to draw in the posterior
posterior_predictive=samples_predictive_svi
)
idata_svi
100%|██████████| 10000/10000 [00:00<00:00, 11110.33it/s, init loss: 53.6608, avg. loss [9501-10000]: 31.6204]
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*Converting from SVI with a custom guide function#
svi_custom_guide = SVI(
eight_schools_model,
guide=eight_schools_custom_guide,
optim=numpyro.optim.Adam(0.01),
loss = Trace_ELBO()
)
svi_custom_guide_result = svi_custom_guide.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)
# sample the posterior predictive
predictive_svi_custom = Predictive(eight_schools_model, guide=eight_schools_custom_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi_custom = predictive_svi_custom(random.PRNGKey(1), J=J, sigma=sigma)
idata_svi_custom_guide = az.from_numpyro_svi(
svi_custom_guide,
svi_result=svi_custom_guide_result,
model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
num_samples = 4000, # number of samples to draw in the posterior
posterior_predictive=samples_predictive_svi_custom
)
idata_svi_custom_guide
100%|██████████| 10000/10000 [00:00<00:00, 10763.36it/s, init loss: 34.9525, avg. loss [9501-10000]: 31.6279]
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*Automatically Labelling Event Dims#
NumPyro batch dims are automatically labelled according to their corresponding plate names. In order to label event dims, we add infer={"event_dims": dim_labels} to the numpyro.sample statement as shown below:
def eight_schools_model_zsn(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
eta = numpyro.sample(
"eta",
dist.ZeroSumNormal(tau, event_shape=(J,)),
# note: this allows arviz to infer the event dimension labels
infer={"event_dims":["J"]}
)
with numpyro.plate("J", J):
theta = numpyro.deterministic("theta", mu + eta)
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
# fit with MCMC
nuts = NUTS(eight_schools_model_zsn)
mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc2.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)
# sample the posterior predictive
predictive2 = Predictive(eight_schools_model, mcmc2.get_samples())
samples_predictive2 = predictive2(random.PRNGKey(1), J=J, sigma=sigma)
# Convert to MCMC
idata_mcmc2 = az.from_numpyro(mcmc2, posterior_predictive=samples_predictive2)
/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_61101/306760900.py:17: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2525.59it/s, 3 steps of size 2.56e-01. acc. prob=0.91]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6867.14it/s, 15 steps of size 1.99e-01. acc. prob=0.90]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6385.59it/s, 15 steps of size 2.85e-01. acc. prob=0.83]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7863.09it/s, 3 steps of size 2.61e-01. acc. prob=0.83]
Notice that eta is labelled appropriately with J
idata_mcmc2
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*Extending NumPyro Conversion to other Inference Objects#
NumPyroInferenceAdapter can be leveraged to extend ArviZ conversion to other NumPyro Inference Objects (such as the NestedSampler)
The example below uses the SVI implementation as an example, where an adapter class is created that inherits the NumPyroInferenceAdapter base class
class SVIAutoGuideAdapter(az.NumPyroInferenceAdapter):
"""Adapter for SVI to standardize attributes and methods with other inference objects."""
def __init__(
self, svi, *, svi_result, model_args=None, model_kwargs=None, num_samples = 1000,
):
# Necessary: Specify the inference object, internal model fn, inputs, and sample shape
super().__init__(
svi,
model=getattr(svi.guide, "model", svi.model),
model_args=model_args,
model_kwargs=model_kwargs,
sample_shape=(num_samples,),
)
self.result_obj = svi_result # saving this to help with posterior sampling
# Necessary: Specify the sample dim names and shape. ie MCMC is ("chain", "draw")
@property
def sample_dims(self):
return ["sample"]
# Necessary: Specify how to get posterior samples from the inference objects
# for SVI in numpyro, we need to sample from the guide with our SVI params
def get_samples(self, seed = None, **kwargs):
key = self.prng_key_func(seed or 0)
return self.posterior.guide.sample_posterior(
key,
self.result_obj.params, # the internal SVI params needed to make predictions
*self._args,
sample_shape=self.sample_shape,
**self._kwargs,
)
The instantiated adapter can now be passed directly into az.from_numpyro.
adapter = SVIAutoGuideAdapter(
svi,
svi_result=svi_result,
model_kwargs=dict(J=J, sigma=sigma, y=y_obs),
num_samples = 4000
)
idata_svi2 = az.from_numpyro(adapter, posterior_predictive=samples_predictive_svi)
idata_svi2
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun Mar 15 2026
Python implementation: CPython
Python version : 3.12.10
IPython version : 9.4.0
jax : 0.9.0.1
numpy : 2.3.2
numpyro : 0.20.0
arviz_base: 0.7.0.dev0
Watermark: 2.5.0