Replace join by direct writes in ancestor producers#2014
Draft
ricardoV94 wants to merge 3 commits intopymc-devs:v3from
Draft
Replace join by direct writes in ancestor producers#2014ricardoV94 wants to merge 3 commits intopymc-devs:v3from
ricardoV94 wants to merge 3 commits intopymc-devs:v3from
Conversation
7c2747a to
0abc044
Compare
0abc044 to
67fad07
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Second time playing out with this idea: #1333
I still don't know if we want to go down this route, I only ever seem to get small benefits
Eliminate Join buffer copies via WriteSplit/WriteJoin
Join(axis=0, f(x), g(y), ...)allocates N intermediate buffers (one per stream) plus a final output buffer, then copies everything. This PR introduces a rewrite that pre-allocates the output buffer and has each Elemwise stream write directly into its slice, eliminating all intermediate allocations and the concatenation copy.How it works
The key idea is passing the output buffer slice as an extra input to each Elemwise op, similar to the
outargument in NumPy ufuncs. The scalar_op is wrapped in aCompositewith a dummy input, andinplace_patterntells the Elemwise to write its result into that input's buffer. This way each stream writes directly into its region of the pre-allocated output.New ops
view_map, so the DestroyHandler treats each slice as independent and multiple Elemwise opscan destroy their slices without conflict.
destroy_mapfor ordering, ensuring all inplace writes complete before the buffer is consumed.JoinBufferElimination rewrite (position 50.0, inplace phase)
ravelis subsumed by reshaping the buffer slice to the pre-ravel shape)set_subtensorfor non-Elemwise or shared-client streamsBenchmarks
NUMBA,
concat([x_i * 2 + 1 for x_i in xs]), 1000 elements per stream:With heavier ops like
tanh, savings are ~3-4% per stream (compute dominates over allocation overhead).Radon model (173 elements, 7 streams): ~parity — compute-dominated, allocation overhead is a small fraction.
Future work
outbuffer pattern could be extended beyond Elemwise to RVs, Blockwise, and Dot — any op that could accept a pre-allocated output buffer. This is actually what Scan's C backend tries hard to do with its internal buffers, but it can't guarantee the write actually happened and has to do a costly check at the end to verify outputs were reused rather than ignored. A stronger op-agnosticout=API would help both here and in the Scan case.AllocEmptybuffer as an outer input and pass it in every call, eliminating even the output allocation overhead across repeated invocations.set_subtensor(set_subtensor(buf, ...), ...)and apply the same approach.concatenate([f(x).ravel() for x in params])which is very common in gradient computations (grad→concatenateof raveled parameter gradients) and in code that usespack/unpackhelpers to flatten parameter vectors.Also included
Fixes
local_useless_composite_outputsto preserve Composite inputs referenced byElemwise.inplace_patterneven when unused in the scalar graph.