Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions cuda_core/tests/example_tests/test_basic_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@

# If we have subcategories of examples in the future, this file can be split along those lines

import glob
import os
from pathlib import Path

import pytest
from cuda.core import Device

from .utils import run_example

samples_path = os.path.join(os.path.dirname(__file__), "..", "..", "examples")
sample_files = glob.glob(samples_path + "**/*.py", recursive=True)
# not dividing, but navigating into the "examples" directory.
EXAMPLES_DIR = Path(__file__).resolve().parent.parent.parent / "examples"

# recursively glob for test files in examples directory, sort for deterministic
# test runs. Relative paths offer cleaner output when tests fail.
SAMPLE_FILES = sorted([str(p.relative_to(EXAMPLES_DIR)) for p in EXAMPLES_DIR.glob("**/*.py")])

@pytest.mark.parametrize("example", sample_files)

@pytest.mark.parametrize("example_rel_path", SAMPLE_FILES)
class TestExamples:
def test_example(self, example, deinit_cuda):
run_example(samples_path, example)
if Device().device_id != 0:
Device(0).set_current()
# deinit_cuda is defined in conftest.py and pops the cuda context automatically.
def test_example(self, example_rel_path: str, deinit_cuda) -> None:
run_example(str(EXAMPLES_DIR), example_rel_path)
53 changes: 35 additions & 18 deletions cuda_core/tests/example_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

import gc
import os
import importlib.util
import sys
from pathlib import Path

import pytest

Expand All @@ -12,24 +13,38 @@ class SampleTestError(Exception):
pass


def parse_python_script(filepath):
if not filepath.endswith(".py"):
raise ValueError(f"{filepath} not supported")
with open(filepath, encoding="utf-8") as f:
script = f.read()
return script
def run_example(parent_dir: str, rel_path_to_example: str, env=None) -> None:
fullpath = Path(parent_dir) / rel_path_to_example
module_name = fullpath.stem

old_sys_path = sys.path.copy()
old_argv = sys.argv

def run_example(samples_path, filename, env=None):
fullpath = os.path.join(samples_path, filename)
script = parse_python_script(fullpath)
try:
old_argv = sys.argv
sys.argv = [fullpath]
old_sys_path = sys.path.copy()
sys.path.append(samples_path)
# TODO: Refactor the examples to give them a common callable `main()` to avoid needing to use exec here?
exec(script, env if env else {}) # noqa: S102
sys.path.append(parent_dir)
sys.argv = [str(fullpath)]

# Collect metadata for file 'module_name' located at 'fullpath'.
# CASE: file does not exist -> spec is none.
# CASE: file is not .py -> spec is none.
# CASE: file does not have proper loader (module.spec.__loader__) -> spec.loader is none.
spec = importlib.util.spec_from_file_location(module_name, fullpath)

if spec is None or spec.loader is None:
raise ImportError(f"Failed to load spec for {rel_path_to_example}")

# Otherwise convert the spec to a module, then run the module.
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module

# This runs top-level code.
# CASE: exec() -> top-level code is implicitly run.
spec.loader.exec_module(module)

# CASE: main() -> we find main and call it below.
if hasattr(module, "main"):
module.main()

except ImportError as e:
# for samples requiring any of optional dependencies
for m in ("cupy", "torch"):
Expand All @@ -40,14 +55,16 @@ def run_example(samples_path, filename, env=None):
raise
except SystemExit:
# for samples that early return due to any missing requirements
pytest.skip(f"skip {filename}")
pytest.skip(f"skip {rel_path_to_example}")
except Exception as e:
msg = "\n"
msg += f"Got error ({filename}):\n"
msg += f"Got error ({rel_path_to_example}):\n"
msg += str(e)
raise SampleTestError(msg) from e
finally:
sys.path = old_sys_path
sys.argv = old_argv

# further reduce the memory watermark
sys.modules.pop(module_name, None)
gc.collect()