Linear algebra rewrites: diag sum rewrite #2022
Linear algebra rewrites: diag sum rewrite #2022Jasjeet-Singh-S wants to merge 1 commit intopymc-devs:v3from
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a new linear algebra graph rewrite that detects when all inputs to an elementwise add are provably diagonal matrices and rewrites the expression to add only their diagonals, then reconstruct a single diagonal matrix—avoiding dense matrix addition work.
Changes:
- Added
_extract_diagonalhelper to recognize diagonal structure fromAllocDiag(k=0)andEye * xpatterns (including broadcasted/batched forms). - Added
rewrite_add_diag_to_diag_addcanonicalize/stabilize rewrite foraddnodes to sum extracted diagonals and reconstruct a diagonal result. - Added tests asserting the rewrite removes dense matrix
Elemwise{Add}on matrix outputs and preserves numerical equivalence across several cases.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
pytensor/tensor/rewriting/linalg.py |
Adds diagonal-extraction logic and a new add rewrite that rebuilds a diagonal matrix from summed diagonal entries. |
tests/tensor/rewriting/test_linalg.py |
Adds tests verifying the rewrite triggers and eliminates dense matrix addition for diag + diag and eye * x patterns. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # For batched scalar * eye, return batched diagonal entries (B, N), | ||
| # not batch scalars (B), so downstream alloc_diag reconstructs (B, N, N). | ||
| return scalar_input[..., None] * pt.ones( | ||
| (eye_input.shape[-1],), dtype=scalar_input.dtype |
There was a problem hiding this comment.
For eye * x patterns, _extract_diagonal builds a length-eye_input.shape[-1] diagonal vector for the batched-scalar case. If eye_input is rectangular (n != m), the true diagonal length is min(n, m), and reconstructing via alloc_diag later will produce the wrong shape/values. Consider either restricting this rewrite to square Eye inputs (or square old_out) or changing the representation/reconstruction to preserve rectangular identity shapes.
| # For batched scalar * eye, return batched diagonal entries (B, N), | |
| # not batch scalars (B), so downstream alloc_diag reconstructs (B, N, N). | |
| return scalar_input[..., None] * pt.ones( | |
| (eye_input.shape[-1],), dtype=scalar_input.dtype | |
| # For batched scalar * eye, return batched diagonal entries (B, K), | |
| # where K = min(n, m), not batch scalars (B), so downstream alloc_diag | |
| # reconstructs the correct diagonal shape. | |
| diag_len = pt.minimum(eye_input.shape[-2], eye_input.shape[-1]) | |
| return scalar_input[..., None] * pt.ones( | |
| (diag_len,), dtype=scalar_input.dtype |
| pt.eye(old_out.shape[-2], old_out.shape[-1], dtype=old_out.dtype) | ||
| * summed_diag | ||
| ) | ||
| else: |
There was a problem hiding this comment.
alloc_diag(summed_diag, axis1=-2, axis2=-1) always constructs a square matrix of size summed_diag.shape[-1]. If this rewrite triggers for diagonal matrices coming from Eye(n, m) * ... with n != m (especially the batched cases where summed_diag.ndim > 0), the replacement will silently change the output shape from (…, n, m) to (…, m, m). Add a guard to only apply this rewrite when the output is guaranteed square (or implement a rectangular-diagonal reconstruction path instead of alloc_diag).
| else: | |
| else: | |
| # alloc_diag always constructs a square matrix based on the diagonal length. | |
| # Only apply this rewrite when the original output is guaranteed square, | |
| # to avoid silently changing shapes for rectangular diagonal matrices. | |
| out_type_shape = getattr(old_out.type, "shape", None) | |
| if ( | |
| out_type_shape is None | |
| or out_type_shape[-2] is None | |
| or out_type_shape[-1] is None | |
| or out_type_shape[-2] != out_type_shape[-1] | |
| ): | |
| # We cannot prove the output is square; skip this optimization. | |
| return None |
| z = pt.eye(5) * x[:, None, :] + pt.eye(5) * y[:, None, :] | ||
|
|
||
| x_test = rng.normal(size=(4, 5)).astype(config.floatX) | ||
| y_test = rng.normal(size=(4, 5)).astype(config.floatX) | ||
| expected = np.eye(5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :]) |
There was a problem hiding this comment.
The parametrized eye * x tests only cover square pt.eye(5) inputs. Given the new rewrite matches Eye patterns, it would be good to add a non-square pt.eye(n, m) case (e.g. n!=m) to ensure the rewrite either does not apply or preserves the correct (n, m) output shape—this would catch the rectangular-shape failure mode introduced by the alloc_diag reconstruction path.
| z = pt.eye(5) * x[:, None, :] + pt.eye(5) * y[:, None, :] | |
| x_test = rng.normal(size=(4, 5)).astype(config.floatX) | |
| y_test = rng.normal(size=(4, 5)).astype(config.floatX) | |
| expected = np.eye(5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :]) | |
| z = pt.eye(3, 5) * x[:, None, :] + pt.eye(3, 5) * y[:, None, :] | |
| x_test = rng.normal(size=(4, 5)).astype(config.floatX) | |
| y_test = rng.normal(size=(4, 5)).astype(config.floatX) | |
| expected = np.eye(3, 5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :]) |
Summary
This PR adds a new linalg rewrite that simplifies addition of diagonal matrices by applying:
diag(A + B) = diag(A) + diag(B)in reverse, so we avoid dense matrix addition when all inputs are provably diagonal.
Related to issue #573
What changed
Added
_extract_diagonalinpytensor/tensor/rewriting/linalg.py.AllocDiagwith zero offset (pt.diag(v)-style)eye * xpatterns (including broadcasted/batched forms)Noneif diagonal structure is not guaranteed.Added
rewrite_add_diag_to_diag_addrewrite inpytensor/tensor/rewriting/linalg.py.addnodes.eye * scalaralloc_diag(summed_diag)Why this helps
For diagonal matrices, dense matrix addition does unnecessary work on known zeros. This rewrite reduces the operation to diagonal-value addition plus a single diagonal reconstruction, which is cheaper and keeps graphs simpler.
Tests
Added/updated tests in
tests/tensor/rewriting/test_linalg.py:test_add_diag_rewrite_from_diag_inputstest_add_diag_rewrite_for_eye_mul_cases, with cases:scalar_eye_mulbatched_scalar_eye_mulbatched_vector_eye_mulThese validate: