diff --git a/CHANGELOG.md b/CHANGELOG.md index e3b3fcccc..30f438e1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change * Exposes `SharedMemoryArrayMetadata` in a public API as a metadata descriptor for `SharedMemoryArray`. * `ParquetIterDataset` can read from multiple string paths interleaving reads. + * Add `ElasticIterDatasetIterator` for scaling up and down the number of shards between checkpoints. * Breaking changes: * Custom implementations of `RandomAccessDataSource` should accept `int` diff --git a/grain/_src/python/dataset/elastic_iterator.py b/grain/_src/python/dataset/elastic_iterator.py index a7140c172..a160d720c 100644 --- a/grain/_src/python/dataset/elastic_iterator.py +++ b/grain/_src/python/dataset/elastic_iterator.py @@ -13,8 +13,9 @@ # limitations under the License. """Iterator supporting changes in the number of hosts (dataset shards).""" +import copy import functools -from typing import Any +from typing import Any, TypeVar, cast from grain._src.core import sharding from grain._src.python import options @@ -22,53 +23,190 @@ from grain._src.python.dataset.transformations import ( filter as filter_dataset, ) +from grain._src.python.dataset.transformations import interleave +from grain._src.python.dataset.transformations import prefetch +from grain._src.python.dataset.transformations import process_prefetch + +T = TypeVar("T") _GLOBAL_NEXT_INDEX_STATE_KEY = "global_next_index" -class ElasticIterator(dataset.DatasetIterator): - """Iterator supporting recovery from a checkpoint after changes in sharding. +class ElasticIterDatasetIterator(dataset.DatasetIterator): + """Elastic iterator for InterleaveIterDatasets.""" - The input dataset is expected to be unbatched and unsharded. In order to - provide elasticity guarantee this iterator includes both, batching and - sharding. The iterator supports elastic re-configuration by having each - shard produce the same exact checkpoint (while producing different data) as - long as they are advanced the same number of steps. + def __init__( + self, + parent: interleave.InterleaveIterDataset, + shard_options: sharding.ShardOptions, + global_batch_size: int, + drop_remainder: bool, + read_options: options.ReadOptions, + multiprocessing_options: options.MultiprocessingOptions | None = None, + ): + super().__init__() + self._ds = parent + self._cycle_length = self._ds._cycle_length # pylint: disable=protected-access - State of any shard can be used to restore the state of all of the shards after - changes in sharding and global batch size. + self._global_batch_size = global_batch_size + self._drop_remainder = drop_remainder + self._shard_options = shard_options + self._read_options = read_options + self._multiprocessing_options = multiprocessing_options - This iterator explicitly disallows many-to-one transformations without - a fixed ratio, like `filter` and generic `IterDataset` transformations. - """ + # These will be initialized when the iterator is created. + self._iterator_started = False + self._is_batched = False + self._closed = False + + @property + def shard_options(self) -> sharding.ShardOptions: + return self._shard_options + + def close(self): + if self._closed: + return + self._closed = True + if "_iterator" in self.__dict__: + self._iterator.close() + + @functools.cached_property + def _iterator(self) -> dataset.DatasetIterator: + ds = self._ds + self._iterator_started = True + if self._global_batch_size > 0: + ds = ds.batch( + self._global_batch_size, drop_remainder=self._drop_remainder + ) + self._is_batched = True + + return ds.__iter__() + + def __next__(self) -> Any: + return next(self._iterator) + + def get_state(self): + state = self._iterator.get_state() + ds_iterator_states = {} + + indices = state["iterators_in_use_indices"] + states = state["iterators_in_use_states"] + exhausted = state["exhausted"] + next_index_in_datasets = state["next_index_in_datasets"] + if self._is_batched: + interleave_iter = cast(interleave.InterleaveDatasetIterator, self._iterator._parent) # pylint: disable=protected-access + else: + interleave_iter = cast( + interleave.InterleaveDatasetIterator, self._iterator + ) + for i in range(len(interleave_iter._datasets)): # pylint: disable=protected-access + shard_index = ( + i * self._shard_options.shard_count + self._shard_options.shard_index + ) + # If the current shard index is greater than or equal to the next + # index in datasets, it means the current shard has not yet started + # to be iterated on. + if i >= next_index_in_datasets: + ds_iterator_states[shard_index] = { + "exhausted": 0, + "state": interleave_iter._get_iterator_start_state(i), # pylint: disable=protected-access + } + elif i not in indices: + # These shards are exhausted but should still create a state to maintain + # static state spec shapes. + ds_iterator_states[shard_index] = { + "exhausted": 1, + "state": interleave_iter._get_iterator_start_state(i), # pylint: disable=protected-access + } + + for index, state, is_exhausted in zip(indices, states, exhausted): + # These shards are currently being iterated on. + shard_index = ( + index * self._shard_options.shard_count + + self._shard_options.shard_index + ) + ds_iterator_states[shard_index] = { + "exhausted": is_exhausted, + "state": state, + } + + return { + "ds_iterator_states": ds_iterator_states, + } + + def set_state(self, state): + """Sets state by reconstructing the state for the underlying interleave.""" + ds_iterator_states = state["ds_iterator_states"] + active_states = [] + + for shard_index, shard_state in sorted(ds_iterator_states.items()): + # Check if this state belongs to the current shard. + if ( + shard_index - self._shard_options.shard_index + ) % self._shard_options.shard_count == 0: + slice_index = shard_index // self._shard_options.shard_count + if not shard_state["exhausted"]: + active_states.append((slice_index, shard_state["state"])) + + iterators_in_use_indices = [] + iterators_in_use_states = [] + exhausted = [] + count = 0 + future_states = {} + for ind, state in active_states: + if count < self._cycle_length: + iterators_in_use_indices.append(ind) + iterators_in_use_states.append(state) + exhausted.append(0) + count += 1 + elif state: + # If a state exists for this iterator add it to future states + future_states[ind] = state + next_index_in_datasets = max(iterators_in_use_indices) + 1 + while count < self._cycle_length: + iterators_in_use_indices.append(next_index_in_datasets) + iterators_in_use_states.append(None) + exhausted.append(1) + count += 1 + + new_state = { + "next_index_in_cycle": 0, + "next_index_in_datasets": next_index_in_datasets, + "iterators_in_use_indices": iterators_in_use_indices, + "iterators_in_use_states": iterators_in_use_states, + "exhausted": exhausted, + "future_states": future_states, + } + if "_iterator" in self.__dict__: + self.__dict__["_iterator"].close() + self.__dict__.pop("_iterator", None) + self._iterator.set_state(new_state) + + +class _ElasticMapDatasetIterator(dataset.DatasetIterator): + """Iterator for MapDatasets in ElasticIterator.""" def __init__( self, ds: dataset.MapDataset, - global_batch_size: int, shard_options: sharding.ShardOptions, - *, + global_batch_size: int, + drop_remainder: bool, read_options: options.ReadOptions = options.ReadOptions(), multiprocessing_options: options.MultiprocessingOptions | None = None, ): super().__init__() - to_check = [ds] - while to_check: - next_ds = to_check.pop() - if isinstance(next_ds, filter_dataset.FilterMapDataset): - raise ValueError( - "ElasticIterator does not support `filter` transformation." - ) - to_check.extend(next_ds.parents) self._ds = ds - self._global_batch_size = global_batch_size self._shard_options = shard_options - self._global_next_index = 0 + self._global_batch_size = global_batch_size + self._drop_remainder = drop_remainder self._read_options = read_options self._multiprocessing_options = multiprocessing_options + self._global_next_index = 0 + self._closed = False @functools.cached_property - def _iterator(self) -> dataset.DatasetIterator: + def _iterator(self): ds = self._ds[ self._global_next_index + self._shard_options.shard_index :: self._shard_options.shard_count @@ -83,13 +221,10 @@ def _iterator(self) -> dataset.DatasetIterator: ) ds = ds.batch(host_batch_size, drop_remainder=True) ds = ds.to_iter_dataset(read_options=self._read_options) - if self._multiprocessing_options is not None: + if self._multiprocessing_options: ds = ds.mp_prefetch(self._multiprocessing_options) return ds.__iter__() - def __iter__(self) -> dataset.DatasetIterator: - return self - def __next__(self) -> Any: result = next(self._iterator) self._global_next_index += self._global_batch_size @@ -100,7 +235,118 @@ def get_state(self) -> dict[str, Any]: _GLOBAL_NEXT_INDEX_STATE_KEY: self._global_next_index, } - def set_state(self, state: dict[str, Any]): + def close(self): + if self._closed: + return + self._closed = True + if "_iterator" in self.__dict__: + self._iterator.close() + + def set_state(self, state): self._global_next_index = state[_GLOBAL_NEXT_INDEX_STATE_KEY] - # Reset the iterator if it was already created. + if "_iterator" in self.__dict__: + self.__dict__["_iterator"].close() self.__dict__.pop("_iterator", None) + + +class ElasticIterator(dataset.IterDataset): + """Iterator supporting recovery from a checkpoint after changes in sharding. + + The input dataset is expected to be unbatched and unsharded. In order to + provide elasticity guarantee this iterator includes both, batching and + sharding. The iterator supports elastic re-configuration by having each + shard produce the same exact checkpoint (while producing different data) as + long as they are advanced the same number of steps. + + State of any shard can be used to restore the state of all of the shards after + changes in sharding and global batch size. + + This iterator explicitly disallows many-to-one transformations without + a fixed ratio, like `filter` and generic `IterDataset` transformations. + """ + + def __init__( + self, + parent: dataset.MapDataset | dataset.IterDataset, + global_batch_size: int, + shard_options: sharding.ShardOptions, + *, + read_options: options.ReadOptions = options.ReadOptions(), + multiprocessing_options: options.MultiprocessingOptions | None = None, + drop_remainder: bool = False, + ): + super().__init__() + to_check = [parent] + while to_check: + next_ds = to_check.pop() + if isinstance(next_ds, filter_dataset.FilterMapDataset) or isinstance( + next_ds, filter_dataset.FilterIterDataset + ): + raise ValueError( + "ElasticIterator does not support `filter` transformation." + ) + to_check.extend(next_ds.parents) + + self._shard_options = shard_options + self._global_batch_size = global_batch_size + self._drop_remainder = drop_remainder + self._read_options = read_options + self._multiprocessing_options = multiprocessing_options + + if isinstance(parent, dataset.IterDataset): + # We must set the slice on the original dataset so that the interleave + # iterator is created with the correct (sliced) datasets. + self._ds = copy.deepcopy(parent) + prefetch._set_slice_iter_dataset( + self._ds, + slice(shard_options.shard_index, None, shard_options.shard_count), + ) + if ( + self._multiprocessing_options + and self._multiprocessing_options.num_workers > 0 + ): + self._ds = process_prefetch.multiprocess_prefetch( + self._ds, + num_workers=self._multiprocessing_options.num_workers, + buffer_size=self._multiprocessing_options.per_worker_buffer_size, + ) + if isinstance(self._ds, dataset.WithOptionsIterDataset): + self._ds = self._ds._parents[0] + if not isinstance(self._ds, interleave.InterleaveIterDataset): + if self._read_options.num_threads > 1: + datasets = [] + for i in range(self._read_options.num_threads): + d = copy.deepcopy(self._ds) + prefetch._set_slice_iter_dataset( + d, slice(i, None, self._read_options.num_threads) + ) + datasets.append(d) + self._ds = interleave.InterleaveIterDataset( + datasets, cycle_length=len(datasets) + ) + else: + self._ds = parent + + @property + def shard_options(self) -> sharding.ShardOptions: + return self._shard_options + + def __iter__(self) -> dataset.DatasetIterator: + if isinstance(self._ds, dataset.IterDataset): + return ElasticIterDatasetIterator( + self._ds, + self._shard_options, + self._global_batch_size, + self._drop_remainder, + self._read_options, + self._multiprocessing_options, + ) + else: + return _ElasticMapDatasetIterator( + self._ds, + self._shard_options, + self._global_batch_size, + self._drop_remainder, + self._read_options, + self._multiprocessing_options, + ) diff --git a/grain/_src/python/dataset/elastic_iterator_test.py b/grain/_src/python/dataset/elastic_iterator_test.py index 1c4261f09..aba3abe22 100644 --- a/grain/_src/python/dataset/elastic_iterator_test.py +++ b/grain/_src/python/dataset/elastic_iterator_test.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import platform from absl.testing import absltest from absl.testing import parameterized @@ -20,12 +19,13 @@ from grain._src.python import options from grain._src.python.dataset import dataset from grain._src.python.dataset import elastic_iterator +from grain._src.python.dataset.transformations import interleave import grain._src.python.testing.experimental as test_util import numpy as np @absltest.skipIf(platform.system() == "Windows", "Skipped under bazel.") -class ElasticIteratorTest(parameterized.TestCase): +class ElasticMapDataset(parameterized.TestCase): @parameterized.parameters( dict( @@ -63,7 +63,7 @@ def test_produces_correct_elements( global_batch_size, shard_options, multiprocessing_options=multiprocessing_options, - ) + ).__iter__() ) np.testing.assert_equal( actual, expected, err_msg=f"actual: {actual}, expected: {expected}" @@ -71,7 +71,9 @@ def test_produces_correct_elements( def test_checkpointing(self): ds = dataset.MapDataset.range(100).map(lambda x: x * 2).shuffle(42) - it = elastic_iterator.ElasticIterator(ds, 5, sharding.NoSharding()) + it = elastic_iterator.ElasticIterator( + ds, 5, sharding.NoSharding() + ).__iter__() test_util.assert_equal_output_after_checkpoint(it) def test_checkpointing_with_multiprocessing(self): @@ -81,7 +83,7 @@ def test_checkpointing_with_multiprocessing(self): 2, sharding.NoSharding(), multiprocessing_options=options.MultiprocessingOptions(2), - ) + ).__iter__() test_util.assert_equal_output_after_checkpoint(it) def _elastic_resize_test_base( @@ -120,7 +122,7 @@ def make_iterators_before(): ds, 64, sharding.ShardOptions(shard_index=i, shard_count=32), - ) + ).__iter__() for i in range(32) ] @@ -131,7 +133,7 @@ def make_iterators_after(): ds, 32, sharding.ShardOptions(shard_index=i, shard_count=16), - ) + ).__iter__() for i in range(16) ] @@ -154,7 +156,7 @@ def make_iterators_before(): multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(8) ] @@ -168,7 +170,7 @@ def make_iterators_after(): multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(4) ] @@ -188,7 +190,7 @@ def make_iterators_before(): ds, 128, sharding.ShardOptions(shard_index=i, shard_count=8), - ) + ).__iter__() for i in range(8) ] @@ -199,7 +201,7 @@ def make_iterators_after(): ds, 128, sharding.ShardOptions(shard_index=i, shard_count=64), - ) + ).__iter__() for i in range(64) ] @@ -222,7 +224,7 @@ def make_iterators_before(): multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(4) ] @@ -236,7 +238,7 @@ def make_iterators_after(): multiprocessing_options=options.MultiprocessingOptions( num_workers=2 ), - ) + ).__iter__() for i in range(6) ] @@ -251,7 +253,112 @@ def test_filter_raises_error(self): ValueError, "ElasticIterator does not support `filter` transformation.", ): - elastic_iterator.ElasticIterator(ds, 5, sharding.NoSharding()) + elastic_iterator.ElasticIterator(ds, 5, sharding.NoSharding()).__iter__() + + +class ElasticIteratorTest(parameterized.TestCase): + + @parameterized.parameters( + dict( + shard_options=sharding.NoSharding(), + global_batch_size=1, + expected=list(range(15)), + ), + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=1), + global_batch_size=1, + expected=list(range(15)), + ), + dict( + shard_options=sharding.NoSharding(), + global_batch_size=3, + # Data is interleaved with cycle length 3. + expected=[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8, 13], [4, 9, 14]], + ), + ) + def test_no_sharding_produces_correct_elements( + self, shard_options, global_batch_size, expected + ): + ds = [ + # 3 shards, each with 5 elements. + dataset.MapDataset.range(i * 5, (i + 1) * 5).to_iter_dataset() + for i in range(3) + ] + interleave_ds = interleave.InterleaveIterDataset( + ds, cycle_length=global_batch_size + ) + it = elastic_iterator.ElasticIterator( + interleave_ds, + shard_options=shard_options, + global_batch_size=global_batch_size, + ).__iter__() + actual = list(it) + self.assertLen(actual, len(expected)) + for actual_batch, expected_batch in zip(actual, expected): + np.testing.assert_equal(actual_batch, expected_batch) + + @parameterized.parameters( + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=2), + global_batch_size=1, + expected=[0, 2, 4, 6, 8], + ), + dict( + shard_options=sharding.ShardOptions(shard_index=1, shard_count=2), + global_batch_size=1, + expected=[1, 3, 5, 7, 9], + ), + dict( + shard_options=sharding.ShardOptions(shard_index=0, shard_count=2), + global_batch_size=2, + expected=[[0, 2], [4, 6], [8]], + ), + ) + def test_sharding_produces_correct_elements( + self, shard_options, global_batch_size, expected + ): + ds = [ + # 4 shards, 0: [0, 4, 8], 1: [1, 5, 9], 2: [2, 6], 3: [3, 7] + dataset.MapDataset.range(i, 10, 4).to_iter_dataset() + for i in range(4) + ] + # Use cycle_length=2 as in the original test. + interleave_ds = interleave.InterleaveIterDataset(ds, cycle_length=2) + it = elastic_iterator.ElasticIterator( + interleave_ds, + shard_options=shard_options, + global_batch_size=global_batch_size, + ).__iter__() + actual = list(it) + self.assertLen(actual, len(expected)) + for actual_batch, expected_batch in zip(actual, expected): + np.testing.assert_equal(actual_batch, expected_batch) + + def test_checkpointing_no_change(self): + ds = [ + dataset.MapDataset.range(i, 100, 25).to_iter_dataset() + for i in range(25) + ] + global_batch_size = 2 + interleave_ds = interleave.InterleaveIterDataset( + ds, cycle_length=global_batch_size + ) + it = elastic_iterator.ElasticIterator( + interleave_ds, + shard_options=sharding.ShardOptions(shard_index=2, shard_count=4), + global_batch_size=global_batch_size, + ).__iter__() + test_util.assert_equal_output_after_checkpoint(it) + + def test_checkpointing_with_multiprocessing_iter_dataset(self): + ds = dataset.MapDataset.range(10).map(lambda x: x * 2).to_iter_dataset() + it = elastic_iterator.ElasticIterator( + ds, + 2, + sharding.NoSharding(), + multiprocessing_options=options.MultiprocessingOptions(2), + ).__iter__() + test_util.assert_equal_output_after_checkpoint(it) if __name__ == "__main__":