diff --git a/cuda_core/cuda/core/_memoryview.pyx b/cuda_core/cuda/core/_memoryview.pyx index e6ad1dd7e9..c68d33b3fd 100644 --- a/cuda_core/cuda/core/_memoryview.pyx +++ b/cuda_core/cuda/core/_memoryview.pyx @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 @@ -27,6 +27,12 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN from cuda.core._memory import Buffer + +try: + from ml_dtypes import bfloat16 +except ImportError: + bfloat16 = None + # TODO(leofang): support NumPy structured dtypes @@ -555,8 +561,13 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype): else: raise TypeError(f'{bits}-bit bool is not supported') elif dtype.code == kDLBfloat: - # TODO(leofang): use ml_dtype.bfloat16? - raise NotImplementedError('bfloat is not supported yet') + if bfloat16 is not None: + np_dtype = numpy.dtype("bfloat16") + else: + raise NotImplementedError( + 'Support for bfloat16 within cuda-core requires `ml_dtypes`' + 'to be installed.' + ) else: raise TypeError('Unsupported dtype. dtype code: {}'.format(dtype.code)) diff --git a/cuda_core/pyproject.toml b/cuda_core/pyproject.toml index 1c295b04ce..e73ca937f0 100644 --- a/cuda_core/pyproject.toml +++ b/cuda_core/pyproject.toml @@ -55,7 +55,7 @@ cu12 = ["cuda-bindings[all]==12.*"] cu13 = ["cuda-bindings[all]==13.*"] [dependency-groups] -test = ["cython>=3.2,<3.3", "setuptools", "pytest>=6.2.4", "pytest-randomly", "pytest-repeat"] +test = ["cython>=3.2,<3.3", "setuptools", "pytest>=6.2.4", "pytest-randomly", "pytest-repeat", "ml-dtypes"] test-cu12 = ["cuda-core[test]", "cupy-cuda12x; python_version < '3.14'", "cuda-toolkit[cudart]==12.*"] # runtime headers needed by CuPy test-cu13 = ["cuda-core[test]", "cupy-cuda13x; python_version < '3.14'", "cuda-toolkit[cudart]==13.*"] # runtime headers needed by CuPy # free threaded build, cupy doesn't support free-threaded builds yet, so avoid installing it for now diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index dd9c52e817..3b91981436 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE @@ -12,7 +12,12 @@ from numba import cuda as numba_cuda except ImportError: numba_cuda = None +try: + import torch +except ImportError: + torch = None import cuda.core +import ml_dtypes import numpy as np import pytest from cuda.core import Device @@ -524,3 +529,111 @@ def test_from_array_interface_unsupported_strides(init_cuda): with pytest.raises(ValueError, match="strides must be divisible by itemsize"): # TODO: ideally this would raise on construction smv.strides # noqa: B018 + + +class _DLPackOnlyArray: + def __init__(self, array): + self.array = array + + def __dlpack__(self, stream=None, max_version=None): + if max_version is None: + return self.array.__dlpack__(stream=stream) + return self.array.__dlpack__(stream=stream, max_version=max_version) + + def __dlpack_device__(self): + return self.array.__dlpack_device__() + + @property + def __cuda_array_interface__(self): + raise AssertionError("from_any_interface should prefer DLPack when available") + + +@pytest.mark.parametrize( + "slices", + [ + param((slice(None), slice(None)), id="contiguous"), + param((slice(None, None, 2), slice(1, None, 2)), id="strided"), + ], +) +def test_ml_dtypes_bfloat16_dlpack(init_cuda, slices): + a = cp.array([1, 2, 3, 4, 5, 6], dtype=ml_dtypes.bfloat16).reshape(2, 3)[slices] + smv = StridedMemoryView.from_dlpack(a, stream_ptr=0) + + assert smv.size == a.size + assert smv.dtype == np.dtype("bfloat16") + assert smv.dtype == np.dtype(ml_dtypes.bfloat16) + assert smv.shape == a.shape + assert smv.ptr == a.data.ptr + assert smv.device_id == init_cuda.device_id + assert smv.is_device_accessible is True + assert smv.exporting_obj is a + assert smv.readonly is a.__cuda_array_interface__["data"][1] + + strides_in_counts = convert_strides_to_counts(a.strides, a.dtype.itemsize) + if a.flags["C_CONTIGUOUS"]: + assert smv.strides in (None, strides_in_counts) + else: + assert smv.strides == strides_in_counts + + +def test_ml_dtypes_bfloat16_from_any_interface_prefers_dlpack(init_cuda): + a = cp.array([1, 2, 3, 4, 5, 6], dtype="bfloat16") + wrapped = _DLPackOnlyArray(a) + smv = StridedMemoryView.from_any_interface(wrapped, stream_ptr=0) + + assert smv.dtype == np.dtype("bfloat16") + assert smv.shape == a.shape + assert smv.ptr == a.data.ptr + assert smv.device_id == init_cuda.device_id + assert smv.is_device_accessible is True + assert smv.exporting_obj is wrapped + + +@pytest.mark.parametrize( + "slices", + [ + param((slice(None), slice(None)), id="contiguous"), + param((slice(None, None, 2), slice(1, None, 2)), id="strided"), + ], +) +@pytest.mark.skipif(torch is None, reason="PyTorch is not installed") +def test_ml_dtypes_bfloat16_torch_dlpack(init_cuda, slices): + a = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.bfloat16, device="cuda").reshape(2, 3)[slices] + smv = StridedMemoryView.from_dlpack(a, stream_ptr=0) + + assert smv.size == a.numel() + assert smv.dtype == np.dtype("bfloat16") + assert smv.dtype == np.dtype(ml_dtypes.bfloat16) + assert smv.shape == tuple(a.shape) + assert smv.ptr == a.data_ptr() + assert smv.device_id == init_cuda.device_id + assert smv.is_device_accessible is True + assert smv.exporting_obj is a + + # PyTorch stride() returns strides in elements, convert to bytes first + strides_in_bytes = tuple(s * a.element_size() for s in a.stride()) + strides_in_counts = convert_strides_to_counts(strides_in_bytes, a.element_size()) + if a.is_contiguous(): + assert smv.strides in (None, strides_in_counts) + else: + assert smv.strides == strides_in_counts + + +@pytest.fixture +def no_ml_dtypes(monkeypatch): + monkeypatch.setattr("cuda.core._memoryview.bfloat16", None) + yield + + +@pytest.mark.parametrize( + "api", + [ + param(StridedMemoryView.from_dlpack, id="from_dlpack"), + param(StridedMemoryView.from_any_interface, id="from_any_interface"), + ], +) +def test_ml_dtypes_bfloat16_dlpack_requires_ml_dtypes(init_cuda, no_ml_dtypes, api): + a = cp.array([1, 2, 3], dtype="bfloat16") + smv = api(a, stream_ptr=0) + with pytest.raises(NotImplementedError, match=r"requires `ml_dtypes`"): + smv.dtype # noqa: B018