diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index aa0e1b3e28..26e47be278 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -467,7 +467,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): MLX = Mode( MLXLinker(), - RewriteDatabaseQuery(include=["fast_run"]), + RewriteDatabaseQuery(include=["fast_run", "mlx"]), ) FAST_COMPILE = Mode( diff --git a/pytensor/link/mlx/dispatch/__init__.py b/pytensor/link/mlx/dispatch/__init__.py index ac59f1809c..8b6447a1fd 100644 --- a/pytensor/link/mlx/dispatch/__init__.py +++ b/pytensor/link/mlx/dispatch/__init__.py @@ -9,6 +9,7 @@ import pytensor.link.mlx.dispatch.shape import pytensor.link.mlx.dispatch.subtensor import pytensor.link.mlx.dispatch.tensor_basic +import pytensor.link.mlx.dispatch.einsum import pytensor.link.mlx.dispatch.signal import pytensor.link.mlx.dispatch.signal.conv import pytensor.link.mlx.dispatch.blockwise diff --git a/pytensor/link/mlx/dispatch/einsum.py b/pytensor/link/mlx/dispatch/einsum.py new file mode 100644 index 0000000000..7264455689 --- /dev/null +++ b/pytensor/link/mlx/dispatch/einsum.py @@ -0,0 +1,24 @@ +import mlx.core as mx + +from pytensor.link.mlx.dispatch import mlx_funcify +from pytensor.tensor.einsum import AbstractEinsum, Einsum + + +@mlx_funcify.register(Einsum) +def mlx_funcify_Einsum(op, **kwargs): + subscripts = op.subscripts + + def einsum(*operands): + return mx.einsum(subscripts, *operands) + + return einsum + + +@mlx_funcify.register(AbstractEinsum) +def mlx_funcify_AbstractEinsum(op, **kwargs): + subscripts = op.subscripts + + def einsum(*operands): + return mx.einsum(subscripts, *operands) + + return einsum diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index a6d5a358f1..c4b1aecc2d 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -13,6 +13,9 @@ from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple from pytensor.compile.builders import OpFromGraph +from pytensor.graph.basic import Apply +from pytensor.graph.op import Op +from pytensor.scalar.basic import upcast from pytensor.tensor import TensorLike from pytensor.tensor.basic import ( arange, @@ -28,6 +31,7 @@ from pytensor.tensor.functional import vectorize from pytensor.tensor.math import and_, eq, tensordot from pytensor.tensor.shape import shape_padright +from pytensor.tensor.type import TensorType from pytensor.tensor.variable import TensorVariable @@ -35,6 +39,38 @@ CONTRACTION_STEP = tuple[tuple[int, ...], set[str], str] +class AbstractEinsum(Op): + """Thin einsum Op that holds only the subscript string. + + Unlike :class:`Einsum` (an ``OpFromGraph``), this Op has no inner graph. + Backends that natively support einsum (e.g. MLX, JAX) can dispatch it + directly to their own ``einsum`` implementation, avoiding decomposition + into lower-level ops that may not be supported. + + ``perform`` falls back to :func:`numpy.einsum` so the Op is always + executable on the Python backend. + """ + + __props__ = ("subscripts", "out_ndim") + + def __init__(self, subscripts: str, out_ndim: int): + self.subscripts = subscripts + self.out_ndim = out_ndim + super().__init__() + + def make_node(self, *operands): + operands = [as_tensor(op) for op in operands] + dtype = upcast(*[op.dtype for op in operands]) + out_type = TensorType(dtype=dtype, shape=(None,) * self.out_ndim) + return Apply(self, list(operands), [out_type()]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.einsum(self.subscripts, *inputs) + + def __str__(self): + return f"AbstractEinsum{{{self.subscripts}}}" + + class Einsum(OpFromGraph): """ Wrapper Op for Einsum graphs diff --git a/pytensor/tensor/rewriting/einsum.py b/pytensor/tensor/rewriting/einsum.py index 5e9fe2d026..c2c1fc95f0 100644 --- a/pytensor/tensor/rewriting/einsum.py +++ b/pytensor/tensor/rewriting/einsum.py @@ -1,8 +1,9 @@ from typing import cast +from pytensor.compile import optdb from pytensor.graph import Apply, FunctionGraph, node_rewriter -from pytensor.graph.rewriting.basic import copy_stack_trace -from pytensor.tensor.einsum import Einsum, einsum +from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter +from pytensor.tensor.einsum import AbstractEinsum, Einsum, einsum from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.ofg import inline_ofg_node from pytensor.tensor.variable import TensorVariable @@ -51,3 +52,27 @@ def inline_optimized_einsum( return None return cast(list[TensorVariable], inline_ofg_node(node)) + + +@node_rewriter([Einsum]) +def einsum_to_abstract( + fgraph: FunctionGraph, node: Apply +) -> list[TensorVariable] | None: + """Replace ``Einsum`` with ``AbstractEinsum``. + + Backends that natively support einsum can dispatch ``AbstractEinsum`` to its native implementation, + rather than using the OpFromGraph defined by Pytensor. + """ + op: Einsum = node.op + out_ndim = node.outputs[0].ndim + new_out = AbstractEinsum(subscripts=op.subscripts, out_ndim=out_ndim)(*node.inputs) + copy_stack_trace(node.outputs[0], new_out) + return [new_out] + + +optdb.register( + "einsum_to_abstract", + dfs_rewriter(einsum_to_abstract), + "mlx", + position=1.9, # Before specialize (2.0) which inlines the Einsum OFG +) diff --git a/tests/link/mlx/test_einsum.py b/tests/link/mlx/test_einsum.py new file mode 100644 index 0000000000..712045d492 --- /dev/null +++ b/tests/link/mlx/test_einsum.py @@ -0,0 +1,54 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from tests.link.mlx.test_basic import compare_mlx_and_py + + +mx = pytest.importorskip("mlx.core") + + +def test_mlx_einsum(): + subscripts = "ij, jk, kl -> il" + x = np.random.rand(3, 5) + y = np.random.rand(5, 2) + z = np.random.rand(2, 4) + + shapes = { + "x": (3, 5), + "y": (5, 2), + "z": (2, 4), + } + x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) + out = pt.einsum(subscripts, x_pt, y_pt, z_pt) + compare_mlx_and_py([x_pt, y_pt, z_pt], [out], [x, y, z]) + + +def test_ellipsis_einsum(): + subscripts = "...i,...i->..." + x = np.random.rand(2, 5) + y = np.random.rand(2, 5) + + x_pt = pt.tensor("x", shape=x.shape) + y_pt = pt.tensor("y", shape=y.shape) + out = pt.einsum(subscripts, x_pt, y_pt) + compare_mlx_and_py([x_pt, y_pt], [out], [x, y]) + + +def test_einsum_trace(): + subscripts = "ii->" + x_pt = pt.matrix("x") + x_val = np.random.rand(5, 5) + out = pt.einsum(subscripts, x_pt) + compare_mlx_and_py([x_pt], [out], [x_val]) + + +def test_einsum_batched_outer_product(): + a = pt.matrix("a", dtype="float32") + b = pt.matrix("b", dtype="float32") + out = pt.einsum("bi,bj->bij", a, b) + + a_val = np.random.normal(size=(5, 3)).astype("float32") + b_val = np.random.normal(size=(5, 2)).astype("float32") + + compare_mlx_and_py([a, b], [out], [a_val, b_val])