Skip to content

Commas in variable names cause idata writing to fail if combined with transforms #284

@velochy

Description

@velochy

Reproducible example

import numpy as np
import pymc as pm

coords = {
    "age_group": [1, 2],
    "education": ["basic", "secondary", "higher"],
    "language_outp": ["et", "other"],
}

with pm.Model(coords=coords):
    x = pm.Normal(
        "alpha_age_group,education_offset_language",
        mu=0.0,
        sigma=1.0,
        dims=("age_group", "education", "language_outp"),
        transform=pm.distributions.transforms.ZeroSumTransform([0, 1]),
    )
    pm.Normal("y", mu=x[0, 0, 0], sigma=1.0, observed=np.array(0.0))

    pm.sample(
        draws=20,
        tune=20,
        chains=1,
        cores=1,
        progressbar=False,
        nuts_sampler="nutpie",
    )

fails with

Traceback (most recent call last):
  File "/home/velochy/salk/salk_internal_package/one_offs/replicate/nutpie_comma_dim_bug_repro.py", line 22, in <module>
    pm.sample(
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 832, in sample
    return _sample_external_nuts(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 377, in _sample_external_nuts
    idata = nutpie.sample(
            ^^^^^^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/nutpie/sample.py", line 865, in sample
    result = sampler.wait()
             ^^^^^^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/nutpie/sample.py", line 528, in wait
    return self._extract(results)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/nutpie/sample.py", line 573, in _extract
    return _arrow_to_arviz(
           ^^^^^^^^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/nutpie/sample.py", line 99, in _arrow_to_arviz
    return arviz.from_dict(
           ^^^^^^^^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/arviz/data/io_dict.py", line 459, in from_dict
    ).to_inference_data()
      ^^^^^^^^^^^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/arviz/data/io_dict.py", line 334, in to_inference_data
    "posterior": self.posterior_to_xarray(),
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/arviz/data/base.py", line 71, in wrapped
    return func(cls)
           ^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/arviz/data/io_dict.py", line 97, in posterior_to_xarray
    dict_to_dataset(
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/arviz/data/base.py", line 406, in dict_to_dataset
    key: numpy_to_data_array(
         ^^^^^^^^^^^^^^^^^^^^
  File "/home/velochy/miniconda3/envs/salk/lib/python3.12/site-packages/arviz/data/base.py", line 305, in numpy_to_data_array
    coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
                                                            ~~~~~~^^^^^
KeyError: 'education_offset_language_zerosum___dim_1'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions