Optimize CPU deform_conv2d forward pass with parallel im2col#9442
Optimize CPU deform_conv2d forward pass with parallel im2col#9442developer0hye wants to merge 2 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9442
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 137d2b7 with merge base 8a5946e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Three changes to the CPU deformable convolution forward kernel: 1. Replace at::zeros with at::empty for columns and out_buf buffers. The deformable_im2col_kernel writes every element of the columns buffer, and out_buf is fully written by addmm_, so zero-initialization is wasted work. 2. Use addmm_ with beta=0 instead of the default beta=1. This avoids accumulating into uninitialized memory while preserving in-place operation (no extra allocation unlike at::mm). 3. Parallelize deformable_im2col_kernel with at::parallel_for. The im2col loop was the only single-threaded phase in the forward pass (GEMM is already parallelized by BLAS). Each loop iteration writes to a non-overlapping region of the columns buffer, so parallelization is safe. Benchmark results on Apple M2 (CPU, float32): Config Before (ms) After (ms) Change small-b1 9.76 2.44 -75% small-b8 91.77 33.88 -63% medium-b1 216.70 75.80 -65% medium-b8 1152.09 650.00 -44% large-b1 348.86 302.70 -13% large-b4 1342.75 1289.96 -4% Signed-off-by: Yonghye Kwon <developer.0hye@gmail.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Yonghye Kwon <developer.0hye@gmail.com>
e653cad to
8a89fb8
Compare
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Yonghye Kwon <developer.0hye@gmail.com>
|
@developer0hye Hi, thanks a lot for this PR! May I ask what's the motivation for optimizing the CPU path for deform_conv2d? It's almost always used on GPU. Is there a specific application in your use case? |
|
@zy1git Great question!
On a personal note, I've been working on in-browser ML inference — things like humanblur and bgremover. These are built with Candle + WASM rather than PyTorch, so admittedly a different stack, but the experience taught me that CPU-side efficiency matters more than you'd expect — not every user has GPU acceleration available, even in a browser. That mindset carried over here: if a CPU path exists and there's a straightforward way to make it 3x faster, it's worth doing. |
Summary
The CPU
deform_conv2dforward pass spends 89–97% of its time in thedeformable_im2col_kernel(confirmed viatorch.profiler), yet this kernel runs entirely single-threaded. GEMM (addmm_) accounts for only 3–10% and is already parallelized by BLAS.This PR introduces three changes to
torchvision/csrc/ops/cpu/deform_conv2d_kernel.cppthat together yield a 2.5–3.3x end-to-end speedup on the forward pass:Parallelize
deformable_im2col_kernelwithat::parallel_for.Each loop iteration writes to a non-overlapping region of the columns buffer (the write offset is uniquely determined by
(in_c, out_b, out_y, out_x)), so parallelization is safe with no synchronization needed. Results are bit-for-bit identical regardless of thread count.Replace
at::zeroswithat::emptyfor thecolumnsbuffer.deformable_im2col_kernelwrites every element of this buffer (n_in_channels × kH × kW × parallel_imgs × out_h × out_welements total), so zero-initialization is wasted work.Replace
at::zeroswithat::emptyforout_bufand useaddmm_withbeta=0.Each
out_buf[b][g]is written exactly once per(batch_block, weight_group)pair. Usingbeta=0skips the accumulation of uninitialized values while preserving in-place operation (unlikeat::mm, which allocates a new tensor).Benchmark
All measurements use
time.perf_counter(), 10 warmup + 100 timed iterations, reporting the median.Hardware: Apple M2,
torch.get_num_threads() = 4Dtype: float32, with mask (DCNv2 mode)
Config format:
s{spatial}-b{batch}, e.g.s32-b4= 64 in/out channels, 3×3 kernel, stride 1, padding 1, 32×32 spatial, batch 4.s64-*uses 256 in/out channels.Profiler breakdown (baseline, s32-b1)
Benchmark script
Numerical correctness
Output is bit-for-bit identical between 1-thread and 8-thread execution (
torch.equalreturnsTrue). Each thread operates on a disjoint slice of the columns buffer, so floating-point evaluation order is unchanged.All existing
TestDeformConvtests pass (forward, backward, scripting, opcheck).Related
deform_conv2dkernels are sequential and don't utilize multicore resourcescc @NicolasHug