Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015
Draft
ricardoV94 wants to merge 3 commits intopymc-devs:v3from
Draft
Numba: fuse AdvancedSubtensor->Elemwise->AdvancedIncSubtensor#2015ricardoV94 wants to merge 3 commits intopymc-devs:v3from
ricardoV94 wants to merge 3 commits intopymc-devs:v3from
Conversation
6d875d8 to
0ad6e2e
Compare
Fuse single-client AdvancedSubtensor1 nodes into Elemwise loops,
replacing indirect array reads with a single iteration loop that
uses index arrays for input access.
Before (2 nodes):
temp = x[idx] # AdvancedSubtensor1, shape (919,)
result = temp + y # Elemwise
After (1 fused loop, x is read directly via idx):
for k in range(919):
result[k] = x[idx[k]] + y[k]
- Introduce IndexedElemwise Op (in rewriting/indexed_elemwise.py)
- Add FuseIndexedElemwise rewrite with SequenceDB
- Merge _vectorized intrinsics into one with NO_SIZE/NO_INDEXED sentinels
- Fix Numba missing getitem(0d_array, Ellipsis)
- Index arrays participate in iter_shape with correct static bc
- zext for unsigned index types
- Add op_debug_information for dprint(print_op_info=True)
- Add correctness tests and benchmarks
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extend the IndexedElemwise fusion to also absorb
AdvancedIncSubtensor1 (indexed set/inc) on the output side.
Before (3 nodes):
temp = Elemwise(x[idx], y) # shape (919,)
result = IncSubtensor(target, temp, idx) # target shape (85,)
After (1 fused loop, target is an input):
for k in range(919):
target[idx[k]] += scalar_fn(x[idx[k]], y[k])
- FuseIndexedElemwise now detects AdvancedIncSubtensor1 consumers
- Reject fusion when val broadcasts against target's non-indexed axes
- store_core_outputs supports inc mode via o[...] += val
- Inner fgraph always uses inplace IncSubtensor
- op_debug_information shows buf_N / idx_N linkage
- Add indexed-update tests, broadcast guard test, and benchmarks
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Support AdvancedSubtensor on any axis (not just axis 0) and multi-index patterns like x[idx_row, idx_col] where multiple 1D index arrays address consecutive source axes. Arbitrary axis: x[:, idx] + y → fused loop with indirect indexing on axis 1 Multi-index: x[idx0, idx1] + y → out[i, j] = x[idx0[i], idx1[i], j] + y[i, j] - Add undo_take_dimshuffle_for_fusion pre-fusion rewrite - Generalize indexed_inputs encoding: ((positions, axis, idx_bc), ...) - input_read_spec uses tuple of (idx_k, axis) pairs per input - source_input_types for array struct access, input_types (effective) for core_ndim / _compute_vectorized_types - n_index_loop_dims = max(idx.ndim for group) for future ND support - Index arrays participate in iter_shape with correct per-index static bc - Reject fusion when val broadcasts against target's non-indexed axes - Add correctness, broadcast, and shape validation tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
0ad6e2e to
9e32400
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Introduce
IndexedElemwise, anOpFromGraphthat wrapsAdvancedSubtensor+Elemwise+AdvancedIncSubtensorsubgraphs so the Numba backend can generate a single loop with indirect indexing, avoiding materializing AvancedSubtensor input arrays, and writing directly on the output buffer, doing the job of AdvancedIncSubtensor in the same loop, without having to loop again through the intermediate elemwise outputCommit 1 fuses indexed reads (AdvancedSubtensor1 on inputs).
Commit 2 fuses indexed updates (AdvancedIncSubtensor1 on outputs).
Commit 3 extends to AdvancedSubtensor inputs, on arbitrary (1d) indexed (consecutive) axes
Motivation
In hierarchical models with mu = beta[idx] * x + ..., the logp+gradient graph combines indexed reads and indexed updates in the same Elemwise (the forward reads county-level parameters via an index, and the gradient accumulates back into county-level buffers via the same index).
A simler
Next step would be to also fuse the sum directly on the elemwise, so we end up with a single loop over the data. This is important as the sum can easily break our fusion, as we don't fuse if the elemwise output is needed elsewhere (like in a sum).