Add FFT-based DISCO S2 contraction for faster spherical convolutions#158
Add FFT-based DISCO S2 contraction for faster spherical convolutions#158mcgibbon wants to merge 15 commits intoNVIDIA:mainfrom
Conversation
… correlation The existing torch reference implementation iterates over output longitudes (typically 256-512 iterations), rolling the input and performing a sparse matmul each step. This replaces that with an FFT circular cross-correlation: one rfft, one einsum, and one irfft — eliminating the Python loop entirely. Gated behind `use_fft_contraction=True` on DiscreteContinuousConvS2 and DiscreteContinuousConvTransposeS2 so both paths remain available. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce a torch_harmonics.benchmark subpackage with: - Timer infrastructure (CUDATimer, NullTimer, CPUEventPair) for GPU event-based and CPU wall-clock timing - BenchmarkABC base class with registry via @register_benchmark - CLI runner (python -m torch_harmonics.benchmark) that saves JSON results - RealSHT and InverseRealSHT benchmarks at 1-degree resolution Also add benchmark_results to .gitignore. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Register a disco_conv_s2_torch_1deg benchmark at 1-degree resolution (B=4, 4 channels, 180x360) using the non-optimized torch contraction path, which does not require the custom CUDA extension. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce hardware.py with a device-name-to-scale-factor lookup table so benchmark batch sizes adapt to different GPUs. Base batch sizes are tuned for Tesla T4 (factor 1.0). Unknown devices default to 1.0 with a warning to add an entry for their hardware. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Register disco_conv_s2_fft_1deg benchmark using use_fft_contraction=True to compare against the torch sparse path benchmark - Replace torch.allclose/assertTrue with torch.testing.assert_close in test_fft_contraction.py for better error messages on failure Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace torch.fft.rfft/irfft with the custom wrappers from torch_harmonics.fft which ensure correct Hermitian symmetry by zeroing imaginary parts of DC and Nyquist components, avoiding artifacts on certain GPU backends. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Register disco_conv_s2_cuda_1deg benchmark using the custom CUDA extension at B=64 to compare against FFT and torch sparse paths. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
azrael417
left a comment
There was a problem hiding this comment.
Thanks for this MR, especially the benchmarking. I left some comments. In particular, I am not sure how useful the FFT accelerated DISCO is since the psi tensor can become very big relatively quickly when not stored as sparse tensor. Also, please make sure there is feature equality between the distributed and non-distributed routines.
| "disco_kernels::_disco_s2_transpose_contraction_optimized", _disco_s2_transpose_contraction_bwd_optimized, setup_context=_setup_context_conv_backward) | ||
|
|
||
| # FFT-based contraction functions | ||
| def _densify_psi(psi_sparse: torch.Tensor, nlat_in: int, nlon_in: int): |
There was a problem hiding this comment.
This one can be pretty big. We decided not to go the FFT route since the device memory will be used up quickly when storing this as dense tensor. For example, for a small test case of 361x721 -> 361x721 that tensor is already 350 MB per basis function, so for basis function sizes of 8 or 9 it is already 2.8 GB.
There was a problem hiding this comment.
That's a good point. I wonder if we can sparsify this a bit by restating the nlat_out dim as centered on the nlat_in position, so that we only need a few indices, as past a certain distance all entries will be zero.
| return cls(conv=conv, x=x) | ||
|
|
||
| @final | ||
| def run_instance(self, timer: Timer) -> TensorDict: |
There was a problem hiding this comment.
generally, I would benchmark forward and backward pass. Some kernels are much more challenging in backward pass (for example the CUDA disco kernel since the naive implementation has many atomics). Please add backward benchmarks as well. You can create a random tensor from the forward output and run a backward pass.
There was a problem hiding this comment.
Updated, now forward and backward pass are computed.
tests/test_fft_contraction.py
Outdated
| y_fft = conv_fft(x) | ||
|
|
||
| self.assertEqual(y_ref.shape, y_fft.shape) | ||
| torch.testing.assert_close(y_ref, y_fft, atol=atol, rtol=1e-4) |
There was a problem hiding this comment.
Please use the compare_tensors routine from testutils for comparing tensors, it does some additional checks and also has better verbose printing.
|
There's also a way to further speed up the optimized cuda kernel at 1-degree resolution or coarser, by 1/4-degree resolution it doesn't really matter (unless we were to increase the cutoff radius). In the polar latitudes where the existing kernel is already sufficiently dense (it's 100% dense at the polarmost latitude) you can run the FFT approach instead. This got a 33% forward pass speed-up in the benchmark I ran. Is this something you'd be interested to see in another PR? |
|
Please note that the tricky part for DISCO is the backward kernel, this has the potential to generate a lot of atomics. We have explored different data layouts as well, such as channels last (and it is available in another branch) but so far performance speedup findings are mixed. Some kernels perform better than the standard one on certain grid sizes, but those perform very badly on others. We found that the existing kernel performs well across the board for most cases. If you find a kernel which performs better for a range of grid sizes and K (3<=K<=9, 180< nlat < 721, 360 < nlon < 1440), let me know. Otherwise we would need to build a heuristic lookup table and that is a bit overkill imo. |
|
Thanks @mcgibbon for proposing this! Does the math check out? Given that this is essentially a different algorithm, I would like to make sure that this is still computing the right thing. I am not sure the convolution decomposes into a 1-d convolution theorem given the non-Abelian nature of SO(3) group convolutions. Have you tried checking if the results are still the same? Especially for non-isotropic kernels, there might be a difference between the two. Finally, the speedups seem to be more significant for lower resolutions. I want to avoid adding more complexity to a code to guarantee high maintainability. If speedups are expected to become less as the resolution increases, my preference is to stick to a simpler setup with a single algorithm. |
I believe it does, which is reflected in the passing of the
Even at quarter-degree resolution, the fft-based code is 4x faster than the existing torch code. At 1/8th-degree resolution, you're probably going to be running the optimized CUDA kernel anyways in some kind of spatial parallel mode - I run out of memory trying to densify a ~45GB psi tensor. Up to you on whether you want to incorporate it, but we will need to fork the code for our own purposes if not. The convolutions are currently very slow on CPU, and it's pretty easy to end up with a version of this code that doesn't have the optimized cuda kernels built. If you agree about the fft being a more efficient way to do cross-convolution and the issue is just with the memory usage of the precomputed psi, I could update this to replace the torch implementation with the fft-based version, and have the precomputation of psi be the flag. That should have lower maintenance burden. Even without the pre-computation the new code is significantly faster, I believe, but I'd have to re-check and confirm. |
|
Wow, the banded implementation (only storing the nlat_out-local region of psi that is nonzero) is much better actually: And now the memory usage is drastically reduced. I was able to run it at 1/8th degree resolution: With this change, it looks like the update gives even more performance at higher resolutions, not less. |
Instead of densifying the full (K, nlat_out, nlat_in, nlon) tensor, store only the contiguous band of input latitudes with nonzero entries per output latitude. This reduces memory by ~nlat_in/band_width and makes the einsum contract over ~20 bands instead of all input lats. Fixes OOM at 1/8-degree resolution and speeds up quarter-degree FFT contraction by ~10x (498ms -> 46ms). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…nchmarks Add MemoryBenchmark context manager that tracks peak GPU memory allocation and reservation during benchmark runs. Update DISCO benchmarks to time forward and backward passes independently using child timers, as requested in review. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Switch from torch.testing.assert_close to the project's compare_tensors utility for better verbose diagnostics on failure. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| return cls(conv=conv, x=x) | ||
|
|
||
| @final | ||
| def run_instance(self, timer: Timer) -> TensorDict: |
There was a problem hiding this comment.
Updated, now forward and backward pass are computed.
| return MemoryResult(max_alloc=self._max_alloc, max_reserved=self._max_reserved) | ||
|
|
||
|
|
||
| def benchmark_memory() -> MemoryBenchmark: |
There was a problem hiding this comment.
Added memory benchmarking, the fft pathway still uses ~3x the memory of CUDA which uses ~2x the memory of torch, but it's significantly less than before.
Support different in/out shapes in benchmark infrastructure and add FFT and CUDA benchmarks for 2x downsampling (360x720 -> 180x360). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… contraction
Avoids a full tensor copy by applying the roll(-1) as a pointwise
exp(2*pi*i*k/N) multiply after the rfft. Also removes redundant
.to(X_f.dtype) casts on psi_fft_conj (already complex64).
Benchmark results (Tesla T4, 5 iters) show no measurable change at
this resolution — the roll was already fast at these tensor sizes:
tconv_fft_1deg_to_halfdeg:
Before: fwd=86.8ms bwd=193.1ms total=280.6ms mem=2719MB
After: fwd=86.6ms bwd=190.5ms total=277.8ms mem=2719MB
fft_1deg:
Before: fwd=38.0ms bwd=43.7ms total=82.1ms mem=1428MB
After: fwd=38.3ms bwd=42.2ms total=80.8ms mem=1428MB
The benefit should be larger at higher resolutions where the roll
copies more data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
There was one more minor speedup to do in ee2e9d5 to avoid a roll by turning it into a multiplication in the frequency domain. It's very minor - I'd be happy to revert it if you think that would keep the code clearer. |
|
Should be ready for another look. |
|
Please note that we have distributed DISCO convolution as well as SHT support, only the attention kernel is not yet distributed. |
|
I will have a look at the code next week, thanks for committing this. We are also working on getting a precompiled wheel but deploying this is a bit tricky because of combined torch and cuda dependency. We are currently figuring out how to best approach this. |
|
Sounds good. I think the comment about also implementing this in the
distributed version got dropped, I’ll look into it.
…On Fri, Mar 13, 2026 at 1:12 PM Thorsten Kurth ***@***.***> wrote:
*azrael417* left a comment (NVIDIA/torch-harmonics#158)
<#158 (comment)>
I will have a look at the code next week, thanks for committing this. We
are also working on getting a precompiled wheel but deploying this is a bit
tricky because of combined torch and cuda dependency. We are currently
figuring out how to best approach this.
—
Reply to this email directly, view it on GitHub
<#158 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AC54ZBNIBFTMT4FYATMTY2T4QQ6PVAVCNFSM6AAAAACWP4CHEWVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHM2DANJWGY2TSNRYGM>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Add use_fft_contraction parameter to DistributedDiscreteContinuousConvS2 and DistributedDiscreteContinuousConvTransposeS2, mirroring the non-distributed implementation. When enabled, precomputes banded FFT psi representation on local latitude slices and uses FFT-based cross-correlation instead of loop-and-roll or CUDA kernels. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
Benchmark results (Tesla T4)
These results correspond to a 13.4x speed-up of the per-sample time. At equal batch size (B=16), the FFT path is ~10x faster than the torch sparse path (lower because the T4 has lower utilization for the FFT path). With this optimization, the torch-only timings are in the ballpark of the optimized CUDA timings, at least for this T4 GPU at 1-degree resolution.
Test plan