Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion cuda_core/cuda/core/_memory/_managed_memory_resource.pyx
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from cuda.bindings cimport cydriver

from cuda.core._memory._memory_pool cimport _MemPool, _MemPoolOptions
from cuda.core._utils.cuda_utils cimport (
HANDLE_RETURN,
check_or_create_options,
)

from dataclasses import dataclass
import threading
import warnings

__all__ = ['ManagedMemoryResource', 'ManagedMemoryResourceOptions']

Expand Down Expand Up @@ -91,6 +95,7 @@ cdef class ManagedMemoryResource(_MemPool):
opts_base._type = cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED

super().__init__(device_id, opts_base)
_check_concurrent_managed_access()
ELSE:
raise RuntimeError("ManagedMemoryResource requires CUDA 13.0 or later")

Expand All @@ -103,3 +108,47 @@ cdef class ManagedMemoryResource(_MemPool):
def is_host_accessible(self) -> bool:
"""Return True. This memory resource provides host-accessible buffers."""
return True


cdef bint _concurrent_access_warned = False
cdef object _concurrent_access_lock = threading.Lock()


cdef inline _check_concurrent_managed_access():
"""Warn once if the platform lacks concurrent managed memory access."""
global _concurrent_access_warned
if _concurrent_access_warned:
return

cdef int c_concurrent = 0
with _concurrent_access_lock:
if _concurrent_access_warned:
return

# concurrent_managed_access is a system-level attribute for sm_60 and
# later, so any device will do.
with nogil:
HANDLE_RETURN(cydriver.cuDeviceGetAttribute(
&c_concurrent,
cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS,
0))
if not c_concurrent:
warnings.warn(
"This platform does not support concurrent managed memory access "
"(Device.properties.concurrent_managed_access is False). Host access to any managed "
"allocation is forbidden while any GPU kernel is in flight, even "
"if the kernel does not touch that allocation. Failing to "
"synchronize before host access will cause a segfault. "
"See: https://docs.nvidia.com/cuda/cuda-c-programming-guide/"
"index.html#gpu-exclusive-access-to-managed-memory",
UserWarning,
stacklevel=3
)

_concurrent_access_warned = True


def reset_concurrent_access_warning():
"""Reset the concurrent access warning flag for testing purposes."""
global _concurrent_access_warned
_concurrent_access_warned = False
3 changes: 2 additions & 1 deletion cuda_core/tests/test_build_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@

import pytest

# build_hooks.py imports Cython at the top level, so skip if not available
# build_hooks.py imports Cython and setuptools at the top level, so skip if not available
pytest.importorskip("Cython")
pytest.importorskip("setuptools")


def _load_build_hooks():
Expand Down
83 changes: 83 additions & 0 deletions cuda_core/tests/test_managed_memory_warning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Test that a warning is emitted when ManagedMemoryResource is created on a
platform without concurrent managed memory access.

These tests only run on affected platforms (concurrent_managed_access is False).
"""

import warnings

import cuda.bindings
import pytest
from cuda.core import Device, ManagedMemoryResource, ManagedMemoryResourceOptions
from cuda.core._memory._managed_memory_resource import reset_concurrent_access_warning

_cuda_major = int(cuda.bindings.__version__.split(".")[0])

requires_cuda_13 = pytest.mark.skipif(
_cuda_major < 13,
reason="ManagedMemoryResource requires CUDA 13.0 or later",
)


def _make_managed_mr(device_id):
"""Create a ManagedMemoryResource with an explicit device preference."""
return ManagedMemoryResource(options=ManagedMemoryResourceOptions(preferred_location=device_id))


@pytest.fixture
def device_without_concurrent_managed_access(init_cuda):
"""Return a device that lacks concurrent managed access, or skip."""
device = Device()
device.set_current()

if not device.properties.memory_pools_supported:
pytest.skip("Device does not support memory pools")

if device.properties.concurrent_managed_access:
pytest.skip("Device supports concurrent managed access; warning not applicable")

return device


@requires_cuda_13
def test_warning_emitted(device_without_concurrent_managed_access):
"""ManagedMemoryResource emits a warning when concurrent managed access is unsupported."""
dev_id = device_without_concurrent_managed_access.device_id
reset_concurrent_access_warning()

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
mr = _make_managed_mr(dev_id)

concurrent_warnings = [
warning for warning in w if "concurrent managed memory access" in str(warning.message).lower()
]
assert len(concurrent_warnings) == 1
assert concurrent_warnings[0].category is UserWarning
assert "segfault" in str(concurrent_warnings[0].message).lower()

mr.close()


@requires_cuda_13
def test_warning_emitted_only_once(device_without_concurrent_managed_access):
"""Warning fires only once even when multiple ManagedMemoryResources are created."""
dev_id = device_without_concurrent_managed_access.device_id
reset_concurrent_access_warning()

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
mr1 = _make_managed_mr(dev_id)
mr2 = _make_managed_mr(dev_id)

concurrent_warnings = [
warning for warning in w if "concurrent managed memory access" in str(warning.message).lower()
]
assert len(concurrent_warnings) == 1

mr1.close()
mr2.close()
Loading