Skip to content

Add einsum dispatch for mlx#2017

Open
jessegrabowski wants to merge 1 commit intopymc-devs:v3from
jessegrabowski:mlx-einsum
Open

Add einsum dispatch for mlx#2017
jessegrabowski wants to merge 1 commit intopymc-devs:v3from
jessegrabowski:mlx-einsum

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

MLX can't handle some of the behaviors of our einsum OpFromGraph, so I added a rewrite to a dummy AbstractEinsum Op and dispatch on that. It's mostly a corner case (it can't do the trace pattern, "ii->", because it uses a symbolic Arange ). If my solution is too cute I can remove it and we can just error on the unsupported case. I benchmarked native mx.einsum against the OpFromGraph approach and they're pretty similar (OFG is a bit faster, but like 5%).

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Mar 30, 2026

i don't get it. you can dispatch on einsum Op directly (and you did). direct dispatches on subclasses (Einsum) take precedence over dispatches on parent class (OFG)

subscripts = op.subscripts

def einsum(*operands):
return mx.einsum(subscripts, *operands)
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are both going to mx einsum?

Did you mean one goes through funcify ofg and the other mx einsum? Do you need a rewrite/ separate op for that? can you anslyze the subscripts and decide here?

@jessegrabowski
Copy link
Copy Markdown
Member Author

i don't get it. you can dispatch on einsum Op directly (and you did). direct dispatches on subclasses (Einsum) take precedence over dispatches on parent class (OFG)

Oh nice, I didn't think of that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants