From e1420da5eab094b38e25c448291023b1c9dcc7ce Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 15:40:34 -0500 Subject: [PATCH 1/2] Add dispatches for `Real` Op --- pytensor/link/jax/dispatch/scalar.py | 9 ++++++++ pytensor/link/mlx/dispatch/scalar.py | 9 ++++++++ pytensor/link/numba/dispatch/scalar.py | 10 +++++++++ pytensor/link/pytorch/dispatch/scalar.py | 9 ++++++++ tests/link/jax/test_scalar.py | 11 ++++++++++ tests/link/mlx/test_scalar.py | 11 ++++++++++ tests/link/numba/test_scalar.py | 11 ++++++++++ tests/link/pytorch/test_elemwise.py | 28 ++++++++++++++++++++++++ tests/tensor/test_math.py | 19 ++++++++++++++++ 9 files changed, 117 insertions(+) diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 44764764f2..2d647ef429 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -16,6 +16,7 @@ IntDiv, Mod, Mul, + Real, ScalarOp, Second, Sub, @@ -330,3 +331,11 @@ def softplus(x): ) return softplus + + +@jax_funcify.register(Real) +def jax_funcify_Real(op, node, **kwargs): + def real(x): + return jnp.real(x) + + return real diff --git a/pytensor/link/mlx/dispatch/scalar.py b/pytensor/link/mlx/dispatch/scalar.py index 83fd858546..638b4b33f1 100644 --- a/pytensor/link/mlx/dispatch/scalar.py +++ b/pytensor/link/mlx/dispatch/scalar.py @@ -6,6 +6,7 @@ Composite, Identity, Mod, + Real, ScalarOp, Second, ) @@ -178,3 +179,11 @@ def log1mexp(x): @mlx_funcify.register(Composite) def mlx_funcify_Composite(op, node=None, **kwargs): return mlx_funcify(op.fgraph, squeeze_output=True) + + +@mlx_funcify.register(Real) +def mlx_funcify_Real(op, node, **kwargs): + def real(x): + return mx.real(x) + + return real diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 777b4d5a6c..62cdc76218 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -24,6 +24,7 @@ Identity, Mul, Pow, + Real, Reciprocal, ScalarOp, Second, @@ -412,3 +413,12 @@ def for_loop(n_steps, *inputs): return carry return for_loop, loop_cache_key + + +@register_funcify_and_cache_key(Real) +def numba_funcify_Real(op, node, **kwargs): + @numba_basic.numba_njit + def real(x): + return np.real(x) + + return real, scalar_op_cache_key(op) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index a64dcf23ba..c7626c632f 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -7,6 +7,7 @@ Cast, Clip, Invert, + Real, ScalarOp, ) from pytensor.scalar.loop import ScalarLoop @@ -112,3 +113,11 @@ def scalar_loop(steps, *start_and_constants): return carry return scalar_loop + + +@pytorch_funcify.register(Real) +def pytorch_funcify_Real(op, node, **kwargs): + def real(x): + return torch.real(x) + + return real diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index ebbf5e7bf4..626d6c8a6d 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -3,6 +3,7 @@ import pytensor.scalar.basic as ps import pytensor.tensor as pt +from pytensor.compile.io import In from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.scalar.basic import Composite @@ -324,3 +325,13 @@ def test_jax_logp(): value_test_value, ], ) + + +def test_jax_real(): + x = pt.zvector("x") + out = pt.real(x)[0].set(99.0) + x_val = np.array([1 + 2j, 3 + 4j]) + _, output = compare_jax_and_py([In(x, mutable=True)], [out], [x_val]) + + # Verify that the real Op does not return a view, resulting in mutation of the input + assert output[0].item(0) != x_val.real.item(0) diff --git a/tests/link/mlx/test_scalar.py b/tests/link/mlx/test_scalar.py index 56a940251d..aaada88dfc 100644 --- a/tests/link/mlx/test_scalar.py +++ b/tests/link/mlx/test_scalar.py @@ -3,6 +3,7 @@ import pytensor.scalar.basic as ps import pytensor.tensor as pt +from pytensor.compile.io import In from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.scalar.basic import Composite @@ -238,3 +239,13 @@ def test_mlx_logp(): value_test_value, ], ) + + +def test_mlx_real(): + x = pt.tensor("x", dtype="complex64", shape=(None,)) + out = pt.real(x)[0].set(np.float32(99.0)) + x_val = np.array([1 + 2j, 3 + 4j], dtype="complex64") + _, output = compare_mlx_and_py([In(x, mutable=True)], [out], [x_val]) + + # Verify that the real Op does not return a view, resulting in mutation of the input + assert output[0][0].item() != x_val.real[0].item() diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 040405cb51..9bcf446530 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -7,6 +7,7 @@ import pytensor.scalar.math as psm import pytensor.tensor as pt from pytensor import config, function +from pytensor.compile.io import In from pytensor.graph import Apply from pytensor.scalar import ScalarLoop, UnaryScalarOp from pytensor.scalar.basic import Composite @@ -312,3 +313,13 @@ def test_loop_with_cython_wrapped_op(self): res = fn(x_test) expected_res = ps.psi(x).eval({x: x_test}) np.testing.assert_allclose(res, expected_res) + + +def test_numba_real(): + x = pt.zvector("x") + out = pt.real(x)[0].set(99.0) + x_val = np.array([1 + 2j, 3 + 4j]) + _, output = compare_numba_and_py([In(x, mutable=True)], [out], [x_val]) + + # Verify that the real Op does not return a view, resulting in mutation of the input + assert output[0].item(0) != x_val.real.item(0) diff --git a/tests/link/pytorch/test_elemwise.py b/tests/link/pytorch/test_elemwise.py index 1db6c67e35..76853e9391 100644 --- a/tests/link/pytorch/test_elemwise.py +++ b/tests/link/pytorch/test_elemwise.py @@ -192,3 +192,31 @@ def relu(row): vals = torch.zeros(2, 3).normal_() np.testing.assert_allclose(f(vals), torch.relu(vals)) assert op.call_shapes == [torch.Size([])], op.call_shapes + + +def test_Real(): + x = pt.zvector("x") + out = pt.real(x) + compare_pytorch_and_py([x], [out], [np.array([1 + 2j, 3 + 4j])]) + + +def test_Real_no_input_mutation(): + """Regression test: real() must not return a view that lets inplace ops corrupt the input. + + numpy.real returns a view, so if the backend dispatch also returns a view + and the inplace optimization fires on a downstream SetSubtensor, the + original mutable input can be silently corrupted. + """ + from pytensor import In, function + from tests.link.pytorch.test_basic import pytorch_mode + + x = pt.zvector("x") + out = pt.real(x)[0].set(99.0) + f = function([In(x, mutable=True)], out, mode=pytorch_mode) + + test_input = np.array([1 + 2j, 3 + 4j]) + original = test_input.copy() + result = f(test_input) + + np.testing.assert_allclose(result, [99.0, 3.0]) + np.testing.assert_array_equal(test_input, original) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 77e4a72723..9ab2119b39 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -2897,6 +2897,25 @@ def test_real_imag(self): assert_array_equal(Z.real.eval({Z: z}), x) assert_array_equal(Z.imag.eval({Z: z}), y) + def test_real_no_input_mutation(self): + """Regression test: real() must not return a view that lets inplace ops corrupt the input. + + numpy.real returns a view, so Elemwise.perform copies non-owned data. + This test ensures that pattern stays safe with mutable inputs. + """ + from pytensor import In, function + + Z = zvector("z") + out = Z.real[0].set(99.0) + f = function([In(Z, mutable=True)], out) + + z = np.array([1 + 2j, 3 + 4j]) + original = z.copy() + result = f(z) + + np.testing.assert_allclose(result, [99.0, 3.0]) + np.testing.assert_array_equal(z, original) + def test_conj(self): X, Y = self.vars x, y = self.vals From 06f767753be2a4191d99fc828c3b4cad728a06d0 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 29 Mar 2026 18:11:14 -0500 Subject: [PATCH 2/2] compare_xxx_and_py accepts `In` objects as graph_inputs --- tests/link/jax/test_basic.py | 7 ++++++- tests/link/mlx/test_basic.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index c80bd2a1e3..06551ba9d5 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from pytensor.compile import SymbolicInput from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function from pytensor.compile.mode import JAX, Mode @@ -71,7 +72,11 @@ def compare_jax_and_py( if assert_fn is None: assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) - if any(inp.owner is not None for inp in graph_inputs): + if any( + inp.owner is not None + for inp in graph_inputs + if not isinstance(inp, SymbolicInput) + ): raise ValueError("Inputs must be root variables") pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode) diff --git a/tests/link/mlx/test_basic.py b/tests/link/mlx/test_basic.py index 2e36496edd..43a2ccc9bb 100644 --- a/tests/link/mlx/test_basic.py +++ b/tests/link/mlx/test_basic.py @@ -11,6 +11,7 @@ import pytensor from pytensor import config from pytensor import tensor as pt +from pytensor.compile import SymbolicInput from pytensor.compile.function import function from pytensor.compile.mode import MLX, Mode from pytensor.graph import RewriteDatabaseQuery @@ -67,7 +68,11 @@ def compare_mlx_and_py( if assert_fn is None: assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) - if any(inp.owner is not None for inp in graph_inputs): + if any( + inp.owner is not None + for inp in graph_inputs + if not isinstance(inp, SymbolicInput) + ): raise ValueError("Inputs must be root variables") pytensor_mlx_fn = function(graph_inputs, graph_outputs, mode=mlx_mode)