diff --git a/gpustack_runtime/detector/nvidia.py b/gpustack_runtime/detector/nvidia.py index 3e4ed71..a896102 100644 --- a/gpustack_runtime/detector/nvidia.py +++ b/gpustack_runtime/detector/nvidia.py @@ -4,6 +4,7 @@ import logging import math import re +import threading import time from _ctypes import byref from functools import lru_cache @@ -75,10 +76,29 @@ def detect_pci_devices() -> dict[str, PCIDevice]: return {} return {dev.address: dev for dev in pci_devs} + _detect_lock = threading.Lock() + def __init__(self): super().__init__(ManufacturerEnum.NVIDIA) - def detect(self) -> Devices | None: # noqa: PLR0915 + def detect(self) -> Devices | None: + """ + Detect NVIDIA GPUs using pynvml with thread-safe locking. + + Returns: + A list of detected NVIDIA GPU devices, + or None if not supported. + + Raises: + If there is an error during detection. + + """ + with self._detect_lock: + result = self._detect_impl() + + return result + + def _detect_impl(self) -> Devices | None: # noqa: PLR0915 """ Detect NVIDIA GPUs using pynvml. @@ -97,7 +117,6 @@ def detect(self) -> Devices | None: # noqa: PLR0915 try: pci_devs = NVIDIADetector.detect_pci_devices() - pynvml.nvmlInit() if not envs.GPUSTACK_RUNTIME_DETECT_NO_TOOLKIT_CALL: try: