diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 2c17020cd9..52349bed68 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -28,6 +28,7 @@ from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod from pytensor.tensor.nlinalg import ( SVD, + Eig, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -1145,3 +1146,115 @@ def scalar_solve_to_division(fgraph, node): copy_stack_trace(old_out, new_out) return [new_out] + + +@register_canonicalize +@register_stabilize +@node_rewriter([blockwise_of(Eig)]) +def rewrite_eig_eye(fgraph, node): + """ + This rewrite takes advantage of the fact that for any identity matrix, all the eigenvalues are 1 and the eigenvectors are the standard basis. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check whether input to Eig is Eye and the 1's are on main diagonal + potential_eye = node.inputs[0] + if not ( + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and Eye.is_offset_zero(potential_eye.owner) + ): + return None + + eigval_rewritten = pt.ones(potential_eye.shape[-1], dtype=node.outputs[0].dtype) + eigvec_rewritten = pt.eye(potential_eye.shape[-1], dtype=node.outputs[1].dtype) + + return [eigval_rewritten, eigvec_rewritten] + + +@register_canonicalize +@register_stabilize +@node_rewriter([blockwise_of(Eig)]) +def rewrite_eig_diag(fgraph, node): + """ + This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the standard basis. + + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices + that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to + make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, + vector or a matrix. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + inputs = node.inputs[0] + + # Check for use of pt.diag first + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(inputs.owner) + ): + eigval_rewritten = inputs.owner.inputs[0].astype(node.outputs[0].dtype) + base_eye = pt.eye(inputs.shape[-1], dtype=node.outputs[1].dtype) + eigvec_rewritten = pt.broadcast_to(base_eye, inputs.shape) + return [eigval_rewritten, eigvec_rewritten] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + _eye_input, non_eye_input = eye_input, non_eye_inputs[0] + + n = inputs.shape[-1] + base_eye = pt.eye(n, dtype=node.outputs[1].dtype) + eigvec_rewritten = ( + base_eye + if inputs.ndim == 2 + else pt.broadcast_to(pt.shape_padleft(base_eye, inputs.ndim - 2), inputs.shape) + ) + + # Checking if original x was scalar/vector/matrix + if non_eye_input.type.broadcastable[-2:] == (True, True): + # For scalar + eigval_rewritten = pt.full( + node.outputs[0].shape, + non_eye_input.squeeze(axis=(-1, -2)), + dtype=node.outputs[0].dtype, + ) + elif non_eye_input.type.broadcastable[-2:] == (False, False): + # For Matrix (including batched matrices) + eigval_rewritten = non_eye_input.diagonal(axis1=-1, axis2=-2).astype( + node.outputs[0].dtype + ) + else: + # For vector + eigval_rewritten = non_eye_input.squeeze().astype(node.outputs[0].dtype) + + return [eigval_rewritten, eigvec_rewritten] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9e8783e51a..9ed07196b5 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -19,6 +19,7 @@ from pytensor.tensor.nlinalg import ( SVD, Det, + Eig, KroneckerProduct, MatrixInverse, MatrixPinv, @@ -1128,3 +1129,115 @@ def solve_op_in_graph(graph): np.testing.assert_allclose( f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5 ) + + +@pytest.mark.parametrize( + "shape", + [(), (7,), (1, 7), (7, 1), (7, 7)], + ids=["scalar", "vector", "row_vec", "col_vec", "matrix"], +) +def test_eig_diag_from_eye_mul(shape): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + y = pt.eye(7) * x + + # Calculating eigval and eigvec using pt.linalg.eig + eigval, eigvec = pt.linalg.eig(y) + + # REWRITE TEST + f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + assert not any( + isinstance(node.op, Eig) or isinstance(getattr(node.op, "core_op", None), Eig) + for node in nodes + ) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + + x_test_matrix = np.eye(7) * x_test + eigval, _ = np.linalg.eig(x_test_matrix) + rewritten_eigval, rewritten_eigvec = f_rewritten(x_test) + + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + x_test_matrix @ rewritten_eigvec, + rewritten_eigvec @ np.diag(rewritten_eigval), + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_eig_eye(): + n = pt.iscalar("n") + x = pt.eye(n) + eigval, eigvec = pt.linalg.eig(x) + + # REWRITE TEST + f_rewritten = function([n], [eigval, eigvec], mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any( + isinstance(node.op, Eig) or isinstance(getattr(node.op, "core_op", None), Eig) + for node in nodes + ) + + # NUMERIC VALUE TEST + n_test = 10 + x_test = np.eye(n_test) + eigval, _ = np.linalg.eig(x_test) + rewritten_eigval, rewritten_eigvec = f_rewritten(n_test) + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + x_test @ rewritten_eigvec, + rewritten_eigvec @ np.diag(rewritten_eigval), + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_eig_diag(): + x = pt.tensor("x", shape=(None,)) + x_diag = pt.diag(x) + eigval, eigvec = pt.linalg.eig(x_diag) + + # REWRITE TEST + f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any( + isinstance(node.op, Eig) or isinstance(getattr(node.op, "core_op", None), Eig) + for node in nodes + ) + + # NUMERIC VALUE TEST + x_test = np.random.rand(7).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + eigval, _ = np.linalg.eig(x_test_matrix) + rewritten_eigval, rewritten_eigvec = f_rewritten(x_test) + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + x_test_matrix @ rewritten_eigvec, + rewritten_eigvec @ np.diag(rewritten_eigval), + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + )