DataTree for Exploratory Analysis of Bayesian Models#
Here we present a collection of common manipulations you can use while working with datatree.DataTree.
import arviz_base as az
import numpy as np
import xarray as xr
xr.set_options(display_expand_data=False, display_expand_attrs=False);
display_expand_data=False makes the default view for xarray.DataArray fold the data values to a single line. To explore the values, click on the icon on the left of the view, right under the xarray.DataArray text. It has no effect on Dataset objects that already default to folded views.
display_expand_attrs=False folds the attributes in both DataArray and Dataset objects to keep the views shorter. In this page we print DataArrays and Datasets several times and they always have the same attributes.
idata = az.load_arviz_data("centered_eight")
idata
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 4, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ mu (chain, draw) float64 16kB 1.716 1.903 1.903 ... 5.409 7.721 10.24
│ theta (chain, draw, school) float64 128kB 2.317 1.45 ... 14.92 14.02
│ tau (chain, draw) float64 16kB 0.8775 0.8027 0.8027 ... 2.99 3.052
│ Attributes: (6)
├── Group: /posterior_predictive
│ Dimensions: (chain: 4, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 128kB 38.88 -14.98 ... 27.05 20.99
│ Attributes: (4)
├── Group: /log_likelihood
│ Dimensions: (chain: 4, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 128kB -5.093 -3.436 ... -3.269 -3.816
│ Attributes: (4)
├── Group: /sample_stats
│ Dimensions: (chain: 4, draw: 500)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
│ Data variables: (12/17)
│ step_size (chain, draw) float64 16kB 0.1427 0.1427 ... 0.1233
│ reached_max_treedepth (chain, draw) bool 2kB False False ... False False
│ perf_counter_start (chain, draw) float64 16kB 3.931e+04 ... 3.931e+04
│ energy_error (chain, draw) float64 16kB 1.896 -1.479 ... 0.1372
│ perf_counter_diff (chain, draw) float64 16kB 0.0004726 ... 0.001483
│ tree_depth (chain, draw) int64 16kB 2 3 3 3 3 5 ... 4 4 4 4 4 4
│ ... ...
│ index_in_trajectory (chain, draw) int64 16kB 1 -1 0 0 1 ... -6 9 13 7 -15
│ acceptance_rate (chain, draw) float64 16kB 0.05665 0.1429 ... 0.8901
│ smallest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan
│ energy (chain, draw) float64 16kB 47.98 49.69 ... 60.07
│ diverging (chain, draw) bool 2kB False False ... False False
│ largest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan
│ Attributes: (6)
├── Group: /prior
│ Dimensions: (chain: 1, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ theta (chain, draw, school) float64 32kB -8.435 24.12 ... 54.57 52.29
│ tau (chain, draw) float64 4kB 11.93 17.76 4.732 ... 2.231 3.319 93.69
│ mu (chain, draw) float64 4kB 4.714 3.853 1.709 ... -2.245 -2.435
│ Attributes: (4)
├── Group: /prior_predictive
│ Dimensions: (chain: 1, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 32kB 22.03 26.95 ... 58.23 39.78
│ Attributes: (4)
├── Group: /observed_data
│ Dimensions: (school: 8)
│ Coordinates:
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (school) float64 64B 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
│ Attributes: (4)
└── Group: /constant_data
Dimensions: (school: 8)
Coordinates:
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
sigma (school) float64 64B 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
Attributes: (4)Get a specific group#
post = idata["posterior"]
post
<xarray.DataTree 'posterior'>
Group: /posterior
Dimensions: (chain: 4, draw: 500, school: 8)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
mu (chain, draw) float64 16kB 1.716 1.903 1.903 ... 5.409 7.721 10.24
theta (chain, draw, school) float64 128kB 2.317 1.45 ... 14.92 14.02
tau (chain, draw) float64 16kB 0.8775 0.8027 0.8027 ... 2.99 3.052
Attributes: (6)Tip
You’ll have noticed we stored the posterior group in a new variable: post. As .copy() was not called, now using idata["posterior"] or post is equivalent.
Use this to keep your code short yet easy to read. Store the groups you’ll need very often as separate variables to use explicitly, but don’t delete the DataTree parent. You’ll need it for many ArviZ functions to work properly. For example: plot_pair needs data from sample_stats group to show divergences, compare needs data from both log_likelihood and posterior groups, plot_loo_pit needs not 2 but 3 groups: log_likelihood, posterior_predictive and posterior.
Add a new variable#
post["log_tau"] = np.log(post["tau"])
idata.posterior
<xarray.DataTree 'posterior'>
Group: /posterior
Dimensions: (chain: 4, draw: 500, school: 8)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
mu (chain, draw) float64 16kB 1.716 1.903 1.903 ... 5.409 7.721 10.24
theta (chain, draw, school) float64 128kB 2.317 1.45 ... 14.92 14.02
tau (chain, draw) float64 16kB 0.8775 0.8027 0.8027 ... 2.99 3.052
log_tau (chain, draw) float64 16kB -0.1307 -0.2198 -0.2198 ... 1.095 1.116
Attributes: (6)Combine chains and draws#
stacked = az.extract(idata)
stacked
<xarray.Dataset> Size: 225kB
Dimensions: (sample: 2000, school: 8)
Coordinates:
* sample (sample) object 16kB MultiIndex
* chain (sample) int64 16kB 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3
* draw (sample) int64 16kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
mu (sample) float64 16kB 1.716 1.903 1.903 1.903 ... 5.409 7.721 10.24
theta (school, sample) float64 128kB 2.317 0.8892 0.8892 ... 9.754 14.02
tau (sample) float64 16kB 0.8775 0.8027 0.8027 ... 2.236 2.99 3.052
log_tau (sample) float64 16kB -0.1307 -0.2198 -0.2198 ... 1.095 1.116
Attributes: (7)arviz.extract is a convenience function aimed at taking care of the most common subsetting operations with MCMC samples. It can:
Combine chains and draws
Return a subset of variables (with optional filtering with regular expressions or string matching)
Return a subset of samples. Moreover by default it returns a random subset to prevent getting non-representative samples due to bad mixing.
Access any group
Get a random subset of the samples#
az.extract(idata, num_samples=100)
<xarray.Dataset> Size: 12kB
Dimensions: (sample: 100, school: 8)
Coordinates:
* sample (sample) object 800B MultiIndex
* chain (sample) int64 800B 2 2 0 3 2 0 0 3 2 2 0 ... 0 2 0 2 1 3 1 2 0 0 3
* draw (sample) int64 800B 42 389 488 381 231 148 ... 474 303 97 45 319
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
mu (sample) float64 800B 5.867 6.017 3.748 ... 6.123 -1.702 6.205
theta (school, sample) float64 6kB 4.476 4.615 -5.496 ... -4.86 7.694
tau (sample) float64 800B 3.236 3.063 6.51 3.265 ... 1.297 4.752 1.869
log_tau (sample) float64 800B 1.174 1.119 1.873 1.183 ... 0.26 1.559 0.6255
Attributes: (7)Tip
Use a random seed to get the same subset from multiple groups: az.extract(idata, num_samples=100, rng=3) and az.extract(idata, group="log_likelihood", num_samples=100, rng=3) will continue to have matching samples
Obtain a NumPy array for a given parameter#
Let’s say we want to get the values for mu as a NumPy array.
stacked.mu.values
array([ 1.71572331, 1.90348113, 1.90348113, ..., 5.40883573,
7.72143998, 10.23715678], shape=(2000,))
Get the dimension lengths#
Let’s check how many groups are in our hierarchical model.
idata.observed_data.sizes["school"]
8
Get coordinate values#
What are the names of the groups in our hierarchical model? You can access them from the coordinate name school in this case
idata.observed_data.school
<xarray.DataArray 'school' (school: 8)> Size: 512B 'Choate' 'Deerfield' 'Phillips Andover' ... "St. Paul's" 'Mt. Hermon' Coordinates: * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Get a subset of chains#
Let’s keep only chain 0 and 2 here. For the subset to take effect on all relevant DataTree groups: posterior, sample_stats, log_likelihood, posterior_predictive we will use the datatree.DataTree.filter before using .sel.
posterior_groups = {"posterior", "posterior_predictive", "sample_stats", "log_likelihood"}
idata.filter(lambda node: node.name in posterior_groups).sel(chain=[0, 2])
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 2, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 16B 0 2
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ mu (chain, draw) float64 8kB 1.716 1.903 1.903 ... 7.887 5.518 7.041
│ theta (chain, draw, school) float64 64kB 2.317 1.45 2.086 ... 7.081 6.192
│ tau (chain, draw) float64 8kB 0.8775 0.8027 0.8027 ... 5.464 4.773
│ log_tau (chain, draw) float64 8kB -0.1307 -0.2198 -0.2198 ... 1.698 1.563
│ Attributes: (6)
├── Group: /posterior_predictive
│ Dimensions: (chain: 2, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 16B 0 2
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 64kB 38.88 -14.98 ... 14.84 26.65
│ Attributes: (4)
├── Group: /log_likelihood
│ Dimensions: (chain: 2, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 16B 0 2
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 64kB -5.093 -3.436 ... -3.818 -3.861
│ Attributes: (4)
└── Group: /sample_stats
Dimensions: (chain: 2, draw: 500)
Coordinates:
* chain (chain) int64 16B 0 2
* draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
Data variables: (12/17)
step_size (chain, draw) float64 8kB 0.1427 0.1427 ... 0.2552
reached_max_treedepth (chain, draw) bool 1kB False False ... False False
perf_counter_start (chain, draw) float64 8kB 3.931e+04 ... 3.931e+04
energy_error (chain, draw) float64 8kB 1.896 -1.479 ... -0.04605
perf_counter_diff (chain, draw) float64 8kB 0.0004726 ... 0.00193
tree_depth (chain, draw) int64 8kB 2 3 3 3 3 5 4 ... 5 4 4 4 4 5
... ...
index_in_trajectory (chain, draw) int64 8kB 1 -1 0 0 1 ... -11 -7 -6 6 1
acceptance_rate (chain, draw) float64 8kB 0.05665 0.1429 ... 0.9809
smallest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan
energy (chain, draw) float64 8kB 47.98 49.69 ... 62.64 63.12
diverging (chain, draw) bool 1kB False False ... False False
largest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan
Attributes: (6)Remove the first n draws (burn-in)#
Let’s say we want to remove the first 100 samples, from all the chains and all DataTree groups with draws.
idata.filter(lambda node: "draw" in node.dims).sel(draw=slice(100, None))
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 4, draw: 400, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ mu (chain, draw) float64 13kB 6.497 1.921 3.822 ... 5.409 7.721 10.24
│ theta (chain, draw, school) float64 102kB 8.271 7.066 ... 14.92 14.02
│ tau (chain, draw) float64 13kB 1.386 2.267 1.846 ... 2.236 2.99 3.052
│ log_tau (chain, draw) float64 13kB 0.3267 0.8184 0.6132 ... 1.095 1.116
│ Attributes: (6)
├── Group: /posterior_predictive
│ Dimensions: (chain: 4, draw: 400, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 102kB -17.56 4.567 ... 27.05 20.99
│ Attributes: (4)
├── Group: /log_likelihood
│ Dimensions: (chain: 4, draw: 400, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 102kB -4.492 -3.226 ... -3.269 -3.816
│ Attributes: (4)
├── Group: /sample_stats
│ Dimensions: (chain: 4, draw: 400)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 3kB 100 101 102 103 ... 496 497 498 499
│ Data variables: (12/17)
│ step_size (chain, draw) float64 13kB 0.1427 0.1427 ... 0.1233
│ reached_max_treedepth (chain, draw) bool 2kB False False ... False False
│ perf_counter_start (chain, draw) float64 13kB 3.931e+04 ... 3.931e+04
│ energy_error (chain, draw) float64 13kB -0.9934 0.2091 ... 0.1372
│ perf_counter_diff (chain, draw) float64 13kB 0.0009694 ... 0.001483
│ tree_depth (chain, draw) int64 13kB 3 5 4 4 4 4 ... 4 4 4 4 4 4
│ ... ...
│ index_in_trajectory (chain, draw) int64 13kB 4 -8 -3 7 12 ... 9 13 7 -15
│ acceptance_rate (chain, draw) float64 13kB 0.9029 0.9153 ... 0.8901
│ smallest_eigval (chain, draw) float64 13kB nan nan nan ... nan nan
│ energy (chain, draw) float64 13kB 54.35 56.89 ... 60.07
│ diverging (chain, draw) bool 2kB False False ... False False
│ largest_eigval (chain, draw) float64 13kB nan nan nan ... nan nan
│ Attributes: (6)
├── Group: /prior
│ Dimensions: (chain: 1, draw: 400, school: 8)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ theta (chain, draw, school) float64 26kB -9.565 2.252 ... 54.57 52.29
│ tau (chain, draw) float64 3kB 5.428 2.633 0.7054 ... 2.231 3.319 93.69
│ mu (chain, draw) float64 3kB 1.102 14.77 -7.669 ... -2.245 -2.435
│ Attributes: (4)
└── Group: /prior_predictive
Dimensions: (chain: 1, draw: 400, school: 8)
Coordinates:
* chain (chain) int64 8B 0
* draw (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
obs (chain, draw, school) float64 26kB -23.9 1.67 ... 58.23 39.78
Attributes: (4)If you check the burnin object you will see that the groups posterior, posterior_predictive, prior and sample_stats have 400 draws compared to idata that has 500. Alternatively, you can specify which group or groups you want to change.
idata.filter(lambda node: node.name in posterior_groups).sel(draw=slice(100, None))
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 4, draw: 400, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ mu (chain, draw) float64 13kB 6.497 1.921 3.822 ... 5.409 7.721 10.24
│ theta (chain, draw, school) float64 102kB 8.271 7.066 ... 14.92 14.02
│ tau (chain, draw) float64 13kB 1.386 2.267 1.846 ... 2.236 2.99 3.052
│ log_tau (chain, draw) float64 13kB 0.3267 0.8184 0.6132 ... 1.095 1.116
│ Attributes: (6)
├── Group: /posterior_predictive
│ Dimensions: (chain: 4, draw: 400, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 102kB -17.56 4.567 ... 27.05 20.99
│ Attributes: (4)
├── Group: /log_likelihood
│ Dimensions: (chain: 4, draw: 400, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 3kB 100 101 102 103 104 105 ... 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 102kB -4.492 -3.226 ... -3.269 -3.816
│ Attributes: (4)
└── Group: /sample_stats
Dimensions: (chain: 4, draw: 400)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 3kB 100 101 102 103 ... 496 497 498 499
Data variables: (12/17)
step_size (chain, draw) float64 13kB 0.1427 0.1427 ... 0.1233
reached_max_treedepth (chain, draw) bool 2kB False False ... False False
perf_counter_start (chain, draw) float64 13kB 3.931e+04 ... 3.931e+04
energy_error (chain, draw) float64 13kB -0.9934 0.2091 ... 0.1372
perf_counter_diff (chain, draw) float64 13kB 0.0009694 ... 0.001483
tree_depth (chain, draw) int64 13kB 3 5 4 4 4 4 ... 4 4 4 4 4 4
... ...
index_in_trajectory (chain, draw) int64 13kB 4 -8 -3 7 12 ... 9 13 7 -15
acceptance_rate (chain, draw) float64 13kB 0.9029 0.9153 ... 0.8901
smallest_eigval (chain, draw) float64 13kB nan nan nan ... nan nan
energy (chain, draw) float64 13kB 54.35 56.89 ... 60.07
diverging (chain, draw) bool 2kB False False ... False False
largest_eigval (chain, draw) float64 13kB nan nan nan ... nan nan
Attributes: (6)Compute posterior mean values along draw and chain dimensions#
To compute the mean value of the posterior samples, do the following:
post.mean()
<xarray.DataTree 'posterior'>
Group: /
Dimensions: ()
Data variables:
mu float64 8B 4.171
theta float64 8B 4.749
tau float64 8B 4.321
log_tau float64 8B 1.256
Attributes: (6)This computes the mean along all dimensions. This is probably what you want for mu and tau, which have two dimensions (chain and draw), but maybe not what you expected for theta, which has one more dimension school.
You can specify along which dimension you want to compute the mean (or other functions).
post.mean(dim=["chain", "draw"])
<xarray.DataTree 'posterior'>
Group: /
Dimensions: (school: 8)
Coordinates:
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
mu float64 8B 4.171
theta (school) float64 64B 6.42 4.954 3.423 4.754 3.453 3.663 6.505 4.82
tau float64 8B 4.321
log_tau float64 8B 1.256
Attributes: (6)Compute and store posterior pushforward quantities#
We use “posterior pushfoward quantities” to refer to quantities that are not variables in the posterior but deterministic computations using posterior variables.
You can use xarray for these pushforward operations and store them as a new variable in the posterior group. You’ll then be able to plot them with ArviZ functions, calculate stats and diagnostics on them (like the mcse) or save and share the datatree.DataTree object with the pushforward quantities included.
Compute the rolling mean of \(\log(\tau)\) with xarray.DataArray.rolling, storing the result in the posterior
post["mlogtau"] = post["log_tau"].rolling({"draw": 50}).mean()
Using xarray for pushforward calculations has all the advantages of working with xarray. It also inherits the disadvantages of working with xarray, but we believe those to be outweighed by the advantages, and we have already shown how to extract the data as NumPy arrays. Working with datatree.DataTree is working mainly with xarray objects and this is what is shown in this guide.
Some examples of these advantages are specifying operations with named dimensions instead of positional ones (as seen in some previous sections), automatic alignment and broadcasting of arrays (as we’ll see now), or integration with Dask.
In this cell you will compute pairwise differences between schools on their mean effects (variable theta).
To do so, subtract the variable theta after renaming the school dimension to the original variable.
Xarray then aligns and broadcasts the two variables because they have different dimensions, and
the result is a 4d variable with all the pointwise differences.
Eventually, store the result in the theta_school_diff variable:
post["theta_school_diff"] = post.theta - post.theta.rename(school="school_bis")
Note
This same operation using NumPy would require manual alignment of the two arrays to make sure they broadcast correctly. The could would be something like:
theta_school_diff = theta[:, :, :, None] - theta[:, :, None, :]
The theta_school_diff variable in the posterior has kept the named dimensions and coordinates:
post
<xarray.DataTree 'posterior'>
Group: /posterior
Dimensions: (chain: 4, draw: 500, school: 8, school_bis: 8)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
* school_bis (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'
Data variables:
mu (chain, draw) float64 16kB 1.716 1.903 ... 7.721 10.24
theta (chain, draw, school) float64 128kB 2.317 1.45 ... 14.02
tau (chain, draw) float64 16kB 0.8775 0.8027 ... 2.99 3.052
log_tau (chain, draw) float64 16kB -0.1307 -0.2198 ... 1.116
mlogtau (chain, draw) float64 16kB nan nan nan ... 1.335 1.335
theta_school_diff (chain, draw, school, school_bis) float64 1MB 0.0 ... 0.0
Attributes: (6)Advanced subsetting#
To select the value corresponding to the difference between the Choate and Deerfield schools do:
post["theta_school_diff"].sel(school="Choate", school_bis="Deerfield")
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500)> Size: 16kB
0.8672 0.1462 0.1462 0.1462 0.2902 1.08 ... 5.899 -5.935 6.959 -5.226 -3.25
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
school <U16 64B 'Choate'
school_bis <U16 64B 'Deerfield'For more advanced subsetting (the equivalent to what is sometimes called “fancy indexing” in NumPy) you need to provide the indices as DataArray objects:
school_idx = xr.DataArray(["Choate", "Hotchkiss", "Mt. Hermon"], dims=["pairwise_school_diff"])
school_bis_idx = xr.DataArray(
["Deerfield", "Choate", "Lawrenceville"], dims=["pairwise_school_diff"]
)
post["theta_school_diff"].sel(school=school_idx, school_bis=school_bis_idx)
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500,
pairwise_school_diff: 3)> Size: 48kB
0.8672 0.7541 -1.253 0.1462 1.946 0.4254 ... 2.964 -2.105 -3.25 0.4516 -1.048
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
school (pairwise_school_diff) <U16 192B 'Choate' ... 'Mt. Hermon'
school_bis (pairwise_school_diff) <U16 192B 'Deerfield' ... 'Lawrenceville'
Dimensions without coordinates: pairwise_school_diffUsing lists or NumPy arrays instead of DataArrays does colum/row based indexing. As you can see, the result has 9 values of theta_school_diff instead of the 3 pairs of difference we selected in the previous cell:
post["theta_school_diff"].sel(
school=["Choate", "Hotchkiss", "Mt. Hermon"],
school_bis=["Deerfield", "Choate", "Lawrenceville"],
)
<xarray.DataArray 'theta_school_diff' (chain: 4, draw: 500, school: 3,
school_bis: 3)> Size: 144kB
0.8672 0.0 -0.3956 1.621 0.7541 0.3585 ... 0.4516 -4.155 0.3088 3.559 -1.048
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* school (school) <U16 192B 'Choate' 'Hotchkiss' 'Mt. Hermon'
* school_bis (school_bis) <U16 192B 'Deerfield' 'Choate' 'Lawrenceville'Add new chains using concat#
After checking the mcse and realizing you need more samples, you rerun the model with two chains
and obtain an idata_rerun object.
# example object with new samples
idata_rerun = (
idata.filter(lambda node: node.name in posterior_groups)
.sel(chain=[0, 1])
.copy()
.map_over_datasets(lambda ds: ds.assign_coords(chain= [4, 5]) if ds else ds)
)
idata_rerun
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 2, draw: 500, school: 8, school_bis: 8)
│ Coordinates:
│ * chain (chain) int64 16B 4 5
│ * draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ * school_bis (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'
│ Data variables:
│ mu (chain, draw) float64 8kB 1.716 1.903 ... 2.499 4.127
│ theta (chain, draw, school) float64 64kB 2.317 1.45 ... 2.815
│ tau (chain, draw) float64 8kB 0.8775 0.8027 ... 6.288 6.8
│ log_tau (chain, draw) float64 8kB -0.1307 -0.2198 ... 1.839 1.917
│ mlogtau (chain, draw) float64 8kB nan nan nan ... 1.314 1.324
│ theta_school_diff (chain, draw, school, school_bis) float64 512kB 0.0 .....
│ Attributes: (6)
├── Group: /posterior_predictive
│ Dimensions: (chain: 2, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 16B 4 5
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 64kB 38.88 -14.98 ... -1.931 35.85
│ Attributes: (4)
├── Group: /log_likelihood
│ Dimensions: (chain: 2, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 16B 4 5
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 64kB -5.093 -3.436 ... -3.913 -3.939
│ Attributes: (4)
└── Group: /sample_stats
Dimensions: (chain: 2, draw: 500)
Coordinates:
* chain (chain) int64 16B 4 5
* draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
Data variables: (12/17)
step_size (chain, draw) float64 8kB 0.1427 0.1427 ... 0.1403
reached_max_treedepth (chain, draw) bool 1kB False False ... False False
perf_counter_start (chain, draw) float64 8kB 3.931e+04 ... 3.931e+04
energy_error (chain, draw) float64 8kB 1.896 -1.479 ... 0.02623
perf_counter_diff (chain, draw) float64 8kB 0.0004726 ... 0.001291
tree_depth (chain, draw) int64 8kB 2 3 3 3 3 5 4 ... 4 4 4 4 4 4
... ...
index_in_trajectory (chain, draw) int64 8kB 1 -1 0 0 1 1 ... 3 -12 9 -6 2
acceptance_rate (chain, draw) float64 8kB 0.05665 0.1429 ... 0.9915
smallest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan
energy (chain, draw) float64 8kB 47.98 49.69 ... 64.34 62.41
diverging (chain, draw) bool 1kB False False ... False False
largest_eigval (chain, draw) float64 8kB nan nan nan ... nan nan nan
Attributes: (6)idata_complete = xr.DataTree()
for group, group_data in idata.children.items():
if group in idata_rerun.children:
ds_complete = xr.concat([group_data.dataset, idata_rerun[group].dataset], dim="chain")
idata_complete[group] = ds_complete
else:
idata_complete[group] = group_data.dataset
idata_complete
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 6, draw: 500, school: 8, school_bis: 8)
│ Coordinates:
│ * chain (chain) int64 48B 0 1 2 3 4 5
│ * draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ * school_bis (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'
│ Data variables:
│ mu (chain, draw) float64 24kB 1.716 1.903 ... 2.499 4.127
│ theta (chain, draw, school) float64 192kB 2.317 1.45 ... 2.815
│ tau (chain, draw) float64 24kB 0.8775 0.8027 ... 6.288 6.8
│ log_tau (chain, draw) float64 24kB -0.1307 -0.2198 ... 1.917
│ mlogtau (chain, draw) float64 24kB nan nan nan ... 1.314 1.324
│ theta_school_diff (chain, draw, school, school_bis) float64 2MB 0.0 ... 0.0
│ Attributes: (6)
├── Group: /posterior_predictive
│ Dimensions: (chain: 6, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 48B 0 1 2 3 4 5
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 192kB 38.88 -14.98 ... -1.931 35.85
│ Attributes: (4)
├── Group: /log_likelihood
│ Dimensions: (chain: 6, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 48B 0 1 2 3 4 5
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 192kB -5.093 -3.436 ... -3.913 -3.939
│ Attributes: (4)
├── Group: /sample_stats
│ Dimensions: (chain: 6, draw: 500)
│ Coordinates:
│ * chain (chain) int64 48B 0 1 2 3 4 5
│ * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
│ Data variables: (12/17)
│ step_size (chain, draw) float64 24kB 0.1427 0.1427 ... 0.1403
│ reached_max_treedepth (chain, draw) bool 3kB False False ... False False
│ perf_counter_start (chain, draw) float64 24kB 3.931e+04 ... 3.931e+04
│ energy_error (chain, draw) float64 24kB 1.896 -1.479 ... 0.02623
│ perf_counter_diff (chain, draw) float64 24kB 0.0004726 ... 0.001291
│ tree_depth (chain, draw) int64 24kB 2 3 3 3 3 5 ... 4 4 4 4 4 4
│ ... ...
│ index_in_trajectory (chain, draw) int64 24kB 1 -1 0 0 1 ... 3 -12 9 -6 2
│ acceptance_rate (chain, draw) float64 24kB 0.05665 0.1429 ... 0.9915
│ smallest_eigval (chain, draw) float64 24kB nan nan nan ... nan nan
│ energy (chain, draw) float64 24kB 47.98 49.69 ... 62.41
│ diverging (chain, draw) bool 3kB False False ... False False
│ largest_eigval (chain, draw) float64 24kB nan nan nan ... nan nan
│ Attributes: (6)
├── Group: /prior
│ Dimensions: (chain: 1, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ theta (chain, draw, school) float64 32kB -8.435 24.12 ... 54.57 52.29
│ tau (chain, draw) float64 4kB 11.93 17.76 4.732 ... 2.231 3.319 93.69
│ mu (chain, draw) float64 4kB 4.714 3.853 1.709 ... -2.245 -2.435
│ Attributes: (4)
├── Group: /prior_predictive
│ Dimensions: (chain: 1, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 32kB 22.03 26.95 ... 58.23 39.78
│ Attributes: (4)
├── Group: /observed_data
│ Dimensions: (school: 8)
│ Coordinates:
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (school) float64 64B 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
│ Attributes: (4)
└── Group: /constant_data
Dimensions: (school: 8)
Coordinates:
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
sigma (school) float64 64B 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
Attributes: (4)Add a new group to a DataTree#
You can also add new groups to a DataTree with the .merge method as above, or using the parent argument when creating new DataTrees object.
The code below creates an example dataset and adds it to the idata DataTree.
rng = np.random.default_rng(3)
ds = az.dict_to_dataset(
{"obs": rng.normal(size=(4, 500, 2))},
dims={"obs": ["new_school"]},
coords={"new_school": ["Essex College", "Moordale"]},
)
idata["predictions"] = ds
idata
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (chain: 4, draw: 500, school: 8, school_bis: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ * school_bis (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'
│ Data variables:
│ mu (chain, draw) float64 16kB 1.716 1.903 ... 7.721 10.24
│ theta (chain, draw, school) float64 128kB 2.317 1.45 ... 14.02
│ tau (chain, draw) float64 16kB 0.8775 0.8027 ... 2.99 3.052
│ log_tau (chain, draw) float64 16kB -0.1307 -0.2198 ... 1.116
│ mlogtau (chain, draw) float64 16kB nan nan nan ... 1.335 1.335
│ theta_school_diff (chain, draw, school, school_bis) float64 1MB 0.0 ... 0.0
│ Attributes: (6)
├── Group: /posterior_predictive
│ Dimensions: (chain: 4, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 128kB 38.88 -14.98 ... 27.05 20.99
│ Attributes: (4)
├── Group: /log_likelihood
│ Dimensions: (chain: 4, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 128kB -5.093 -3.436 ... -3.269 -3.816
│ Attributes: (4)
├── Group: /sample_stats
│ Dimensions: (chain: 4, draw: 500)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
│ Data variables: (12/17)
│ step_size (chain, draw) float64 16kB 0.1427 0.1427 ... 0.1233
│ reached_max_treedepth (chain, draw) bool 2kB False False ... False False
│ perf_counter_start (chain, draw) float64 16kB 3.931e+04 ... 3.931e+04
│ energy_error (chain, draw) float64 16kB 1.896 -1.479 ... 0.1372
│ perf_counter_diff (chain, draw) float64 16kB 0.0004726 ... 0.001483
│ tree_depth (chain, draw) int64 16kB 2 3 3 3 3 5 ... 4 4 4 4 4 4
│ ... ...
│ index_in_trajectory (chain, draw) int64 16kB 1 -1 0 0 1 ... -6 9 13 7 -15
│ acceptance_rate (chain, draw) float64 16kB 0.05665 0.1429 ... 0.8901
│ smallest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan
│ energy (chain, draw) float64 16kB 47.98 49.69 ... 60.07
│ diverging (chain, draw) bool 2kB False False ... False False
│ largest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan
│ Attributes: (6)
├── Group: /prior
│ Dimensions: (chain: 1, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ theta (chain, draw, school) float64 32kB -8.435 24.12 ... 54.57 52.29
│ tau (chain, draw) float64 4kB 11.93 17.76 4.732 ... 2.231 3.319 93.69
│ mu (chain, draw) float64 4kB 4.714 3.853 1.709 ... -2.245 -2.435
│ Attributes: (4)
├── Group: /prior_predictive
│ Dimensions: (chain: 1, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 32kB 22.03 26.95 ... 58.23 39.78
│ Attributes: (4)
├── Group: /observed_data
│ Dimensions: (school: 8)
│ Coordinates:
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (school) float64 64B 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
│ Attributes: (4)
├── Group: /constant_data
│ Dimensions: (school: 8)
│ Coordinates:
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ sigma (school) float64 64B 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
│ Attributes: (4)
└── Group: /predictions
Dimensions: (chain: 4, draw: 500, new_school: 2)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* new_school (new_school) <U13 104B 'Essex College' 'Moordale'
Data variables:
obs (chain, draw, new_school) float64 32kB 2.041 -2.556 ... -0.2822
Attributes: (5)Add Transformations to Multiple Groups#
You can calculate transformations to multiple datatree.DataTree groups using xarray.DataTree.map_over_datasets. It takes a function as an input and applies the function groupwise to the groups and overwrites the group with the result of the function.
Normally, you would need to apply the transformation to selected groups, the functions filter or match are used for that purpose.
selected_groups = ("posterior", "prior")
def calc_mean(dataset):
# filter will keep also empty nodes if needed to support leaves
# applying function to execute only for nodes with a dataset
if dataset:
result = dataset.mean(dim="chain")
return result
means = idata.filter(lambda node: node.name in selected_groups).map_over_datasets(calc_mean)
means
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (draw: 500, school: 8, school_bis: 8)
│ Coordinates:
│ * draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ * school_bis (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'
│ Data variables:
│ mu (draw) float64 4kB 2.083 2.353 2.475 ... 4.46 5.277 7.085
│ theta (draw, school) float64 32kB 7.452 2.307 ... 10.9 8.16
│ tau (draw) float64 4kB 6.268 3.318 3.69 ... 5.005 6.214 5.676
│ log_tau (draw) float64 4kB 1.486 0.9877 1.061 ... 1.737 1.671
│ mlogtau (draw) float64 4kB nan nan nan nan ... 1.441 1.454 1.467
│ theta_school_diff (draw, school, school_bis) float64 256kB 0.0 ... 0.0
│ Attributes: (6)
└── Group: /prior
Dimensions: (draw: 500, school: 8)
Coordinates:
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
Data variables:
theta (draw, school) float64 32kB -8.435 24.12 -2.102 ... 54.57 52.29
tau (draw) float64 4kB 11.93 17.76 4.732 2.161 ... 2.231 3.319 93.69
mu (draw) float64 4kB 4.714 3.853 1.709 1.59 ... -2.329 -2.245 -2.435
Attributes: (4)Note
Using xarray.DataTree.map_over_datasets will create a new DataTree with only the filtered groups. To get the DataTree with the changes, you need to use update.
new_data.update(idata)
idata_with_means = idata.copy()
idata_with_means.update(means)
idata_with_means
<xarray.DataTree>
Group: /
├── Group: /posterior
│ Dimensions: (draw: 500, school: 8, school_bis: 8)
│ Coordinates:
│ * draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ * school_bis (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'
│ Data variables:
│ mu (draw) float64 4kB 2.083 2.353 2.475 ... 4.46 5.277 7.085
│ theta (draw, school) float64 32kB 7.452 2.307 ... 10.9 8.16
│ tau (draw) float64 4kB 6.268 3.318 3.69 ... 5.005 6.214 5.676
│ log_tau (draw) float64 4kB 1.486 0.9877 1.061 ... 1.737 1.671
│ mlogtau (draw) float64 4kB nan nan nan nan ... 1.441 1.454 1.467
│ theta_school_diff (draw, school, school_bis) float64 256kB 0.0 ... 0.0
│ Attributes: (6)
├── Group: /posterior_predictive
│ Dimensions: (chain: 4, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 128kB 38.88 -14.98 ... 27.05 20.99
│ Attributes: (4)
├── Group: /log_likelihood
│ Dimensions: (chain: 4, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 128kB -5.093 -3.436 ... -3.269 -3.816
│ Attributes: (4)
├── Group: /sample_stats
│ Dimensions: (chain: 4, draw: 500)
│ Coordinates:
│ * chain (chain) int64 32B 0 1 2 3
│ * draw (draw) int64 4kB 0 1 2 3 4 5 ... 495 496 497 498 499
│ Data variables: (12/17)
│ step_size (chain, draw) float64 16kB 0.1427 0.1427 ... 0.1233
│ reached_max_treedepth (chain, draw) bool 2kB False False ... False False
│ perf_counter_start (chain, draw) float64 16kB 3.931e+04 ... 3.931e+04
│ energy_error (chain, draw) float64 16kB 1.896 -1.479 ... 0.1372
│ perf_counter_diff (chain, draw) float64 16kB 0.0004726 ... 0.001483
│ tree_depth (chain, draw) int64 16kB 2 3 3 3 3 5 ... 4 4 4 4 4 4
│ ... ...
│ index_in_trajectory (chain, draw) int64 16kB 1 -1 0 0 1 ... -6 9 13 7 -15
│ acceptance_rate (chain, draw) float64 16kB 0.05665 0.1429 ... 0.8901
│ smallest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan
│ energy (chain, draw) float64 16kB 47.98 49.69 ... 60.07
│ diverging (chain, draw) bool 2kB False False ... False False
│ largest_eigval (chain, draw) float64 16kB nan nan nan ... nan nan
│ Attributes: (6)
├── Group: /prior
│ Dimensions: (draw: 500, school: 8)
│ Coordinates:
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ theta (draw, school) float64 32kB -8.435 24.12 -2.102 ... 54.57 52.29
│ tau (draw) float64 4kB 11.93 17.76 4.732 2.161 ... 2.231 3.319 93.69
│ mu (draw) float64 4kB 4.714 3.853 1.709 1.59 ... -2.329 -2.245 -2.435
│ Attributes: (4)
├── Group: /prior_predictive
│ Dimensions: (chain: 1, draw: 500, school: 8)
│ Coordinates:
│ * chain (chain) int64 8B 0
│ * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (chain, draw, school) float64 32kB 22.03 26.95 ... 58.23 39.78
│ Attributes: (4)
├── Group: /observed_data
│ Dimensions: (school: 8)
│ Coordinates:
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ obs (school) float64 64B 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
│ Attributes: (4)
├── Group: /constant_data
│ Dimensions: (school: 8)
│ Coordinates:
│ * school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
│ Data variables:
│ sigma (school) float64 64B 15.0 10.0 16.0 11.0 9.0 11.0 10.0 18.0
│ Attributes: (4)
└── Group: /predictions
Dimensions: (chain: 4, draw: 500, new_school: 2)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* new_school (new_school) <U13 104B 'Essex College' 'Moordale'
Data variables:
obs (chain, draw, new_school) float64 32kB 2.041 -2.556 ... -0.2822
Attributes: (5)You can also pass a lambda function in map_over_datasets.
idata_shifted_obs = idata.match("posterior").map_over_datasets(lambda x: x + 3)
idata_shifted_obs
<xarray.DataTree>
Group: /
└── Group: /posterior
Dimensions: (chain: 4, draw: 500, school: 8, school_bis: 8)
Coordinates:
* chain (chain) int64 32B 0 1 2 3
* draw (draw) int64 4kB 0 1 2 3 4 5 ... 494 495 496 497 498 499
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
* school_bis (school_bis) <U16 512B 'Choate' ... 'Mt. Hermon'
Data variables:
mu (chain, draw) float64 16kB 4.716 4.903 ... 10.72 13.24
theta (chain, draw, school) float64 128kB 5.317 4.45 ... 17.02
tau (chain, draw) float64 16kB 3.877 3.803 ... 5.99 6.052
log_tau (chain, draw) float64 16kB 2.869 2.78 ... 4.095 4.116
mlogtau (chain, draw) float64 16kB nan nan nan ... 4.335 4.335
theta_school_diff (chain, draw, school, school_bis) float64 1MB 3.0 ... 3.0
Attributes: (6)You can also add extra coordinates using map_over_datasets.
_upper = np.array([x.upper() for x in idata.observed_data.school.values]).T
idata_with_upper = idata.match("observed_data").map_over_datasets(
lambda ds, coords: ds.assign_coords(coords) if ds else ds, {"Upper": _upper}
)
idata_with_upper
<xarray.DataTree>
Group: /
└── Group: /observed_data
Dimensions: (school: 8, Upper: 8)
Coordinates:
* school (school) <U16 512B 'Choate' 'Deerfield' ... 'Mt. Hermon'
* Upper (Upper) <U16 512B 'CHOATE' 'DEERFIELD' ... 'MT. HERMON'
Data variables:
obs (school) float64 64B 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0
Attributes: (4)%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Tue, 21 Apr 2026
Python implementation: CPython
Python version : 3.13.7
IPython version : 9.3.0
arviz_base: 1.0.1.dev0
numpy : 2.3.0
xarray : 2026.2.0
Watermark: 2.6.0