Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ py_test(
shard_count = 50,
srcs_version = "PY3",
deps = [
"//grain/_src/core:config",
"//grain/_src/core:transforms",
"//grain/_src/python:options",
"//grain/_src/python/dataset",
Expand Down
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,11 @@ def _add_prefetch_and_make_iterator(
ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access
)
iterator = iter_dataset.__iter__()

# Propagate options applied after InterleaveIterDataset to the iterators that
# are being interleaved.
iterator._ctx.dataset_options = interleave_iterator_obj._ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access

if start_prefetch:
iterator.start_prefetch()
return iterator
Expand Down
24 changes: 23 additions & 1 deletion grain/_src/python/dataset/transformations/interleave_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def __iter__(self) -> dataset.DatasetIterator:
return _IteratorIdDatasetIterator(self._parent.__iter__())


class InterleaveIterDatasetTest(parameterized.TestCase):
@absltest.skipThisClass("Base class")
class _InterleaveIterDatasetTestBase(parameterized.TestCase):

def _maybe_wrap_ds(self, ds):
return ds

@parameterized.named_parameters(*_INTERLEAVE_TEST_CASES)
def test_interleaved_mix(self, to_mix, cycle_length, expected):
Expand All @@ -99,6 +103,7 @@ def test_interleaved_mix(self, to_mix, cycle_length, expected):
for elements in to_mix
]
ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length)
ds = self._maybe_wrap_ds(ds)
self.assertEqual(list(ds), expected)
# Sanity check.
flat_inputs = []
Expand All @@ -113,6 +118,7 @@ def test_checkpoint(self, to_mix, cycle_length, expected):
for elements in to_mix
]
ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()
checkpoints = {}
for i in range(len(expected)):
Expand All @@ -138,6 +144,7 @@ def test_checkpoint_with_extra_threads_creating_iterators(
num_make_iter_threads=10,
make_iter_buffer_size=10,
)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()
checkpoints = {}
for i in range(len(expected)):
Expand All @@ -158,6 +165,7 @@ def make_dummy_source(filename):
filenames = dataset.MapDataset.source(["11", "2345", "678", "9999"])
sources = filenames.shuffle(seed=42).map(make_dummy_source)
ds = interleave.InterleaveIterDataset(sources, cycle_length=2)
ds = self._maybe_wrap_ds(ds)
self.assertEqual(
list(ds),
["1", "2", "1", "3", "4", "6", "5", "7", "8", "9", "9", "9", "9"],
Expand All @@ -168,6 +176,7 @@ def test_with_mp_prefetch(self):
lambda i: dataset.MapDataset.source([i]).repeat(i).to_iter_dataset()
)
ds = interleave.InterleaveIterDataset(ds, cycle_length=5)
ds = self._maybe_wrap_ds(ds)
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3))
self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5])

Expand All @@ -176,6 +185,7 @@ def test_options_propagated(self):
ds1 = ds1.filter(lambda x: False)
ds2 = dataset.MapDataset.source([2]).repeat(1000).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds1, ds2], cycle_length=1)
ds = self._maybe_wrap_ds(ds)
ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
ds = dataset.WithOptionsIterDataset(ds, ds_options)
with self.assertRaisesRegex(ValueError, r"skipped 100\.00 %"):
Expand All @@ -187,6 +197,7 @@ def test_checkpointing_comprehensive(self):
for i in range(1, 6)
]
ds = interleave.InterleaveIterDataset(ds, cycle_length=5)
ds = self._maybe_wrap_ds(ds)
assert_equal_output_after_checkpoint(ds)

def test_set_state_does_not_recreate_iterators_if_not_needed(self):
Expand All @@ -196,6 +207,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self):
ds = interleave.InterleaveIterDataset(
[ds] * cycle_length, cycle_length=cycle_length
)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()
iter_ids1 = []
for _ in range(cycle_length):
Expand All @@ -211,6 +223,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self):
def test_element_spec(self):
ds = dataset.MapDataset.range(3).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
ds = self._maybe_wrap_ds(ds)
spec = dataset.get_element_spec(ds)
self.assertEqual(spec.dtype, np.int64)
self.assertEqual(spec.shape, ())
Expand Down Expand Up @@ -255,6 +268,7 @@ def test_interleave_stats_with_mismatched_dataset_structures(self):
def test_get_next_index(self):
ds = dataset.MapDataset.range(10).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds], cycle_length=1)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()
self.assertEqual(dataset.get_next_index(ds_iter), 0)
for i in range(10):
Expand All @@ -264,6 +278,7 @@ def test_get_next_index(self):
def test_set_next_index(self):
ds = dataset.MapDataset.range(10).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds], cycle_length=1)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()
for i in reversed(range(10)):
dataset.set_next_index(ds_iter, i)
Expand All @@ -272,6 +287,7 @@ def test_set_next_index(self):
def test_get_next_index_with_multiple_datasets(self):
ds = dataset.MapDataset.range(10).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()
with self.assertRaisesRegex(
NotImplementedError,
Expand All @@ -283,6 +299,7 @@ def test_get_next_index_with_multiple_datasets(self):
def test_set_next_index_with_multiple_datasets(self):
ds = dataset.MapDataset.range(10).to_iter_dataset()
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()
with self.assertRaisesRegex(
NotImplementedError,
Expand All @@ -297,6 +314,7 @@ def test_future_states(self):
dataset.MapDataset.source([3, 4]).to_iter_dataset(),
]
ds = interleave.InterleaveIterDataset(datasets, cycle_length=1)
ds = self._maybe_wrap_ds(ds)
ds_iter = ds.__iter__()

# Initialize the first iterator and get state.
Expand All @@ -323,5 +341,9 @@ def test_future_states(self):
next(ds_iter)


class InterleaveIterDatasetTest(_InterleaveIterDatasetTestBase):
"""Runs tests without prefetch."""


if __name__ == "__main__":
absltest.main()
5 changes: 4 additions & 1 deletion grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def start_prefetch(self):
return

self._prefetch_should_stop.clear()

self._prefetch_thread = threading.Thread(
target=functools.partial(
_put_iterator_elements_in_buffer,
Expand Down Expand Up @@ -595,8 +596,10 @@ def _stop_prefetch(self):
# is shutting down. Attempting to join can lead to hanging in Python
# 3.13 as daemon threads can hang during interpreter shutdown. See
# https://github.com/python/cpython/issues/123940#issuecomment-2976446309
self._prefetch_thread.join()
if self._prefetch_thread is not None:
self._prefetch_thread.join()
self._prefetch_thread = None

# Clear the buffer again in case the prefetch loop added more elements on
# exit.
self._clear_buffer()
Expand Down
7 changes: 6 additions & 1 deletion grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ def test_set_next_index(self):
self.assertEqual(next(ds_iter), i)


class ThreadPrefetchIterDatasetTest(parameterized.TestCase):
@absltest.skipThisClass('Base class')
class _ThreadPrefetchIterDatasetTestBase(parameterized.TestCase):

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -768,6 +769,10 @@ def new_get_state(self):
self.assertEqual(get_state_counter.call_count - get_state_count, 1)


class ThreadPrefetchIterDatasetTest(_ThreadPrefetchIterDatasetTestBase):
"""Runs tests without provided executor."""


class _MpContextCheckIterDataset(dataset.IterDataset[_T]):

def __iter__(self) -> dataset.DatasetIterator[_T]:
Expand Down
Loading