Skip to content

Linear algebra rewrites: diag sum rewrite #2022

Open
Jasjeet-Singh-S wants to merge 1 commit intopymc-devs:v3from
Jasjeet-Singh-S:diag-sum-rewrite
Open

Linear algebra rewrites: diag sum rewrite #2022
Jasjeet-Singh-S wants to merge 1 commit intopymc-devs:v3from
Jasjeet-Singh-S:diag-sum-rewrite

Conversation

@Jasjeet-Singh-S
Copy link
Copy Markdown

Summary

This PR adds a new linalg rewrite that simplifies addition of diagonal matrices by applying:

diag(A + B) = diag(A) + diag(B)

in reverse, so we avoid dense matrix addition when all inputs are provably diagonal.

Related to issue #573

What changed

  • Added _extract_diagonal in pytensor/tensor/rewriting/linalg.py.

    • Detects diagonal structure for:
      • AllocDiag with zero offset (pt.diag(v)-style)
      • eye * x patterns (including broadcasted/batched forms)
    • Returns a compact diagonal representation (scalar/vector/batched vector), or None if diagonal structure is not guaranteed.
  • Added rewrite_add_diag_to_diag_add rewrite in pytensor/tensor/rewriting/linalg.py.

    • Triggered on add nodes.
    • Rewrites only when every input is provably diagonal.
    • Sums extracted diagonals first, then reconstructs one diagonal matrix:
      • scalar case: eye * scalar
      • non-scalar case: alloc_diag(summed_diag)
    • Preserves stack traces and output dtype.

Why this helps

For diagonal matrices, dense matrix addition does unnecessary work on known zeros. This rewrite reduces the operation to diagonal-value addition plus a single diagonal reconstruction, which is cheaper and keeps graphs simpler.

Tests

Added/updated tests in tests/tensor/rewriting/test_linalg.py:

  • test_add_diag_rewrite_from_diag_inputs
  • test_add_diag_rewrite_for_eye_mul_cases, with cases:
    • scalar_eye_mul
    • batched_scalar_eye_mul
    • batched_vector_eye_mul

These validate:

  • rewrite activation for diagonal-add patterns
  • no dense matrix add remains in the rewritten graph
  • numerical equivalence with expected outputs
  • batched behavior

Copilot AI review requested due to automatic review settings March 30, 2026 12:27
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new linear algebra graph rewrite that detects when all inputs to an elementwise add are provably diagonal matrices and rewrites the expression to add only their diagonals, then reconstruct a single diagonal matrix—avoiding dense matrix addition work.

Changes:

  • Added _extract_diagonal helper to recognize diagonal structure from AllocDiag(k=0) and Eye * x patterns (including broadcasted/batched forms).
  • Added rewrite_add_diag_to_diag_add canonicalize/stabilize rewrite for add nodes to sum extracted diagonals and reconstruct a diagonal result.
  • Added tests asserting the rewrite removes dense matrix Elemwise{Add} on matrix outputs and preserves numerical equivalence across several cases.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
pytensor/tensor/rewriting/linalg.py Adds diagonal-extraction logic and a new add rewrite that rebuilds a diagonal matrix from summed diagonal entries.
tests/tensor/rewriting/test_linalg.py Adds tests verifying the rewrite triggers and eliminates dense matrix addition for diag + diag and eye * x patterns.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1178 to +1181
# For batched scalar * eye, return batched diagonal entries (B, N),
# not batch scalars (B), so downstream alloc_diag reconstructs (B, N, N).
return scalar_input[..., None] * pt.ones(
(eye_input.shape[-1],), dtype=scalar_input.dtype
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For eye * x patterns, _extract_diagonal builds a length-eye_input.shape[-1] diagonal vector for the batched-scalar case. If eye_input is rectangular (n != m), the true diagonal length is min(n, m), and reconstructing via alloc_diag later will produce the wrong shape/values. Consider either restricting this rewrite to square Eye inputs (or square old_out) or changing the representation/reconstruction to preserve rectangular identity shapes.

Suggested change
# For batched scalar * eye, return batched diagonal entries (B, N),
# not batch scalars (B), so downstream alloc_diag reconstructs (B, N, N).
return scalar_input[..., None] * pt.ones(
(eye_input.shape[-1],), dtype=scalar_input.dtype
# For batched scalar * eye, return batched diagonal entries (B, K),
# where K = min(n, m), not batch scalars (B), so downstream alloc_diag
# reconstructs the correct diagonal shape.
diag_len = pt.minimum(eye_input.shape[-2], eye_input.shape[-1])
return scalar_input[..., None] * pt.ones(
(diag_len,), dtype=scalar_input.dtype

Copilot uses AI. Check for mistakes.
pt.eye(old_out.shape[-2], old_out.shape[-1], dtype=old_out.dtype)
* summed_diag
)
else:
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alloc_diag(summed_diag, axis1=-2, axis2=-1) always constructs a square matrix of size summed_diag.shape[-1]. If this rewrite triggers for diagonal matrices coming from Eye(n, m) * ... with n != m (especially the batched cases where summed_diag.ndim > 0), the replacement will silently change the output shape from (…, n, m) to (…, m, m). Add a guard to only apply this rewrite when the output is guaranteed square (or implement a rectangular-diagonal reconstruction path instead of alloc_diag).

Suggested change
else:
else:
# alloc_diag always constructs a square matrix based on the diagonal length.
# Only apply this rewrite when the original output is guaranteed square,
# to avoid silently changing shapes for rectangular diagonal matrices.
out_type_shape = getattr(old_out.type, "shape", None)
if (
out_type_shape is None
or out_type_shape[-2] is None
or out_type_shape[-1] is None
or out_type_shape[-2] != out_type_shape[-1]
):
# We cannot prove the output is square; skip this optimization.
return None

Copilot uses AI. Check for mistakes.
Comment on lines +1207 to +1211
z = pt.eye(5) * x[:, None, :] + pt.eye(5) * y[:, None, :]

x_test = rng.normal(size=(4, 5)).astype(config.floatX)
y_test = rng.normal(size=(4, 5)).astype(config.floatX)
expected = np.eye(5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :])
Copy link

Copilot AI Mar 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parametrized eye * x tests only cover square pt.eye(5) inputs. Given the new rewrite matches Eye patterns, it would be good to add a non-square pt.eye(n, m) case (e.g. n!=m) to ensure the rewrite either does not apply or preserves the correct (n, m) output shape—this would catch the rectangular-shape failure mode introduced by the alloc_diag reconstruction path.

Suggested change
z = pt.eye(5) * x[:, None, :] + pt.eye(5) * y[:, None, :]
x_test = rng.normal(size=(4, 5)).astype(config.floatX)
y_test = rng.normal(size=(4, 5)).astype(config.floatX)
expected = np.eye(5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :])
z = pt.eye(3, 5) * x[:, None, :] + pt.eye(3, 5) * y[:, None, :]
x_test = rng.normal(size=(4, 5)).astype(config.floatX)
y_test = rng.normal(size=(4, 5)).astype(config.floatX)
expected = np.eye(3, 5)[None, :, :] * (x_test[:, None, :] + y_test[:, None, :])

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants