From 086b1a76a0fceb86a26a8cd607d2f849c70e2d72 Mon Sep 17 00:00:00 2001 From: Grain Team Date: Wed, 8 Apr 2026 16:54:29 -0700 Subject: [PATCH] Google internal changes only. PiperOrigin-RevId: 896773621 --- grain/_src/python/BUILD | 1 + grain/_src/python/data_loader_test.py | 10 ++++++++-- .../python/dataset/transformations/process_prefetch.py | 3 +-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 4fe4f017b..14dec1f15 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -175,6 +175,7 @@ py_test( ":operations", ":options", ":samplers", + "//grain", "//grain/_src/core:sharding", "//grain/_src/core:transforms", "//grain/_src/python/dataset", diff --git a/grain/_src/python/data_loader_test.py b/grain/_src/python/data_loader_test.py index 7598ebb96..7a947e7f5 100644 --- a/grain/_src/python/data_loader_test.py +++ b/grain/_src/python/data_loader_test.py @@ -15,9 +15,11 @@ from collections.abc import Sequence import functools +from multiprocessing import shared_memory import pathlib import platform import sys +import threading from typing import Any, Union from unittest import mock @@ -30,17 +32,17 @@ from grain._src.python import data_loader as data_loader_lib from grain._src.python import options from grain._src.python import samplers -# pylint: disable=g-importing-member from grain._src.python.data_sources import ArrayRecordDataSource from grain._src.python.data_sources import RangeDataSource from grain._src.python.data_sources import SharedMemoryDataSource from grain._src.python.dataset.transformations import batch +from grain._src.python.dataset.transformations import process_prefetch +from grain._src.python.dataset.transformations import source from grain._src.python.ipc import shared_memory_array from grain._src.python.operations import BatchOperation from grain._src.python.operations import FilterOperation from grain._src.python.operations import MapOperation from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint -# pylint: enable=g-importing-member import numpy as np import parameterized @@ -186,6 +188,10 @@ def test_copy_skipped_flags_c_contiguous(self): {"num_threads_per_worker": 15}, ]) class DataLoaderTest(absl_parameterized.TestCase): + + def tearDown(self): + super().tearDown() + # Number of prefetch threads for each Grain worker num_threads_per_worker: int | None diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index 3b033efcc..8663e313d 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -28,7 +28,6 @@ from absl import flags import cloudpickle -from grain._src.core import monitoring as grain_monitoring from grain._src.core.config import config import multiprocessing as mp from grain._src.python import grain_logging @@ -379,7 +378,7 @@ def _stats(self): ) # pytype: enable=attribute-error - # pylint: enable=protected-access + # pylint: disable=protected-access def start_prefetch(self) -> None: """Starts prefetching elements in background.