Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):

MLX = Mode(
MLXLinker(),
RewriteDatabaseQuery(include=["fast_run"]),
RewriteDatabaseQuery(include=["fast_run", "mlx"]),
)

FAST_COMPILE = Mode(
Expand Down
1 change: 1 addition & 0 deletions pytensor/link/mlx/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytensor.link.mlx.dispatch.shape
import pytensor.link.mlx.dispatch.subtensor
import pytensor.link.mlx.dispatch.tensor_basic
import pytensor.link.mlx.dispatch.einsum
import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
Expand Down
24 changes: 24 additions & 0 deletions pytensor/link/mlx/dispatch/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.einsum import AbstractEinsum, Einsum


@mlx_funcify.register(Einsum)
def mlx_funcify_Einsum(op, **kwargs):
subscripts = op.subscripts

def einsum(*operands):
return mx.einsum(subscripts, *operands)

return einsum


@mlx_funcify.register(AbstractEinsum)
def mlx_funcify_AbstractEinsum(op, **kwargs):
subscripts = op.subscripts

def einsum(*operands):
return mx.einsum(subscripts, *operands)
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are both going to mx einsum?

Did you mean one goes through funcify ofg and the other mx einsum? Do you need a rewrite/ separate op for that? can you anslyze the subscripts and decide here?


return einsum
36 changes: 36 additions & 0 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple

from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.scalar.basic import upcast
from pytensor.tensor import TensorLike
from pytensor.tensor.basic import (
arange,
Expand All @@ -28,13 +31,46 @@
from pytensor.tensor.functional import vectorize
from pytensor.tensor.math import and_, eq, tensordot
from pytensor.tensor.shape import shape_padright
from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import TensorVariable


PATH = tuple[tuple[int] | tuple[int, int], ...]
CONTRACTION_STEP = tuple[tuple[int, ...], set[str], str]


class AbstractEinsum(Op):
"""Thin einsum Op that holds only the subscript string.

Unlike :class:`Einsum` (an ``OpFromGraph``), this Op has no inner graph.
Backends that natively support einsum (e.g. MLX, JAX) can dispatch it
directly to their own ``einsum`` implementation, avoiding decomposition
into lower-level ops that may not be supported.

``perform`` falls back to :func:`numpy.einsum` so the Op is always
executable on the Python backend.
"""

__props__ = ("subscripts", "out_ndim")

def __init__(self, subscripts: str, out_ndim: int):
self.subscripts = subscripts
self.out_ndim = out_ndim
super().__init__()

def make_node(self, *operands):
operands = [as_tensor(op) for op in operands]
dtype = upcast(*[op.dtype for op in operands])
out_type = TensorType(dtype=dtype, shape=(None,) * self.out_ndim)
return Apply(self, list(operands), [out_type()])

def perform(self, node, inputs, output_storage):
output_storage[0][0] = np.einsum(self.subscripts, *inputs)

def __str__(self):
return f"AbstractEinsum{{{self.subscripts}}}"


class Einsum(OpFromGraph):
"""
Wrapper Op for Einsum graphs
Expand Down
29 changes: 27 additions & 2 deletions pytensor/tensor/rewriting/einsum.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import cast

from pytensor.compile import optdb
from pytensor.graph import Apply, FunctionGraph, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.tensor.einsum import Einsum, einsum
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.tensor.einsum import AbstractEinsum, Einsum, einsum
from pytensor.tensor.rewriting.basic import register_specialize
from pytensor.tensor.rewriting.ofg import inline_ofg_node
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -51,3 +52,27 @@ def inline_optimized_einsum(
return None

return cast(list[TensorVariable], inline_ofg_node(node))


@node_rewriter([Einsum])
def einsum_to_abstract(
fgraph: FunctionGraph, node: Apply
) -> list[TensorVariable] | None:
"""Replace ``Einsum`` with ``AbstractEinsum``.

Backends that natively support einsum can dispatch ``AbstractEinsum`` to its native implementation,
rather than using the OpFromGraph defined by Pytensor.
"""
op: Einsum = node.op
out_ndim = node.outputs[0].ndim
new_out = AbstractEinsum(subscripts=op.subscripts, out_ndim=out_ndim)(*node.inputs)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


optdb.register(
"einsum_to_abstract",
dfs_rewriter(einsum_to_abstract),
"mlx",
position=1.9, # Before specialize (2.0) which inlines the Einsum OFG
)
54 changes: 54 additions & 0 deletions tests/link/mlx/test_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import pytest

import pytensor.tensor as pt
from tests.link.mlx.test_basic import compare_mlx_and_py


mx = pytest.importorskip("mlx.core")


def test_mlx_einsum():
subscripts = "ij, jk, kl -> il"
x = np.random.rand(3, 5)
y = np.random.rand(5, 2)
z = np.random.rand(2, 4)

shapes = {
"x": (3, 5),
"y": (5, 2),
"z": (2, 4),
}
x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items())
out = pt.einsum(subscripts, x_pt, y_pt, z_pt)
compare_mlx_and_py([x_pt, y_pt, z_pt], [out], [x, y, z])


def test_ellipsis_einsum():
subscripts = "...i,...i->..."
x = np.random.rand(2, 5)
y = np.random.rand(2, 5)

x_pt = pt.tensor("x", shape=x.shape)
y_pt = pt.tensor("y", shape=y.shape)
out = pt.einsum(subscripts, x_pt, y_pt)
compare_mlx_and_py([x_pt, y_pt], [out], [x, y])


def test_einsum_trace():
subscripts = "ii->"
x_pt = pt.matrix("x")
x_val = np.random.rand(5, 5)
out = pt.einsum(subscripts, x_pt)
compare_mlx_and_py([x_pt], [out], [x_val])


def test_einsum_batched_outer_product():
a = pt.matrix("a", dtype="float32")
b = pt.matrix("b", dtype="float32")
out = pt.einsum("bi,bj->bij", a, b)

a_val = np.random.normal(size=(5, 3)).astype("float32")
b_val = np.random.normal(size=(5, 2)).astype("float32")

compare_mlx_and_py([a, b], [out], [a_val, b_val])
Loading