Skip to content

Add FFT-based DISCO S2 contraction for faster spherical convolutions#158

Open
mcgibbon wants to merge 15 commits intoNVIDIA:mainfrom
mcgibbon:fft-disco-contraction
Open

Add FFT-based DISCO S2 contraction for faster spherical convolutions#158
mcgibbon wants to merge 15 commits intoNVIDIA:mainfrom
mcgibbon:fft-disco-contraction

Conversation

@mcgibbon
Copy link
Contributor

@mcgibbon mcgibbon commented Mar 12, 2026

Summary

  • Replace the loop-and-roll longitude correlation in DISCO S2 convolutions with FFT-based circular cross-correlation, enabled via use_fft_contraction=True on DiscreteContinuousConvS2 and DiscreteContinuousConvTransposeS2
  • Add a benchmark framework (torch_harmonics.benchmark) with a registry pattern, CUDA event timing, hardware-dependent batch size scaling, and a CLI runner (python -m torch_harmonics.benchmark)
  • Register benchmarks for RealSHT, InverseRealSHT, and DISCO convolutions (both torch sparse and FFT paths)

Benchmark results (Tesla T4)

  ┌─────────────────────────────────┬───────────────┐
  │            Benchmark            │ Avg time (ms) │
  ├─────────────────────────────────┼───────────────┤
  │ disco_conv_s2_torch_1deg (B=16) │ ~97           │
  ├─────────────────────────────────┼───────────────┤
  │ disco_conv_s2_cuda_1deg (B=64)  │ ~24           │
  ├─────────────────────────────────┼───────────────┤
  │ disco_conv_s2_fft_1deg (B=64)   │ ~29           │
  └─────────────────────────────────┴───────────────┘

These results correspond to a 13.4x speed-up of the per-sample time. At equal batch size (B=16), the FFT path is ~10x faster than the torch sparse path (lower because the T4 has lower utilization for the FFT path). With this optimization, the torch-only timings are in the ballpark of the optimized CUDA timings, at least for this T4 GPU at 1-degree resolution.

Test plan

  • pytest tests/test_fft_contraction.py — verifies FFT and torch paths produce assert_close-compatible results for forward, transpose, and gradients across multiple grid/basis/stride configurations
  • python -m torch_harmonics.benchmark — runs all registered benchmarks and saves JSON results

mcgibbon and others added 8 commits March 11, 2026 22:10
… correlation

The existing torch reference implementation iterates over output longitudes
(typically 256-512 iterations), rolling the input and performing a sparse
matmul each step. This replaces that with an FFT circular cross-correlation:
one rfft, one einsum, and one irfft — eliminating the Python loop entirely.

Gated behind `use_fft_contraction=True` on DiscreteContinuousConvS2 and
DiscreteContinuousConvTransposeS2 so both paths remain available.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce a torch_harmonics.benchmark subpackage with:
- Timer infrastructure (CUDATimer, NullTimer, CPUEventPair) for GPU
  event-based and CPU wall-clock timing
- BenchmarkABC base class with registry via @register_benchmark
- CLI runner (python -m torch_harmonics.benchmark) that saves JSON results
- RealSHT and InverseRealSHT benchmarks at 1-degree resolution

Also add benchmark_results to .gitignore.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Register a disco_conv_s2_torch_1deg benchmark at 1-degree resolution
(B=4, 4 channels, 180x360) using the non-optimized torch contraction
path, which does not require the custom CUDA extension.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce hardware.py with a device-name-to-scale-factor lookup table
so benchmark batch sizes adapt to different GPUs. Base batch sizes are
tuned for Tesla T4 (factor 1.0). Unknown devices default to 1.0 with
a warning to add an entry for their hardware.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Register disco_conv_s2_fft_1deg benchmark using use_fft_contraction=True
  to compare against the torch sparse path benchmark
- Replace torch.allclose/assertTrue with torch.testing.assert_close in
  test_fft_contraction.py for better error messages on failure

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace torch.fft.rfft/irfft with the custom wrappers from
torch_harmonics.fft which ensure correct Hermitian symmetry by zeroing
imaginary parts of DC and Nyquist components, avoiding artifacts on
certain GPU backends.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Register disco_conv_s2_cuda_1deg benchmark using the custom CUDA
extension at B=64 to compare against FFT and torch sparse paths.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Collaborator

@azrael417 azrael417 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this MR, especially the benchmarking. I left some comments. In particular, I am not sure how useful the FFT accelerated DISCO is since the psi tensor can become very big relatively quickly when not stored as sparse tensor. Also, please make sure there is feature equality between the distributed and non-distributed routines.

"disco_kernels::_disco_s2_transpose_contraction_optimized", _disco_s2_transpose_contraction_bwd_optimized, setup_context=_setup_context_conv_backward)

# FFT-based contraction functions
def _densify_psi(psi_sparse: torch.Tensor, nlat_in: int, nlon_in: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can be pretty big. We decided not to go the FFT route since the device memory will be used up quickly when storing this as dense tensor. For example, for a small test case of 361x721 -> 361x721 that tensor is already 350 MB per basis function, so for basis function sizes of 8 or 9 it is already 2.8 GB.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I wonder if we can sparsify this a bit by restating the nlat_out dim as centered on the nlat_in position, so that we only need a few indices, as past a certain distance all entries will be zero.

return cls(conv=conv, x=x)

@final
def run_instance(self, timer: Timer) -> TensorDict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally, I would benchmark forward and backward pass. Some kernels are much more challenging in backward pass (for example the CUDA disco kernel since the naive implementation has many atomics). Please add backward benchmarks as well. You can create a random tensor from the forward output and run a backward pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, now forward and backward pass are computed.

y_fft = conv_fft(x)

self.assertEqual(y_ref.shape, y_fft.shape)
torch.testing.assert_close(y_ref, y_fft, atol=atol, rtol=1e-4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the compare_tensors routine from testutils for comparing tensors, it does some additional checks and also has better verbose printing.

@mcgibbon
Copy link
Contributor Author

There's also a way to further speed up the optimized cuda kernel at 1-degree resolution or coarser, by 1/4-degree resolution it doesn't really matter (unless we were to increase the cutoff radius). In the polar latitudes where the existing kernel is already sufficiently dense (it's 100% dense at the polarmost latitude) you can run the FFT approach instead. This got a 33% forward pass speed-up in the benchmark I ran.

● Here are the results with the precomputed BMM layout optimization:

  ┌───────────────────────────┬───────────┬─────────────┬─────────┐
  │        Resolution         │ CUDA (ms) │ Hybrid (ms) │ Speedup │
  ├───────────────────────────┼───────────┼─────────────┼─────────┤
  │ 1-degree (180x360)        │ 23.47     │ 15.96       │ 1.47x   │
  ├───────────────────────────┼───────────┼─────────────┼─────────┤
  │ Quarter-degree (720x1440) │ 25.62     │ 24.34       │ 1.05x   │
  └───────────────────────────┴───────────┴─────────────┴─────────┘

  The precomputed BMM optimization made a big difference at 1-degree — hybrid is now 1.47x faster than pure CUDA there. At quarter-degree the
  hybrid is only marginally faster (1.05x), which makes sense since the mid-latitude sparse CUDA kernel dominates at higher resolution (fewer
  polar rows proportionally).

Is this something you'd be interested to see in another PR?

@azrael417
Copy link
Collaborator

Please note that the tricky part for DISCO is the backward kernel, this has the potential to generate a lot of atomics. We have explored different data layouts as well, such as channels last (and it is available in another branch) but so far performance speedup findings are mixed. Some kernels perform better than the standard one on certain grid sizes, but those perform very badly on others. We found that the existing kernel performs well across the board for most cases. If you find a kernel which performs better for a range of grid sizes and K (3<=K<=9, 180< nlat < 721, 360 < nlon < 1440), let me know. Otherwise we would need to build a heuristic lookup table and that is a bit overkill imo.

@bonevbs
Copy link
Collaborator

bonevbs commented Mar 13, 2026

Thanks @mcgibbon for proposing this!

Does the math check out? Given that this is essentially a different algorithm, I would like to make sure that this is still computing the right thing. I am not sure the convolution decomposes into a 1-d convolution theorem given the non-Abelian nature of SO(3) group convolutions. Have you tried checking if the results are still the same? Especially for non-isotropic kernels, there might be a difference between the two.

Finally, the speedups seem to be more significant for lower resolutions. I want to avoid adding more complexity to a code to guarantee high maintainability. If speedups are expected to become less as the resolution increases, my preference is to stick to a simpler setup with a single algorithm.

@mcgibbon
Copy link
Contributor Author

Thanks @mcgibbon for proposing this!

Does the math check out? Given that this is essentially a different algorithm, I would like to make sure that this is still computing the right thing. I am not sure the convolution decomposes into a 1-d convolution theorem given the non-Abelian nature of SO(3) group convolutions. Have you tried checking if the results are still the same? Especially for non-isotropic kernels, there might be a difference between the two.

I believe it does, which is reflected in the passing of the torch.testing.assert_close tests included here. The current convolutions take a static kernel and apply it at each longitude through a for loop. This is exactly the cross-correlation of the kernel against the data across longitude. Each latitude can be separated into its own 1-D cross-correlation problem, and then the cross-correlations summed between them. Cross-correlation is more effective to compute through the FFT approach taken here than it is through a python for-loop, leading to the speedup.

Finally, the speedups seem to be more significant for lower resolutions. I want to avoid adding more complexity to a code to guarantee high maintainability. If speedups are expected to become less as the resolution increases, my preference is to stick to a simpler setup with a single algorithm.

Even at quarter-degree resolution, the fft-based code is 4x faster than the existing torch code. At 1/8th-degree resolution, you're probably going to be running the optimized CUDA kernel anyways in some kind of spatial parallel mode - I run out of memory trying to densify a ~45GB psi tensor.

Up to you on whether you want to incorporate it, but we will need to fork the code for our own purposes if not. The convolutions are currently very slow on CPU, and it's pretty easy to end up with a version of this code that doesn't have the optimized cuda kernels built.

If you agree about the fft being a more efficient way to do cross-convolution and the issue is just with the memory usage of the precomputed psi, I could update this to replace the torch implementation with the fft-based version, and have the precomputation of psi be the flag. That should have lower maintenance burden. Even without the pre-computation the new code is significantly faster, I believe, but I'd have to re-check and confirm.

@mcgibbon
Copy link
Contributor Author

mcgibbon commented Mar 13, 2026

Wow, the banded implementation (only storing the nlat_out-local region of psi that is nonzero) is much better actually:

● Results with the banded FFT implementation:

  ┌──────────────────────┬────────────┬─────────────────┬─────────┐
  │      Resolution      │ Torch (ms) │ FFT banded (ms) │ Speedup │
  ├──────────────────────┼────────────┼─────────────────┼─────────┤
  │ 1-degree (B=64)      │ 635.38     │ 33.14           │ 19.2x   │
  ├──────────────────────┼────────────┼─────────────────┼─────────┤
  │ Quarter-degree (B=4) │ 1971.88    │ 45.61           │ 43.2x   │
  └──────────────────────┴────────────┴─────────────────┴─────────┘

  The quarter-degree FFT went from 497ms (old dense) to 45.6ms with banding — a 10.9x improvement on top of the previous FFT speedup, since the
  einsum now contracts over ~20 latitude bands instead of 720.

And now the memory usage is drastically reduced. I was able to run it at 1/8th degree resolution:

● ┌──────────────────┬────────────┬─────────────────┬─────────┐
  │    Resolution    │ Torch (ms) │ FFT banded (ms) │ Speedup │
  ├──────────────────┼────────────┼─────────────────┼─────────┤
  │ 1/8-degree (B=1) │ 3834.19    │ 85.81           │ 44.7x   │
  └──────────────────┴────────────┴─────────────────┴─────────┘

  Previously this OOM'd with the dense approach. Now it runs fine and is nearly 45x faster than torch.

With this change, it looks like the update gives even more performance at higher resolutions, not less.

mcgibbon and others added 3 commits March 13, 2026 15:44
Instead of densifying the full (K, nlat_out, nlat_in, nlon) tensor,
store only the contiguous band of input latitudes with nonzero entries
per output latitude. This reduces memory by ~nlat_in/band_width and
makes the einsum contract over ~20 bands instead of all input lats.

Fixes OOM at 1/8-degree resolution and speeds up quarter-degree FFT
contraction by ~10x (498ms -> 46ms).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…nchmarks

Add MemoryBenchmark context manager that tracks peak GPU memory
allocation and reservation during benchmark runs.

Update DISCO benchmarks to time forward and backward passes
independently using child timers, as requested in review.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Switch from torch.testing.assert_close to the project's
compare_tensors utility for better verbose diagnostics on failure.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
return cls(conv=conv, x=x)

@final
def run_instance(self, timer: Timer) -> TensorDict:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, now forward and backward pass are computed.

return MemoryResult(max_alloc=self._max_alloc, max_reserved=self._max_reserved)


def benchmark_memory() -> MemoryBenchmark:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added memory benchmarking, the fft pathway still uses ~3x the memory of CUDA which uses ~2x the memory of torch, but it's significantly less than before.

mcgibbon and others added 3 commits March 13, 2026 16:03
Support different in/out shapes in benchmark infrastructure and add
FFT and CUDA benchmarks for 2x downsampling (360x720 -> 180x360).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… contraction

Avoids a full tensor copy by applying the roll(-1) as a pointwise
exp(2*pi*i*k/N) multiply after the rfft. Also removes redundant
.to(X_f.dtype) casts on psi_fft_conj (already complex64).

Benchmark results (Tesla T4, 5 iters) show no measurable change at
this resolution — the roll was already fast at these tensor sizes:

  tconv_fft_1deg_to_halfdeg:
    Before: fwd=86.8ms bwd=193.1ms total=280.6ms mem=2719MB
    After:  fwd=86.6ms bwd=190.5ms total=277.8ms mem=2719MB

  fft_1deg:
    Before: fwd=38.0ms bwd=43.7ms total=82.1ms mem=1428MB
    After:  fwd=38.3ms bwd=42.2ms total=80.8ms mem=1428MB

The benefit should be larger at higher resolutions where the roll
copies more data.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mcgibbon
Copy link
Contributor Author

mcgibbon commented Mar 13, 2026

There was one more minor speedup to do in ee2e9d5 to avoid a roll by turning it into a multiplication in the frequency domain. It's very minor - I'd be happy to revert it if you think that would keep the code clearer.

Replace torch.roll with frequency-domain phase shift in FFT transpose contraction
Avoids a full tensor copy by applying the roll(-1) as a pointwise
exp(2*pi*i*k/N) multiply after the rfft. Also removes redundant
.to(X_f.dtype) casts on psi_fft_conj (already complex64).

Benchmark results (Tesla T4, 5 iters) show no measurable change at
this resolution — the roll was already fast at these tensor sizes:

  tconv_fft_1deg_to_halfdeg:
    Before: fwd=86.8ms bwd=193.1ms total=280.6ms mem=2719MB
    After:  fwd=86.6ms bwd=190.5ms total=277.8ms mem=2719MB

  fft_1deg:
    Before: fwd=38.0ms bwd=43.7ms total=82.1ms mem=1428MB
    After:  fwd=38.3ms bwd=42.2ms total=80.8ms mem=1428MB

The benefit should be larger at higher resolutions where the roll
copies more data.

@mcgibbon
Copy link
Contributor Author

Should be ready for another look.

@azrael417
Copy link
Collaborator

Please note that we have distributed DISCO convolution as well as SHT support, only the attention kernel is not yet distributed.

@azrael417
Copy link
Collaborator

I will have a look at the code next week, thanks for committing this. We are also working on getting a precompiled wheel but deploying this is a bit tricky because of combined torch and cuda dependency. We are currently figuring out how to best approach this.

@mcgibbon
Copy link
Contributor Author

mcgibbon commented Mar 13, 2026 via email

Add use_fft_contraction parameter to DistributedDiscreteContinuousConvS2
and DistributedDiscreteContinuousConvTransposeS2, mirroring the
non-distributed implementation. When enabled, precomputes banded FFT
psi representation on local latitude slices and uses FFT-based
cross-correlation instead of loop-and-roll or CUDA kernels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants