Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
IntDiv,
Mod,
Mul,
Real,
ScalarOp,
Second,
Sub,
Expand Down Expand Up @@ -330,3 +331,11 @@ def softplus(x):
)

return softplus


@jax_funcify.register(Real)
def jax_funcify_Real(op, node, **kwargs):
def real(x):
return jnp.real(x)

return real
9 changes: 9 additions & 0 deletions pytensor/link/mlx/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Composite,
Identity,
Mod,
Real,
ScalarOp,
Second,
)
Expand Down Expand Up @@ -178,3 +179,11 @@ def log1mexp(x):
@mlx_funcify.register(Composite)
def mlx_funcify_Composite(op, node=None, **kwargs):
return mlx_funcify(op.fgraph, squeeze_output=True)


@mlx_funcify.register(Real)
def mlx_funcify_Real(op, node, **kwargs):
def real(x):
return mx.real(x)

return real
10 changes: 10 additions & 0 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Identity,
Mul,
Pow,
Real,
Reciprocal,
ScalarOp,
Second,
Expand Down Expand Up @@ -412,3 +413,12 @@ def for_loop(n_steps, *inputs):
return carry

return for_loop, loop_cache_key


@register_funcify_and_cache_key(Real)
def numba_funcify_Real(op, node, **kwargs):
@numba_basic.numba_njit
def real(x):
return np.real(x)

return real, scalar_op_cache_key(op)
9 changes: 9 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Cast,
Clip,
Invert,
Real,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
Expand Down Expand Up @@ -112,3 +113,11 @@ def scalar_loop(steps, *start_and_constants):
return carry

return scalar_loop


@pytorch_funcify.register(Real)
def pytorch_funcify_Real(op, node, **kwargs):
def real(x):
return torch.real(x)

return real
7 changes: 6 additions & 1 deletion tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

from pytensor.compile import SymbolicInput
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import JAX, Mode
Expand Down Expand Up @@ -71,7 +72,11 @@ def compare_jax_and_py(
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)

if any(inp.owner is not None for inp in graph_inputs):
if any(
inp.owner is not None
for inp in graph_inputs
if not isinstance(inp, SymbolicInput)
):
raise ValueError("Inputs must be root variables")

pytensor_jax_fn = function(graph_inputs, graph_outputs, mode=jax_mode)
Expand Down
11 changes: 11 additions & 0 deletions tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytensor.scalar.basic as ps
import pytensor.tensor as pt
from pytensor.compile.io import In
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import Composite
Expand Down Expand Up @@ -324,3 +325,13 @@ def test_jax_logp():
value_test_value,
],
)


def test_jax_real():
x = pt.zvector("x")
out = pt.real(x)[0].set(99.0)
x_val = np.array([1 + 2j, 3 + 4j])
_, output = compare_jax_and_py([In(x, mutable=True)], [out], [x_val])

# Verify that the real Op does not return a view, resulting in mutation of the input
assert output[0].item(0) != x_val.real.item(0)
7 changes: 6 additions & 1 deletion tests/link/mlx/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytensor
from pytensor import config
from pytensor import tensor as pt
from pytensor.compile import SymbolicInput
from pytensor.compile.function import function
from pytensor.compile.mode import MLX, Mode
from pytensor.graph import RewriteDatabaseQuery
Expand Down Expand Up @@ -67,7 +68,11 @@ def compare_mlx_and_py(
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)

if any(inp.owner is not None for inp in graph_inputs):
if any(
inp.owner is not None
for inp in graph_inputs
if not isinstance(inp, SymbolicInput)
):
raise ValueError("Inputs must be root variables")

pytensor_mlx_fn = function(graph_inputs, graph_outputs, mode=mlx_mode)
Expand Down
11 changes: 11 additions & 0 deletions tests/link/mlx/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytensor.scalar.basic as ps
import pytensor.tensor as pt
from pytensor.compile.io import In
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import Composite
Expand Down Expand Up @@ -238,3 +239,13 @@ def test_mlx_logp():
value_test_value,
],
)


def test_mlx_real():
x = pt.tensor("x", dtype="complex64", shape=(None,))
out = pt.real(x)[0].set(np.float32(99.0))
x_val = np.array([1 + 2j, 3 + 4j], dtype="complex64")
_, output = compare_mlx_and_py([In(x, mutable=True)], [out], [x_val])

# Verify that the real Op does not return a view, resulting in mutation of the input
assert output[0][0].item() != x_val.real[0].item()
11 changes: 11 additions & 0 deletions tests/link/numba/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytensor.scalar.math as psm
import pytensor.tensor as pt
from pytensor import config, function
from pytensor.compile.io import In
from pytensor.graph import Apply
from pytensor.scalar import ScalarLoop, UnaryScalarOp
from pytensor.scalar.basic import Composite
Expand Down Expand Up @@ -312,3 +313,13 @@ def test_loop_with_cython_wrapped_op(self):
res = fn(x_test)
expected_res = ps.psi(x).eval({x: x_test})
np.testing.assert_allclose(res, expected_res)


def test_numba_real():
x = pt.zvector("x")
out = pt.real(x)[0].set(99.0)
x_val = np.array([1 + 2j, 3 + 4j])
_, output = compare_numba_and_py([In(x, mutable=True)], [out], [x_val])

# Verify that the real Op does not return a view, resulting in mutation of the input
assert output[0].item(0) != x_val.real.item(0)
28 changes: 28 additions & 0 deletions tests/link/pytorch/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,31 @@ def relu(row):
vals = torch.zeros(2, 3).normal_()
np.testing.assert_allclose(f(vals), torch.relu(vals))
assert op.call_shapes == [torch.Size([])], op.call_shapes


def test_Real():
x = pt.zvector("x")
out = pt.real(x)
compare_pytorch_and_py([x], [out], [np.array([1 + 2j, 3 + 4j])])


def test_Real_no_input_mutation():
"""Regression test: real() must not return a view that lets inplace ops corrupt the input.

numpy.real returns a view, so if the backend dispatch also returns a view
and the inplace optimization fires on a downstream SetSubtensor, the
original mutable input can be silently corrupted.
"""
from pytensor import In, function
from tests.link.pytorch.test_basic import pytorch_mode

x = pt.zvector("x")
out = pt.real(x)[0].set(99.0)
f = function([In(x, mutable=True)], out, mode=pytorch_mode)

test_input = np.array([1 + 2j, 3 + 4j])
original = test_input.copy()
result = f(test_input)

np.testing.assert_allclose(result, [99.0, 3.0])
np.testing.assert_array_equal(test_input, original)
19 changes: 19 additions & 0 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2897,6 +2897,25 @@ def test_real_imag(self):
assert_array_equal(Z.real.eval({Z: z}), x)
assert_array_equal(Z.imag.eval({Z: z}), y)

def test_real_no_input_mutation(self):
"""Regression test: real() must not return a view that lets inplace ops corrupt the input.

numpy.real returns a view, so Elemwise.perform copies non-owned data.
This test ensures that pattern stays safe with mutable inputs.
"""
from pytensor import In, function

Z = zvector("z")
out = Z.real[0].set(99.0)
f = function([In(Z, mutable=True)], out)

z = np.array([1 + 2j, 3 + 4j])
original = z.copy()
result = f(z)

np.testing.assert_allclose(result, [99.0, 3.0])
np.testing.assert_array_equal(z, original)

def test_conj(self):
X, Y = self.vars
x, y = self.vals
Expand Down
Loading