diff --git a/cuda_core/tests/example_tests/test_basic_examples.py b/cuda_core/tests/example_tests/test_basic_examples.py index 640b53c2fc..413e8a1a61 100644 --- a/cuda_core/tests/example_tests/test_basic_examples.py +++ b/cuda_core/tests/example_tests/test_basic_examples.py @@ -3,21 +3,22 @@ # If we have subcategories of examples in the future, this file can be split along those lines -import glob -import os +from pathlib import Path import pytest -from cuda.core import Device from .utils import run_example -samples_path = os.path.join(os.path.dirname(__file__), "..", "..", "examples") -sample_files = glob.glob(samples_path + "**/*.py", recursive=True) +# not dividing, but navigating into the "examples" directory. +EXAMPLES_DIR = Path(__file__).resolve().parent.parent.parent / "examples" +# recursively glob for test files in examples directory, sort for deterministic +# test runs. Relative paths offer cleaner output when tests fail. +SAMPLE_FILES = sorted([str(p.relative_to(EXAMPLES_DIR)) for p in EXAMPLES_DIR.glob("**/*.py")]) -@pytest.mark.parametrize("example", sample_files) + +@pytest.mark.parametrize("example_rel_path", SAMPLE_FILES) class TestExamples: - def test_example(self, example, deinit_cuda): - run_example(samples_path, example) - if Device().device_id != 0: - Device(0).set_current() + # deinit_cuda is defined in conftest.py and pops the cuda context automatically. + def test_example(self, example_rel_path: str, deinit_cuda) -> None: + run_example(str(EXAMPLES_DIR), example_rel_path) diff --git a/cuda_core/tests/example_tests/utils.py b/cuda_core/tests/example_tests/utils.py index 9b5dc57e5f..1f2b35d551 100644 --- a/cuda_core/tests/example_tests/utils.py +++ b/cuda_core/tests/example_tests/utils.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import gc -import os +import importlib.util import sys +from pathlib import Path import pytest @@ -12,24 +13,38 @@ class SampleTestError(Exception): pass -def parse_python_script(filepath): - if not filepath.endswith(".py"): - raise ValueError(f"{filepath} not supported") - with open(filepath, encoding="utf-8") as f: - script = f.read() - return script +def run_example(parent_dir: str, rel_path_to_example: str, env=None) -> None: + fullpath = Path(parent_dir) / rel_path_to_example + module_name = fullpath.stem + old_sys_path = sys.path.copy() + old_argv = sys.argv -def run_example(samples_path, filename, env=None): - fullpath = os.path.join(samples_path, filename) - script = parse_python_script(fullpath) try: - old_argv = sys.argv - sys.argv = [fullpath] - old_sys_path = sys.path.copy() - sys.path.append(samples_path) - # TODO: Refactor the examples to give them a common callable `main()` to avoid needing to use exec here? - exec(script, env if env else {}) # noqa: S102 + sys.path.append(parent_dir) + sys.argv = [str(fullpath)] + + # Collect metadata for file 'module_name' located at 'fullpath'. + # CASE: file does not exist -> spec is none. + # CASE: file is not .py -> spec is none. + # CASE: file does not have proper loader (module.spec.__loader__) -> spec.loader is none. + spec = importlib.util.spec_from_file_location(module_name, fullpath) + + if spec is None or spec.loader is None: + raise ImportError(f"Failed to load spec for {rel_path_to_example}") + + # Otherwise convert the spec to a module, then run the module. + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + # This runs top-level code. + # CASE: exec() -> top-level code is implicitly run. + spec.loader.exec_module(module) + + # CASE: main() -> we find main and call it below. + if hasattr(module, "main"): + module.main() + except ImportError as e: # for samples requiring any of optional dependencies for m in ("cupy", "torch"): @@ -40,14 +55,16 @@ def run_example(samples_path, filename, env=None): raise except SystemExit: # for samples that early return due to any missing requirements - pytest.skip(f"skip {filename}") + pytest.skip(f"skip {rel_path_to_example}") except Exception as e: msg = "\n" - msg += f"Got error ({filename}):\n" + msg += f"Got error ({rel_path_to_example}):\n" msg += str(e) raise SampleTestError(msg) from e finally: sys.path = old_sys_path sys.argv = old_argv + # further reduce the memory watermark + sys.modules.pop(module_name, None) gc.collect()