Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ py_test(
"//grain/_src/python:options",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
"//grain/_src/python/experimental/autotune/python:bindings",
"//grain/_src/python/testing:experimental",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:flagsaver",
Expand Down
35 changes: 31 additions & 4 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.
"""Implements dataset interleaving."""

import collections
from collections.abc import Sequence
import copy
import functools
from typing import Any, TypeVar
import weakref

from concurrent import futures
from grain._src.python import options as grain_options
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
Expand Down Expand Up @@ -296,6 +299,11 @@ def close(self) -> None:
def _initialize_stats(
self, execution_tracking_mode: base.ExecutionTrackingMode
) -> stats.Stats:
# We pass an empty list of parents to `stats.make_stats` below. The
# parents of InterleaveDatasetIterator are the iterators of the
# datasets being interleaved. These are dynamically created and tracked
# in `self._iterators_in_use`. When an iterator produces an element in
# `__next__`, its Stats object is added to `self._stats._parents`.
config = stats.StatsConfig(
name=str(self),
transform_mutates_spec=self._MUTATES_ELEMENT_SPEC,
Expand Down Expand Up @@ -337,8 +345,9 @@ def _get_iterator_start_state(self, index: int) -> dict[str, Any]:

def _add_prefetch_and_make_iterator(
ds: dataset.IterDataset[T] | dataset.MapDataset[T],
interleave_iterator: weakref.ref[InterleaveDatasetIterator[T]],
interleave_iterator: weakref.ref[dataset.DatasetIterator[T]],
start_prefetch: bool,
starting_state: dict[str, Any] | None = None,
) -> dataset.DatasetIterator[T]:
"""Adds prefetching to an IterDataset and returns an iterator.

Expand All @@ -350,6 +359,7 @@ def _add_prefetch_and_make_iterator(
ds: The dataset to create an iterator from.
interleave_iterator: The `InterleaveDatasetIterator` instance.
start_prefetch: Whether to start the prefetching on iterator creation.
starting_state: The state of the iterator to set.

Returns:
A `dataset.DatasetIterator` for the given dataset, with prefetching
Expand All @@ -358,24 +368,41 @@ def _add_prefetch_and_make_iterator(
Raises:
RuntimeError: If the interleave_iterator has been garbage collected.
"""
# pylint: disable=protected-access
interleave_iterator_obj = interleave_iterator()
assert isinstance(
interleave_iterator_obj,
(InterleaveDatasetIterator, TunableInterleaveDatasetIterator),
)
if interleave_iterator_obj is None:
raise RuntimeError("InterleaveDatasetIterator has been garbage collected.")
iter_buffer_size = interleave_iterator_obj._iter_buffer_size
ctx = interleave_iterator_obj._ctx
# Release the strong reference before potentially slow iterator creation.
# This prevents the worker thread from triggering parent destruction.
del interleave_iterator_obj

if isinstance(ds, dataset.MapDataset):
# Prefetch is automatically added in `MapDataset.__iter__`.
iter_dataset = ds.to_iter_dataset()
else:
iter_dataset = prefetch.ThreadPrefetchIterDataset(
ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access
ds, prefetch_buffer_size=iter_buffer_size
)
iterator = iter_dataset.__iter__()

# Propagate options applied after InterleaveIterDataset to the iterators that
# are being interleaved.
iterator._ctx.dataset_options = interleave_iterator_obj._ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access
iterator._ctx.dataset_options = ctx.dataset_options.merge(
iterator._ctx.dataset_options
)
iterator._ctx.dataset_options = ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access

if start_prefetch:
iterator.start_prefetch()
if starting_state is not None:
iterator.set_state(starting_state)
# pylint: enable=protected-access
return iterator


Expand Down Expand Up @@ -409,7 +436,7 @@ def __init__(
self,
datasets: Sequence[dataset.IterDataset[T] | dataset.MapDataset[T]],
*,
cycle_length: int,
cycle_length: int | bindings.AutotuneParameter,
num_make_iter_threads: int = 1,
make_iter_buffer_size: int = 1,
iter_buffer_size: int = 1,
Expand Down
41 changes: 20 additions & 21 deletions grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import cast
from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
Expand Down Expand Up @@ -93,6 +94,9 @@ def __iter__(self) -> dataset.DatasetIterator:
@absltest.skipThisClass("Base class")
class _InterleaveIterDatasetTestBase(parameterized.TestCase):

def _create_dataset(self, *args, **kwargs):
return interleave.InterleaveIterDataset(*args, **kwargs)

def _maybe_wrap_ds(self, ds):
return ds

Expand All @@ -102,7 +106,7 @@ def test_interleaved_mix(self, to_mix, cycle_length, expected):
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in to_mix
]
ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length)
ds = self._create_dataset(datasets, cycle_length=cycle_length)
ds = self._maybe_wrap_ds(ds)
self.assertEqual(list(ds), expected)
# Sanity check.
Expand All @@ -117,7 +121,7 @@ def test_checkpoint(self, to_mix, cycle_length, expected):
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in to_mix
]
ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length)
ds = self._create_dataset(datasets, cycle_length=cycle_length)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()
checkpoints = {}
Expand All @@ -138,7 +142,7 @@ def test_checkpoint_with_extra_threads_creating_iterators(
dataset.MapDataset.source(elements).to_iter_dataset()
for elements in to_mix
]
ds = interleave.InterleaveIterDataset(
ds = self._create_dataset(
datasets,
cycle_length=cycle_length,
num_make_iter_threads=10,
Expand All @@ -164,7 +168,7 @@ def make_dummy_source(filename):

filenames = dataset.MapDataset.source(["11", "2345", "678", "9999"])
sources = filenames.shuffle(seed=42).map(make_dummy_source)
ds = interleave.InterleaveIterDataset(sources, cycle_length=2)
ds = self._create_dataset(sources, cycle_length=2)
ds = self._maybe_wrap_ds(ds)
self.assertEqual(
list(ds),
Expand All @@ -175,7 +179,7 @@ def test_with_mp_prefetch(self):
ds = dataset.MapDataset.range(1, 6).map(
lambda i: dataset.MapDataset.source([i]).repeat(i).to_iter_dataset()
)
ds = interleave.InterleaveIterDataset(ds, cycle_length=5)
ds = self._create_dataset(ds, cycle_length=5)
ds = self._maybe_wrap_ds(ds)
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3))
self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5])
Expand All @@ -184,7 +188,7 @@ def test_options_propagated(self):
ds1 = dataset.MapDataset.source([1]).repeat(1000).to_iter_dataset()
ds1 = ds1.filter(lambda x: False)
ds2 = dataset.MapDataset.source([2]).repeat(1000).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds1, ds2], cycle_length=1)
ds = self._create_dataset([ds1, ds2], cycle_length=1)
ds = self._maybe_wrap_ds(ds)
ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
ds = dataset.WithOptionsIterDataset(ds, ds_options)
Expand All @@ -196,17 +200,15 @@ def test_checkpointing_comprehensive(self):
dataset.MapDataset.source([i]).repeat(i).to_iter_dataset()
for i in range(1, 6)
]
ds = interleave.InterleaveIterDataset(ds, cycle_length=5)
ds = self._create_dataset(ds, cycle_length=5)
ds = self._maybe_wrap_ds(ds)
assert_equal_output_after_checkpoint(ds)

def test_set_state_does_not_recreate_iterators_if_not_needed(self):
cycle_length = 5
ds = dataset.MapDataset.range(100).to_iter_dataset()
ds = _IteratorIdIterDataset(ds)
ds = interleave.InterleaveIterDataset(
[ds] * cycle_length, cycle_length=cycle_length
)
ds = self._create_dataset([ds] * cycle_length, cycle_length=cycle_length)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()
iter_ids1 = []
Expand All @@ -222,7 +224,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self):

def test_element_spec(self):
ds = dataset.MapDataset.range(3).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
ds = self._create_dataset([ds, ds], cycle_length=2)
ds = self._maybe_wrap_ds(ds)
spec = dataset.get_element_spec(ds)
self.assertEqual(spec.dtype, np.int64)
Expand All @@ -232,7 +234,7 @@ def test_element_spec(self):
def test_interleave_stats(self):
ds = dataset.MapDataset.range(10000).map(lambda x: x + 1)
ds = ds.to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
ds = self._create_dataset([ds, ds], cycle_length=2)
it = ds.__iter__()
next(it)
next(it)
Expand All @@ -243,27 +245,26 @@ def test_interleave_stats(self):
"MapMapDataset",
"PrefetchDatasetIterator",
"ThreadPrefetchDatasetIterator",
"InterleaveDatasetIterator",
"Interleave",
]
for expected_node in expected_nodes:
self.assertTrue(any(expected_node in name for name in node_names))
self.assertLen(node_names, len(expected_nodes))
print(summary)

@flagsaver.flagsaver(grain_py_debug_mode=True)
def test_interleave_stats_with_mismatched_dataset_structures(self):
ds1 = dataset.MapDataset.range(10000).map(lambda x: x + 1)
ds1 = ds1.to_iter_dataset()
ds2 = dataset.MapDataset.range(10000).map(lambda x: x + 1).map(lambda x: x)
ds2 = ds2.to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds1, ds2], cycle_length=2)
ds = self._create_dataset([ds1, ds2], cycle_length=2)
it = ds.__iter__()
next(it)
next(it)
summary = dataset.get_execution_summary(it)
node_names = [node.name for node in summary.nodes.values()]
self.assertLen(node_names, 1)
self.assertIn("InterleaveDatasetIterator", node_names[0])
self.assertIn("Interleave", node_names[0])

def test_get_next_index(self):
ds = dataset.MapDataset.range(10).to_iter_dataset()
Expand Down Expand Up @@ -291,8 +292,7 @@ def test_get_next_index_with_multiple_datasets(self):
ds_iter = ds.__iter__()
with self.assertRaisesRegex(
NotImplementedError,
"get_next_index is not supported for InterleaveDatasetIterator with"
" more than one dataset.",
"get_next_index is not supported for .*Interleave",
):
dataset.get_next_index(ds_iter)

Expand All @@ -303,8 +303,7 @@ def test_set_next_index_with_multiple_datasets(self):
ds_iter = ds.__iter__()
with self.assertRaisesRegex(
NotImplementedError,
"set_next_index is not supported for InterleaveDatasetIterator with"
" more than one dataset.",
"set_next_index is not supported for .*Interleave",
):
dataset.set_next_index(ds_iter, 0)

Expand All @@ -313,7 +312,7 @@ def test_future_states(self):
dataset.MapDataset.source([1, 2]).to_iter_dataset(),
dataset.MapDataset.source([3, 4]).to_iter_dataset(),
]
ds = interleave.InterleaveIterDataset(datasets, cycle_length=1)
ds = self._create_dataset(datasets, cycle_length=1)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()

Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from grain._src.python import options
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import autotune
from grain._src.python.dataset.transformations import filter as filter_lazy_dataset
from grain._src.python.dataset.transformations import prefetch
import numpy as np
Expand Down
Loading