From d63f5e6538271f0a7f7cb3b248e8d92c116dd9ce Mon Sep 17 00:00:00 2001 From: Sagun Bajra Date: Wed, 8 Apr 2026 09:24:57 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 896542394 --- .../_src/python/dataset/transformations/BUILD | 1 - .../dataset/transformations/interleave.py | 2 ++ .../transformations/interleave_test.py | 24 ++++++++++++++++++- .../dataset/transformations/prefetch.py | 5 +++- .../dataset/transformations/prefetch_test.py | 7 +++++- 5 files changed, 35 insertions(+), 4 deletions(-) diff --git a/grain/_src/python/dataset/transformations/BUILD b/grain/_src/python/dataset/transformations/BUILD index aa832cc05..6cfa17703 100644 --- a/grain/_src/python/dataset/transformations/BUILD +++ b/grain/_src/python/dataset/transformations/BUILD @@ -64,7 +64,6 @@ py_test( shard_count = 50, srcs_version = "PY3", deps = [ - "//grain/_src/core:config", "//grain/_src/core:transforms", "//grain/_src/python:options", "//grain/_src/python/dataset", diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index a3a299d12..a87ede04b 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -369,9 +369,11 @@ def _add_prefetch_and_make_iterator( ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access ) iterator = iter_dataset.__iter__() + # Propagate options applied after InterleaveIterDataset to the iterators that # are being interleaved. iterator._ctx.dataset_options = interleave_iterator_obj._ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access + if start_prefetch: iterator.start_prefetch() return iterator diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index a5ed001a8..27eeb8929 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -90,7 +90,11 @@ def __iter__(self) -> dataset.DatasetIterator: return _IteratorIdDatasetIterator(self._parent.__iter__()) -class InterleaveIterDatasetTest(parameterized.TestCase): +@absltest.skipThisClass("Base class") +class _InterleaveIterDatasetTestBase(parameterized.TestCase): + + def _maybe_wrap_ds(self, ds): + return ds @parameterized.named_parameters(*_INTERLEAVE_TEST_CASES) def test_interleaved_mix(self, to_mix, cycle_length, expected): @@ -99,6 +103,7 @@ def test_interleaved_mix(self, to_mix, cycle_length, expected): for elements in to_mix ] ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length) + ds = self._maybe_wrap_ds(ds) self.assertEqual(list(ds), expected) # Sanity check. flat_inputs = [] @@ -113,6 +118,7 @@ def test_checkpoint(self, to_mix, cycle_length, expected): for elements in to_mix ] ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length) + ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() checkpoints = {} for i in range(len(expected)): @@ -138,6 +144,7 @@ def test_checkpoint_with_extra_threads_creating_iterators( num_make_iter_threads=10, make_iter_buffer_size=10, ) + ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() checkpoints = {} for i in range(len(expected)): @@ -158,6 +165,7 @@ def make_dummy_source(filename): filenames = dataset.MapDataset.source(["11", "2345", "678", "9999"]) sources = filenames.shuffle(seed=42).map(make_dummy_source) ds = interleave.InterleaveIterDataset(sources, cycle_length=2) + ds = self._maybe_wrap_ds(ds) self.assertEqual( list(ds), ["1", "2", "1", "3", "4", "6", "5", "7", "8", "9", "9", "9", "9"], @@ -168,6 +176,7 @@ def test_with_mp_prefetch(self): lambda i: dataset.MapDataset.source([i]).repeat(i).to_iter_dataset() ) ds = interleave.InterleaveIterDataset(ds, cycle_length=5) + ds = self._maybe_wrap_ds(ds) ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3)) self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5]) @@ -176,6 +185,7 @@ def test_options_propagated(self): ds1 = ds1.filter(lambda x: False) ds2 = dataset.MapDataset.source([2]).repeat(1000).to_iter_dataset() ds = interleave.InterleaveIterDataset([ds1, ds2], cycle_length=1) + ds = self._maybe_wrap_ds(ds) ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) ds = dataset.WithOptionsIterDataset(ds, ds_options) with self.assertRaisesRegex(ValueError, r"skipped 100\.00 %"): @@ -187,6 +197,7 @@ def test_checkpointing_comprehensive(self): for i in range(1, 6) ] ds = interleave.InterleaveIterDataset(ds, cycle_length=5) + ds = self._maybe_wrap_ds(ds) assert_equal_output_after_checkpoint(ds) def test_set_state_does_not_recreate_iterators_if_not_needed(self): @@ -196,6 +207,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self): ds = interleave.InterleaveIterDataset( [ds] * cycle_length, cycle_length=cycle_length ) + ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() iter_ids1 = [] for _ in range(cycle_length): @@ -211,6 +223,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self): def test_element_spec(self): ds = dataset.MapDataset.range(3).to_iter_dataset() ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2) + ds = self._maybe_wrap_ds(ds) spec = dataset.get_element_spec(ds) self.assertEqual(spec.dtype, np.int64) self.assertEqual(spec.shape, ()) @@ -255,6 +268,7 @@ def test_interleave_stats_with_mismatched_dataset_structures(self): def test_get_next_index(self): ds = dataset.MapDataset.range(10).to_iter_dataset() ds = interleave.InterleaveIterDataset([ds], cycle_length=1) + ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() self.assertEqual(dataset.get_next_index(ds_iter), 0) for i in range(10): @@ -264,6 +278,7 @@ def test_get_next_index(self): def test_set_next_index(self): ds = dataset.MapDataset.range(10).to_iter_dataset() ds = interleave.InterleaveIterDataset([ds], cycle_length=1) + ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() for i in reversed(range(10)): dataset.set_next_index(ds_iter, i) @@ -272,6 +287,7 @@ def test_set_next_index(self): def test_get_next_index_with_multiple_datasets(self): ds = dataset.MapDataset.range(10).to_iter_dataset() ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2) + ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() with self.assertRaisesRegex( NotImplementedError, @@ -283,6 +299,7 @@ def test_get_next_index_with_multiple_datasets(self): def test_set_next_index_with_multiple_datasets(self): ds = dataset.MapDataset.range(10).to_iter_dataset() ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2) + ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() with self.assertRaisesRegex( NotImplementedError, @@ -297,6 +314,7 @@ def test_future_states(self): dataset.MapDataset.source([3, 4]).to_iter_dataset(), ] ds = interleave.InterleaveIterDataset(datasets, cycle_length=1) + ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() # Initialize the first iterator and get state. @@ -323,5 +341,9 @@ def test_future_states(self): next(ds_iter) +class InterleaveIterDatasetTest(_InterleaveIterDatasetTestBase): + """Runs tests without prefetch.""" + + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index a9cde4843..0344e3cb7 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -534,6 +534,7 @@ def start_prefetch(self): return self._prefetch_should_stop.clear() + self._prefetch_thread = threading.Thread( target=functools.partial( _put_iterator_elements_in_buffer, @@ -595,8 +596,10 @@ def _stop_prefetch(self): # is shutting down. Attempting to join can lead to hanging in Python # 3.13 as daemon threads can hang during interpreter shutdown. See # https://github.com/python/cpython/issues/123940#issuecomment-2976446309 - self._prefetch_thread.join() + if self._prefetch_thread is not None: + self._prefetch_thread.join() self._prefetch_thread = None + # Clear the buffer again in case the prefetch loop added more elements on # exit. self._clear_buffer() diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 80ae0a723..4f48fb18c 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -448,7 +448,8 @@ def test_set_next_index(self): self.assertEqual(next(ds_iter), i) -class ThreadPrefetchIterDatasetTest(parameterized.TestCase): +@absltest.skipThisClass('Base class') +class _ThreadPrefetchIterDatasetTestBase(parameterized.TestCase): def setUp(self): super().setUp() @@ -768,6 +769,10 @@ def new_get_state(self): self.assertEqual(get_state_counter.call_count - get_state_count, 1) +class ThreadPrefetchIterDatasetTest(_ThreadPrefetchIterDatasetTestBase): + """Runs tests without provided executor.""" + + class _MpContextCheckIterDataset(dataset.IterDataset[_T]): def __iter__(self) -> dataset.DatasetIterator[_T]: