Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions cuda_core/cuda/core/_memoryview.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

Expand Down Expand Up @@ -27,6 +27,12 @@ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN

from cuda.core._memory import Buffer


try:
from ml_dtypes import bfloat16
except ImportError:
bfloat16 = None

# TODO(leofang): support NumPy structured dtypes


Expand Down Expand Up @@ -555,8 +561,13 @@ cdef object dtype_dlpack_to_numpy(DLDataType* dtype):
else:
raise TypeError(f'{bits}-bit bool is not supported')
elif dtype.code == kDLBfloat:
# TODO(leofang): use ml_dtype.bfloat16?
raise NotImplementedError('bfloat is not supported yet')
if bfloat16 is not None:
np_dtype = numpy.dtype("bfloat16")
else:
raise NotImplementedError(
'Support for bfloat16 within cuda-core requires `ml_dtypes`'
'to be installed.'
)
else:
raise TypeError('Unsupported dtype. dtype code: {}'.format(dtype.code))

Expand Down
2 changes: 1 addition & 1 deletion cuda_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ cu12 = ["cuda-bindings[all]==12.*"]
cu13 = ["cuda-bindings[all]==13.*"]

[dependency-groups]
test = ["cython>=3.2,<3.3", "setuptools", "pytest>=6.2.4", "pytest-randomly", "pytest-repeat"]
test = ["cython>=3.2,<3.3", "setuptools", "pytest>=6.2.4", "pytest-randomly", "pytest-repeat", "ml-dtypes"]
test-cu12 = ["cuda-core[test]", "cupy-cuda12x; python_version < '3.14'", "cuda-toolkit[cudart]==12.*"] # runtime headers needed by CuPy
test-cu13 = ["cuda-core[test]", "cupy-cuda13x; python_version < '3.14'", "cuda-toolkit[cudart]==13.*"] # runtime headers needed by CuPy
# free threaded build, cupy doesn't support free-threaded builds yet, so avoid installing it for now
Expand Down
115 changes: 114 additions & 1 deletion cuda_core/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

Expand All @@ -12,7 +12,12 @@
from numba import cuda as numba_cuda
except ImportError:
numba_cuda = None
try:
import torch
except ImportError:
torch = None
import cuda.core
import ml_dtypes
import numpy as np
import pytest
from cuda.core import Device
Expand Down Expand Up @@ -524,3 +529,111 @@ def test_from_array_interface_unsupported_strides(init_cuda):
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
# TODO: ideally this would raise on construction
smv.strides # noqa: B018


class _DLPackOnlyArray:
def __init__(self, array):
self.array = array

def __dlpack__(self, stream=None, max_version=None):
if max_version is None:
return self.array.__dlpack__(stream=stream)
return self.array.__dlpack__(stream=stream, max_version=max_version)

def __dlpack_device__(self):
return self.array.__dlpack_device__()

@property
def __cuda_array_interface__(self):
raise AssertionError("from_any_interface should prefer DLPack when available")


@pytest.mark.parametrize(
"slices",
[
param((slice(None), slice(None)), id="contiguous"),
param((slice(None, None, 2), slice(1, None, 2)), id="strided"),
],
)
def test_ml_dtypes_bfloat16_dlpack(init_cuda, slices):
a = cp.array([1, 2, 3, 4, 5, 6], dtype=ml_dtypes.bfloat16).reshape(2, 3)[slices]
smv = StridedMemoryView.from_dlpack(a, stream_ptr=0)

assert smv.size == a.size
assert smv.dtype == np.dtype("bfloat16")
assert smv.dtype == np.dtype(ml_dtypes.bfloat16)
assert smv.shape == a.shape
assert smv.ptr == a.data.ptr
assert smv.device_id == init_cuda.device_id
assert smv.is_device_accessible is True
assert smv.exporting_obj is a
assert smv.readonly is a.__cuda_array_interface__["data"][1]

strides_in_counts = convert_strides_to_counts(a.strides, a.dtype.itemsize)
if a.flags["C_CONTIGUOUS"]:
assert smv.strides in (None, strides_in_counts)
else:
assert smv.strides == strides_in_counts


def test_ml_dtypes_bfloat16_from_any_interface_prefers_dlpack(init_cuda):
a = cp.array([1, 2, 3, 4, 5, 6], dtype="bfloat16")
wrapped = _DLPackOnlyArray(a)
smv = StridedMemoryView.from_any_interface(wrapped, stream_ptr=0)

assert smv.dtype == np.dtype("bfloat16")
assert smv.shape == a.shape
assert smv.ptr == a.data.ptr
assert smv.device_id == init_cuda.device_id
assert smv.is_device_accessible is True
assert smv.exporting_obj is wrapped


@pytest.mark.parametrize(
"slices",
[
param((slice(None), slice(None)), id="contiguous"),
param((slice(None, None, 2), slice(1, None, 2)), id="strided"),
],
)
@pytest.mark.skipif(torch is None, reason="PyTorch is not installed")
def test_ml_dtypes_bfloat16_torch_dlpack(init_cuda, slices):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tests interop between pytorch tensor, it passes for me locally. But would require adding torch into test environment. Should we do that?

a = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.bfloat16, device="cuda").reshape(2, 3)[slices]
smv = StridedMemoryView.from_dlpack(a, stream_ptr=0)

assert smv.size == a.numel()
assert smv.dtype == np.dtype("bfloat16")
assert smv.dtype == np.dtype(ml_dtypes.bfloat16)
assert smv.shape == tuple(a.shape)
assert smv.ptr == a.data_ptr()
assert smv.device_id == init_cuda.device_id
assert smv.is_device_accessible is True
assert smv.exporting_obj is a

# PyTorch stride() returns strides in elements, convert to bytes first
strides_in_bytes = tuple(s * a.element_size() for s in a.stride())
strides_in_counts = convert_strides_to_counts(strides_in_bytes, a.element_size())
if a.is_contiguous():
assert smv.strides in (None, strides_in_counts)
else:
assert smv.strides == strides_in_counts


@pytest.fixture
def no_ml_dtypes(monkeypatch):
monkeypatch.setattr("cuda.core._memoryview.bfloat16", None)
yield


@pytest.mark.parametrize(
"api",
[
param(StridedMemoryView.from_dlpack, id="from_dlpack"),
param(StridedMemoryView.from_any_interface, id="from_any_interface"),
],
)
def test_ml_dtypes_bfloat16_dlpack_requires_ml_dtypes(init_cuda, no_ml_dtypes, api):
a = cp.array([1, 2, 3], dtype="bfloat16")
smv = api(a, stream_ptr=0)
with pytest.raises(NotImplementedError, match=r"requires `ml_dtypes`"):
smv.dtype # noqa: B018