Skip to content

feat(mlx): pt.random support with mlx backend#1979

Open
williambdean wants to merge 5 commits intopymc-devs:v3from
williambdean:mlx-random
Open

feat(mlx): pt.random support with mlx backend#1979
williambdean wants to merge 5 commits intopymc-devs:v3from
williambdean:mlx-random

Conversation

@williambdean
Copy link
Copy Markdown
Contributor

Description

Basic support for mlx random generation.

import pytensor.tensor as pt

data = pt.random.normal(size=(52, 3))

# Works now
data.eval(mode="MLX")

They have limited support. Missing Gamma distribution. Could support additional ones
with basic transformations. i.e. pt.abs(pt.random.normal(...)) ~ Half Normal

MLX Reference: https://ml-explore.github.io/mlx/build/html/python/random.html

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@williambdean williambdean force-pushed the mlx-random branch 2 times, most recently from b620bb1 to d602268 Compare March 16, 2026 18:25
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing rng outputs /updates (so consecutive calls get updated rng)

There should be tests in numba/jax you can use as template. Jax is going to be more similar

thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], Generator):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to do the same dance jax linker does with shared Generator variables

@jessegrabowski
Copy link
Copy Markdown
Member

#2010 caused conflicts for this PR. You will need to rebase.

def sample_fn(rng_key, size, dtype, p):
p = mx.array(p)
if size is None:
shape = p.shape
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you always need the shape? You didn't need it in the categorical. I would assume you only need when one of the parameters doesn't go in the random function. If so that would take a lot of boilerplate away from your dispatches

return sample_fn


@mlx_sample_fn.register(ptr.MvNormalRV)
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MvNormal supports different decomposition strategies, you may want to implement like numba dispatch/op.perform which is more low level if mx.random.multivariate_normal doesn't support them. Or if it's unfeasible issue a warning that it isn't respected and will fallback to svd (if it wasn't svd to begin with)

if batch_ndim:
raise NotImplementedError(
"MLX random.permutation does not support batch dimensions."
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise at dispatch time already

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants