Skip to content

Replace join by direct writes in ancestor producers#2014

Draft
ricardoV94 wants to merge 3 commits intopymc-devs:v3from
ricardoV94:join_memory_opt
Draft

Replace join by direct writes in ancestor producers#2014
ricardoV94 wants to merge 3 commits intopymc-devs:v3from
ricardoV94:join_memory_opt

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Mar 29, 2026

Second time playing out with this idea: #1333

I still don't know if we want to go down this route, I only ever seem to get small benefits

Eliminate Join buffer copies via WriteSplit/WriteJoin

Join(axis=0, f(x), g(y), ...) allocates N intermediate buffers (one per stream) plus a final output buffer, then copies everything. This PR introduces a rewrite that pre-allocates the output buffer and has each Elemwise stream write directly into its slice, eliminating all intermediate allocations and the concatenation copy.

How it works

The key idea is passing the output buffer slice as an extra input to each Elemwise op, similar to the out argument in NumPy ufuncs. The scalar_op is wrapped in a Composite with a dummy input, and inplace_pattern tells the Elemwise to write its result into that input's buffer. This way each stream writes directly into its region of the pre-allocated output.

New ops

  • WriteSplit: splits buffer into contiguous views without declaring view_map, so the DestroyHandler treats each slice as independent and multiple Elemwise ops
    can destroy their slices without conflict.
  • WriteJoin: returns buffer with destroy_map for ordering, ensuring all inplace writes complete before the buffer is consumed.

JoinBufferElimination rewrite (position 50.0, inplace phase)

  • Walks up through DimShuffle/Reshape to find Elemwise producers
  • Absorbs view ops by inverse-transforming the buffer slice (e.g. ravel is subsumed by reshaping the buffer slice to the pre-ravel shape)
  • Falls back to set_subtensor for non-Elemwise or shared-client streams
  • Requires >= 2 expandable streams (below that, overhead exceeds savings)
  • Checks for duplicate/shared Elemwise outputs to avoid double-destroy

Benchmarks

NUMBA, concat([x_i * 2 + 1 for x_i in xs]), 1000 elements per stream:

Streams WITHOUT WITH Savings
2 2.2us 2.0us 10%
4 3.0us 2.4us 18%
8 4.6us 3.4us 24%

With heavier ops like tanh, savings are ~3-4% per stream (compute dominates over allocation overhead).

Radon model (173 elements, 7 streams): ~parity — compute-dominated, allocation overhead is a small fraction.

Future work

  • Extend to other producers: The out buffer pattern could be extended beyond Elemwise to RVs, Blockwise, and Dot — any op that could accept a pre-allocated output buffer. This is actually what Scan's C backend tries hard to do with its internal buffers, but it can't guarantee the write actually happened and has to do a costly check at the end to verify outputs were reused rather than ignored. A stronger op-agnostic out= API would help both here and in the Scan case.
  • User-controlled buffer reuse: After this rewrite, an advanced user could lift the AllocEmpty buffer as an outer input and pass it in every call, eliminating even the output allocation overhead across repeated invocations.
  • Chained set_subtensor: A follow-up could identify implicit joins from chained set_subtensor(set_subtensor(buf, ...), ...) and apply the same approach.
  • Common patterns: This optimization targets concatenate([f(x).ravel() for x in params]) which is very common in gradient computations (gradconcatenate of raveled parameter gradients) and in code that uses pack/unpack helpers to flatten parameter vectors.

Also included

Fixes local_useless_composite_outputs to preserve Composite inputs referenced by Elemwise.inplace_pattern even when unused in the scalar graph.

@ricardoV94 ricardoV94 changed the title Replace join by direct writes in ancestors Replace join by direct writes in ancestor producers Mar 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant