feat(mlx): pt.random support with mlx backend#1979
feat(mlx): pt.random support with mlx backend#1979williambdean wants to merge 5 commits intopymc-devs:v3from
Conversation
b620bb1 to
d602268
Compare
ricardoV94
left a comment
There was a problem hiding this comment.
missing rng outputs /updates (so consecutive calls get updated rng)
There should be tests in numba/jax you can use as template. Jax is going to be more similar
| thunk_inputs = [] | ||
| for n in self.fgraph.inputs: | ||
| sinput = storage_map[n] | ||
| if isinstance(sinput[0], Generator): |
There was a problem hiding this comment.
you need to do the same dance jax linker does with shared Generator variables
|
#2010 caused conflicts for this PR. You will need to rebase. |
e6f7371 to
0b4fb85
Compare
| def sample_fn(rng_key, size, dtype, p): | ||
| p = mx.array(p) | ||
| if size is None: | ||
| shape = p.shape |
There was a problem hiding this comment.
you always need the shape? You didn't need it in the categorical. I would assume you only need when one of the parameters doesn't go in the random function. If so that would take a lot of boilerplate away from your dispatches
| return sample_fn | ||
|
|
||
|
|
||
| @mlx_sample_fn.register(ptr.MvNormalRV) |
There was a problem hiding this comment.
MvNormal supports different decomposition strategies, you may want to implement like numba dispatch/op.perform which is more low level if mx.random.multivariate_normal doesn't support them. Or if it's unfeasible issue a warning that it isn't respected and will fallback to svd (if it wasn't svd to begin with)
| if batch_ndim: | ||
| raise NotImplementedError( | ||
| "MLX random.permutation does not support batch dimensions." | ||
| ) |
There was a problem hiding this comment.
raise at dispatch time already
Description
Basic support for
mlxrandom generation.They have limited support. Missing Gamma distribution. Could support additional ones
with basic transformations. i.e.
pt.abs(pt.random.normal(...))~ Half NormalMLX Reference: https://ml-explore.github.io/mlx/build/html/python/random.html
Related Issue
Checklist
Type of change