diff --git a/grain/_src/python/dataset/transformations/interleave.py b/grain/_src/python/dataset/transformations/interleave.py index a87ede04b..d7f786f29 100644 --- a/grain/_src/python/dataset/transformations/interleave.py +++ b/grain/_src/python/dataset/transformations/interleave.py @@ -13,11 +13,14 @@ # limitations under the License. """Implements dataset interleaving.""" +import collections from collections.abc import Sequence +import copy import functools from typing import Any, TypeVar import weakref +from concurrent import futures from grain._src.python import options as grain_options from grain._src.python.dataset import base from grain._src.python.dataset import dataset @@ -296,6 +299,11 @@ def close(self) -> None: def _initialize_stats( self, execution_tracking_mode: base.ExecutionTrackingMode ) -> stats.Stats: + # We pass an empty list of parents to `stats.make_stats` below. The + # parents of InterleaveDatasetIterator are the iterators of the + # datasets being interleaved. These are dynamically created and tracked + # in `self._iterators_in_use`. When an iterator produces an element in + # `__next__`, its Stats object is added to `self._stats._parents`. config = stats.StatsConfig( name=str(self), transform_mutates_spec=self._MUTATES_ELEMENT_SPEC, @@ -337,8 +345,9 @@ def _get_iterator_start_state(self, index: int) -> dict[str, Any]: def _add_prefetch_and_make_iterator( ds: dataset.IterDataset[T] | dataset.MapDataset[T], - interleave_iterator: weakref.ref[InterleaveDatasetIterator[T]], + interleave_iterator: weakref.ref[dataset.DatasetIterator[T]], start_prefetch: bool, + starting_state: dict[str, Any] | None = None, ) -> dataset.DatasetIterator[T]: """Adds prefetching to an IterDataset and returns an iterator. @@ -350,6 +359,7 @@ def _add_prefetch_and_make_iterator( ds: The dataset to create an iterator from. interleave_iterator: The `InterleaveDatasetIterator` instance. start_prefetch: Whether to start the prefetching on iterator creation. + starting_state: The state of the iterator to set. Returns: A `dataset.DatasetIterator` for the given dataset, with prefetching @@ -358,24 +368,38 @@ def _add_prefetch_and_make_iterator( Raises: RuntimeError: If the interleave_iterator has been garbage collected. """ + # pylint: disable=protected-access interleave_iterator_obj = interleave_iterator() + assert isinstance(interleave_iterator_obj, InterleaveDatasetIterator) if interleave_iterator_obj is None: raise RuntimeError("InterleaveDatasetIterator has been garbage collected.") + iter_buffer_size = interleave_iterator_obj._iter_buffer_size + ctx = interleave_iterator_obj._ctx + # Release the strong reference before potentially slow iterator creation. + # This prevents the worker thread from triggering parent destruction. + del interleave_iterator_obj + if isinstance(ds, dataset.MapDataset): # Prefetch is automatically added in `MapDataset.__iter__`. iter_dataset = ds.to_iter_dataset() else: iter_dataset = prefetch.ThreadPrefetchIterDataset( - ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access + ds, prefetch_buffer_size=iter_buffer_size ) iterator = iter_dataset.__iter__() # Propagate options applied after InterleaveIterDataset to the iterators that # are being interleaved. - iterator._ctx.dataset_options = interleave_iterator_obj._ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access + iterator._ctx.dataset_options = ctx.dataset_options.merge( + iterator._ctx.dataset_options + ) + iterator._ctx.dataset_options = ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access if start_prefetch: iterator.start_prefetch() + if starting_state is not None: + iterator.set_state(starting_state) + # pylint: enable=protected-access return iterator @@ -409,7 +433,7 @@ def __init__( self, datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]], *, - cycle_length: int, + cycle_length: int | grain_options.AutotuneParameter, num_make_iter_threads: int = 1, make_iter_buffer_size: int = 1, iter_buffer_size: int = 1, diff --git a/grain/_src/python/dataset/transformations/interleave_test.py b/grain/_src/python/dataset/transformations/interleave_test.py index 27eeb8929..f2b49ae2e 100644 --- a/grain/_src/python/dataset/transformations/interleave_test.py +++ b/grain/_src/python/dataset/transformations/interleave_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast from absl.testing import absltest from absl.testing import flagsaver from absl.testing import parameterized @@ -93,6 +94,9 @@ def __iter__(self) -> dataset.DatasetIterator: @absltest.skipThisClass("Base class") class _InterleaveIterDatasetTestBase(parameterized.TestCase): + def _create_dataset(self, *args, **kwargs): + return interleave.InterleaveIterDataset(*args, **kwargs) + def _maybe_wrap_ds(self, ds): return ds @@ -102,7 +106,7 @@ def test_interleaved_mix(self, to_mix, cycle_length, expected): dataset.MapDataset.source(elements).to_iter_dataset() for elements in to_mix ] - ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length) + ds = self._create_dataset(datasets, cycle_length=cycle_length) ds = self._maybe_wrap_ds(ds) self.assertEqual(list(ds), expected) # Sanity check. @@ -117,7 +121,7 @@ def test_checkpoint(self, to_mix, cycle_length, expected): dataset.MapDataset.source(elements).to_iter_dataset() for elements in to_mix ] - ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length) + ds = self._create_dataset(datasets, cycle_length=cycle_length) ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() checkpoints = {} @@ -138,7 +142,7 @@ def test_checkpoint_with_extra_threads_creating_iterators( dataset.MapDataset.source(elements).to_iter_dataset() for elements in to_mix ] - ds = interleave.InterleaveIterDataset( + ds = self._create_dataset( datasets, cycle_length=cycle_length, num_make_iter_threads=10, @@ -164,7 +168,7 @@ def make_dummy_source(filename): filenames = dataset.MapDataset.source(["11", "2345", "678", "9999"]) sources = filenames.shuffle(seed=42).map(make_dummy_source) - ds = interleave.InterleaveIterDataset(sources, cycle_length=2) + ds = self._create_dataset(sources, cycle_length=2) ds = self._maybe_wrap_ds(ds) self.assertEqual( list(ds), @@ -175,7 +179,7 @@ def test_with_mp_prefetch(self): ds = dataset.MapDataset.range(1, 6).map( lambda i: dataset.MapDataset.source([i]).repeat(i).to_iter_dataset() ) - ds = interleave.InterleaveIterDataset(ds, cycle_length=5) + ds = self._create_dataset(ds, cycle_length=5) ds = self._maybe_wrap_ds(ds) ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3)) self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5]) @@ -184,7 +188,7 @@ def test_options_propagated(self): ds1 = dataset.MapDataset.source([1]).repeat(1000).to_iter_dataset() ds1 = ds1.filter(lambda x: False) ds2 = dataset.MapDataset.source([2]).repeat(1000).to_iter_dataset() - ds = interleave.InterleaveIterDataset([ds1, ds2], cycle_length=1) + ds = self._create_dataset([ds1, ds2], cycle_length=1) ds = self._maybe_wrap_ds(ds) ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1) ds = dataset.WithOptionsIterDataset(ds, ds_options) @@ -196,7 +200,7 @@ def test_checkpointing_comprehensive(self): dataset.MapDataset.source([i]).repeat(i).to_iter_dataset() for i in range(1, 6) ] - ds = interleave.InterleaveIterDataset(ds, cycle_length=5) + ds = self._create_dataset(ds, cycle_length=5) ds = self._maybe_wrap_ds(ds) assert_equal_output_after_checkpoint(ds) @@ -204,9 +208,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self): cycle_length = 5 ds = dataset.MapDataset.range(100).to_iter_dataset() ds = _IteratorIdIterDataset(ds) - ds = interleave.InterleaveIterDataset( - [ds] * cycle_length, cycle_length=cycle_length - ) + ds = self._create_dataset([ds] * cycle_length, cycle_length=cycle_length) ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() iter_ids1 = [] @@ -222,7 +224,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self): def test_element_spec(self): ds = dataset.MapDataset.range(3).to_iter_dataset() - ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2) + ds = self._create_dataset([ds, ds], cycle_length=2) ds = self._maybe_wrap_ds(ds) spec = dataset.get_element_spec(ds) self.assertEqual(spec.dtype, np.int64) @@ -232,7 +234,7 @@ def test_element_spec(self): def test_interleave_stats(self): ds = dataset.MapDataset.range(10000).map(lambda x: x + 1) ds = ds.to_iter_dataset() - ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2) + ds = self._create_dataset([ds, ds], cycle_length=2) it = ds.__iter__() next(it) next(it) @@ -243,12 +245,11 @@ def test_interleave_stats(self): "MapMapDataset", "PrefetchDatasetIterator", "ThreadPrefetchDatasetIterator", - "InterleaveDatasetIterator", + "Interleave", ] for expected_node in expected_nodes: self.assertTrue(any(expected_node in name for name in node_names)) self.assertLen(node_names, len(expected_nodes)) - print(summary) @flagsaver.flagsaver(grain_py_debug_mode=True) def test_interleave_stats_with_mismatched_dataset_structures(self): @@ -256,14 +257,14 @@ def test_interleave_stats_with_mismatched_dataset_structures(self): ds1 = ds1.to_iter_dataset() ds2 = dataset.MapDataset.range(10000).map(lambda x: x + 1).map(lambda x: x) ds2 = ds2.to_iter_dataset() - ds = interleave.InterleaveIterDataset([ds1, ds2], cycle_length=2) + ds = self._create_dataset([ds1, ds2], cycle_length=2) it = ds.__iter__() next(it) next(it) summary = dataset.get_execution_summary(it) node_names = [node.name for node in summary.nodes.values()] self.assertLen(node_names, 1) - self.assertIn("InterleaveDatasetIterator", node_names[0]) + self.assertIn("Interleave", node_names[0]) def test_get_next_index(self): ds = dataset.MapDataset.range(10).to_iter_dataset() @@ -291,8 +292,7 @@ def test_get_next_index_with_multiple_datasets(self): ds_iter = ds.__iter__() with self.assertRaisesRegex( NotImplementedError, - "get_next_index is not supported for InterleaveDatasetIterator with" - " more than one dataset.", + "get_next_index is not supported for .*Interleave", ): dataset.get_next_index(ds_iter) @@ -303,8 +303,7 @@ def test_set_next_index_with_multiple_datasets(self): ds_iter = ds.__iter__() with self.assertRaisesRegex( NotImplementedError, - "set_next_index is not supported for InterleaveDatasetIterator with" - " more than one dataset.", + "set_next_index is not supported for .*Interleave", ): dataset.set_next_index(ds_iter, 0) @@ -313,7 +312,7 @@ def test_future_states(self): dataset.MapDataset.source([1, 2]).to_iter_dataset(), dataset.MapDataset.source([3, 4]).to_iter_dataset(), ] - ds = interleave.InterleaveIterDataset(datasets, cycle_length=1) + ds = self._create_dataset(datasets, cycle_length=1) ds = self._maybe_wrap_ds(ds) ds_iter = ds.__iter__() diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 1940b437a..660517898 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -412,7 +412,7 @@ def __init__( self, parent: dataset.IterDataset[T], *, - prefetch_buffer_size: int | bindings.AutotuneParameter, + prefetch_buffer_size: int | grain_options.AutotuneParameter, ): super().__init__(parent) target_prefetch_buffer_size = prefetch_buffer_size @@ -480,7 +480,7 @@ class ThreadPrefetchDatasetIterator(dataset.DatasetIterator[T]): def __init__( self, parent: CheckpointableIterator[T], - prefetch_buffer_size: int | bindings.AutotuneParameter, + prefetch_buffer_size: int | grain_options.AutotuneParameter, ): if isinstance(parent, dataset.DatasetIterator): super().__init__(parent)