diff --git a/dion/newton_schulz_triton.py b/dion/newton_schulz_triton.py index 30e21a8..6298277 100644 --- a/dion/newton_schulz_triton.py +++ b/dion/newton_schulz_triton.py @@ -76,6 +76,7 @@ def ns_line_1_kernel( BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, LOWER_UPPER: tl.constexpr, + INPUT_PRECISION: tl.constexpr = "tf32", ): """ Input A has shape (M, K) @@ -111,7 +112,7 @@ def ns_line_1_kernel( for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) + accumulator = tl.dot(a, at, accumulator, input_precision=INPUT_PRECISION) a_ptrs += BLOCK_SIZE_K * a_stride_c at_ptrs += BLOCK_SIZE_K * a_stride_c @@ -148,6 +149,7 @@ def ns_line_1(A: Tensor, *, out: Tensor = None): batch_size = A.size(0) if A.ndim == 3 else 1 input_batch_stride = A.stride(0) if A.ndim == 3 else 0 output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + input_precision = "ieee" if A.dtype == torch.float32 else "tf32" grid = lambda meta: ( batch_size @@ -165,6 +167,7 @@ def ns_line_1(A: Tensor, *, out: Tensor = None): c_stride_b=output_batch_stride, c_stride_r=out.stride(-2), c_stride_c=out.stride(-1), + INPUT_PRECISION=input_precision, ) return out @@ -192,6 +195,7 @@ def ns_line_2_kernel( BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, LOWER_UPPER: tl.constexpr, + INPUT_PRECISION: tl.constexpr = "tf32", ): """ Input A is square matrix with shape (M, M) @@ -227,7 +231,7 @@ def ns_line_2_kernel( for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) - accumulator = tl.dot(a, at, accumulator) + accumulator = tl.dot(a, at, accumulator, input_precision=INPUT_PRECISION) a_ptrs += BLOCK_SIZE_K * a_stride_c at_ptrs += BLOCK_SIZE_K * a_stride_c @@ -279,6 +283,7 @@ def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor = None): batch_size = A.size(0) if A.ndim == 3 else 1 input_batch_stride = A.stride(0) if A.ndim == 3 else 0 output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + input_precision = "ieee" if A.dtype == torch.float32 else "tf32" grid = lambda meta: ( batch_size @@ -297,6 +302,7 @@ def ns_line_2(A: Tensor, alpha: float, beta: float, *, out: Tensor = None): c_stride_c=out.stride(-1), alpha=alpha, beta=beta, + INPUT_PRECISION=input_precision, ) return out diff --git a/tests/test_newton_shultz.py b/tests/test_newton_shultz.py deleted file mode 100644 index 80e5903..0000000 --- a/tests/test_newton_shultz.py +++ /dev/null @@ -1,80 +0,0 @@ -# tests/test_newton_schulz.py -import pytest -import torch - -from dion.newton_schulz_triton import ( - ns_line_1, - ns_line_2, - newton_schulz_triton, - zeropower_via_newtonschulz5, -) - -# -----------------------------------------------------------------------------# -# General settings -# -----------------------------------------------------------------------------# - -# Allow a lot of recompiles in Torch-Triton -torch._dynamo.config.cache_size_limit = 100 # noqa: SLF001 - -CUDA_AVAILABLE = torch.cuda.is_available() - -# -----------------------------------------------------------------------------# -# Helper -# -----------------------------------------------------------------------------# - - -def _assert_close(result: torch.Tensor, correct: torch.Tensor, *, tol: float = 5e-2): - """Assert two tensors are close enough for the test to pass.""" - assert ( - result.dtype == correct.dtype - ), f"dtype mismatch — got {result.dtype}, expected {correct.dtype}" - assert ( - result.shape == correct.shape - ), f"shape mismatch — got {result.shape}, expected {correct.shape}" - assert torch.allclose( - result, correct, atol=tol, rtol=tol - ), f"max-abs-diff {torch.abs(result - correct).max().item():.3e} > {tol}" - - -@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA device required") -@pytest.mark.parametrize("m,n", [(256, 256), (256, 1024)]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -def test_ns_line_1(m: int, n: int, dtype: torch.dtype): - """ns_line_1 should compute A @ A^T (batched and unbatched).""" - A = torch.randn(m, n, dtype=dtype, device="cuda") - _assert_close(ns_line_1(A), A @ A.mT) - - A_batched = torch.randn(4, m, n, dtype=dtype, device="cuda") - _assert_close(ns_line_1(A_batched), A_batched @ A_batched.mT) - - -@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA device required") -@pytest.mark.parametrize("m", [256]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -def test_ns_line_2(m: int, dtype: torch.dtype): - """ns_line_2 should compute alpha(A@A^T) + beta*A for symmetric A.""" - alpha, beta = torch.randn(1).item(), torch.randn(1).item() - - A = torch.randn(m, m, dtype=dtype, device="cuda") - A = (A + A.mT) / 2 # ensure symmetry - correct = alpha * (A @ A.mT) + beta * A - _assert_close(ns_line_2(A, alpha=alpha, beta=beta), correct) - - A_batched = torch.randn(4, m, m, dtype=dtype, device="cuda") - A_batched = (A_batched + A_batched.mT) / 2 - correct_batched = alpha * (A_batched @ A_batched.mT) + beta * A_batched - _assert_close(ns_line_2(A_batched, alpha=alpha, beta=beta), correct_batched) - - -@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA device required") -@pytest.mark.parametrize("m,n", [(256, 256), (256, 1024)]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -def test_newton_schulz_triton(m: int, n: int, dtype: torch.dtype): - """Fast Triton implementation should match the reference Newton-Schulz.""" - G = torch.randn(m, n, dtype=dtype, device="cuda") - _assert_close(newton_schulz_triton(G), zeropower_via_newtonschulz5(G)) - - G_batched = torch.randn(4, m, n, dtype=dtype, device="cuda") - _assert_close( - newton_schulz_triton(G_batched), zeropower_via_newtonschulz5(G_batched) - ) diff --git a/tests/test_newton_shulz.py b/tests/test_newton_shulz.py new file mode 100644 index 0000000..40953c7 --- /dev/null +++ b/tests/test_newton_shulz.py @@ -0,0 +1,129 @@ +# tests/test_newton_shulz.py +""" +Accuracy tests for Triton Newton-Schulz kernels against a numpy float64 +CPU reference. Each test asserts that the Triton kernel has similar or +better error (both mean and max) compared to PyTorch's cuBLAS for the +same operation. +""" +import numpy as np +import pytest +import torch + +from dion.newton_schulz_triton import ( + ns_line_1, + ns_line_2, + newton_schulz_triton, + zeropower_via_newtonschulz5, +) + +torch._dynamo.config.cache_size_limit = 100 # noqa: SLF001 + +CUDA_AVAILABLE = torch.cuda.is_available() + +# For bf16/f16, Triton should be at least as accurate as cuBLAS (multiplier=1). +# For f32, Triton's tl.dot uses a less favorable internal reduction tree than +# cuBLAS even with input_precision="ieee", so we allow some slack. +# Empirically (unbatched, 20 runs each): +# mean ratio: up to ~3.6x (shape 256x1024) +# max ratio: up to ~14x (shape 256x1024, outlier-sensitive) +# Batched cases show ratio 1.0 because torch.bmm uses the same reduction +# order as Triton (i.e. both produce bitwise-identical results), unlike +# torch.mm which uses a different cuBLAS algorithm. +# This is a Triton limitation — improving it would require raw CUDA. +_F32_MEAN_ERR_MULTIPLIER = 5 +_F32_MAX_ERR_MULTIPLIER = 15 + + +def _abs_errs(result: torch.Tensor, reference: torch.Tensor) -> tuple[float, float]: + """Return (mean, max) absolute error between a GPU result and a CPU reference.""" + diff = (result.cpu().float() - reference.float()).abs() + return diff.mean().item(), diff.max().item() + + +def _numpy_ref_aat(A: torch.Tensor) -> torch.Tensor: + """Compute A @ A^T in numpy float64, return as float32.""" + a = A.cpu().float().numpy().astype(np.float64) + out = a @ a.T if a.ndim == 2 else a @ np.swapaxes(a, -2, -1) + return torch.from_numpy(out.astype(np.float32)) + + +def _numpy_ref_ns_line_2(A: torch.Tensor, alpha: float, beta: float) -> torch.Tensor: + """Compute alpha * A @ A^T + beta * A in numpy float64.""" + a = A.cpu().float().numpy().astype(np.float64) + aT = a.T if a.ndim == 2 else np.swapaxes(a, -2, -1) + out = alpha * (a @ aT) + beta * a + return torch.from_numpy(out.astype(np.float32)) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA device required") +@pytest.mark.parametrize("m,n", [(256, 256), (256, 1024)]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_ns_line_1_accuracy(m: int, n: int, dtype: torch.dtype): + """Triton ns_line_1 should have similar or better error than cuBLAS for A @ A^T.""" + mean_mul = _F32_MEAN_ERR_MULTIPLIER if dtype == torch.float32 else 1 + max_mul = _F32_MAX_ERR_MULTIPLIER if dtype == torch.float32 else 1 + for A in [ + torch.randn(m, n, dtype=dtype, device="cuda"), + torch.randn(4, m, n, dtype=dtype, device="cuda"), + ]: + ref = _numpy_ref_aat(A) + triton_mean, triton_max = _abs_errs(ns_line_1(A), ref) + cublas_mean, cublas_max = _abs_errs(A @ A.mT, ref) + assert triton_mean <= cublas_mean * mean_mul, ( + f"Triton mean err {triton_mean:.3e} > cuBLAS mean err {cublas_mean:.3e} * {mean_mul} " + f"(shape={tuple(A.shape)}, dtype={A.dtype})" + ) + assert triton_max <= cublas_max * max_mul, ( + f"Triton max err {triton_max:.3e} > cuBLAS max err {cublas_max:.3e} * {max_mul} " + f"(shape={tuple(A.shape)}, dtype={A.dtype})" + ) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA device required") +@pytest.mark.parametrize("m", [256]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_ns_line_2_accuracy(m: int, dtype: torch.dtype): + """Triton ns_line_2 should have similar or better error than cuBLAS.""" + mean_mul = _F32_MEAN_ERR_MULTIPLIER if dtype == torch.float32 else 1 + max_mul = _F32_MAX_ERR_MULTIPLIER if dtype == torch.float32 else 1 + alpha, beta = torch.randn(1).item(), torch.randn(1).item() + + for A in [ + torch.randn(m, m, dtype=dtype, device="cuda"), + torch.randn(4, m, m, dtype=dtype, device="cuda"), + ]: + A = (A + A.mT) / 2 + ref = _numpy_ref_ns_line_2(A, alpha, beta) + triton_mean, triton_max = _abs_errs(ns_line_2(A, alpha=alpha, beta=beta), ref) + cublas_mean, cublas_max = _abs_errs(alpha * (A @ A.mT) + beta * A, ref) + assert triton_mean <= cublas_mean * mean_mul, ( + f"Triton mean err {triton_mean:.3e} > cuBLAS mean err {cublas_mean:.3e} * {mean_mul} " + f"(shape={tuple(A.shape)}, dtype={A.dtype})" + ) + assert triton_max <= cublas_max * max_mul, ( + f"Triton max err {triton_max:.3e} > cuBLAS max err {cublas_max:.3e} * {max_mul} " + f"(shape={tuple(A.shape)}, dtype={A.dtype})" + ) + + +@pytest.mark.skipif(not CUDA_AVAILABLE, reason="CUDA device required") +@pytest.mark.parametrize("m,n", [(256, 256), (256, 1024)]) +def test_newton_schulz_triton_vs_reference(m: int, n: int): + """Triton and reference Newton-Schulz should agree within tolerance. + + Both implementations use the same algorithm (same constants, same + iteration count) and always operate in bf16 internally. Small + differences arise from kernel-level reduction order. + """ + for G in [ + torch.randn(m, n, dtype=torch.float32, device="cuda"), + torch.randn(4, m, n, dtype=torch.float32, device="cuda"), + ]: + triton_out = newton_schulz_triton(G) + ref_out = zeropower_via_newtonschulz5(G) + diff = (triton_out - ref_out).abs().max().item() + # Empirically max diff is ~7.8e-3 across 50 runs; 0.02 gives ~2.5x headroom. + assert diff < 0.02, ( + f"Newton-Schulz implementations diverged: max diff {diff:.3e} " + f"(shape={tuple(G.shape)}, dtype={G.dtype})" + )