diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1d253d6e2..08c34abc9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -344,7 +344,7 @@ jobs: pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip uninstall -y orbax pip install gcsfs - pip install portpicker pytest chex pyyaml + pip install portpicker pytest chex pyyaml pathwaysutils if [ "${{ matrix.jax-version }}" = "newest" ]; then pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html elif [ "${{ matrix.jax-version }}" = "nightly" ]; then @@ -352,9 +352,34 @@ jobs: else pip install "jax[tpu]==${{ matrix.jax-version }}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html fi - - name: Run multiprocess tests + - name: Run pathways tests + env: + JAX_DEFAULT_BACKEND: pathways + JAX_PLATFORMS: tpu + # Configures JAX to target a subslice within the TPU allocation. + JAX_BACKEND_TARGET: subslice + # Enables IFRT in Pathways. + PATHWAYS_IFRT: true + # Allows JAX to run even if some TPUs are not utilized. + JAX_ALLOW_UNUSED_TPUS: true + run: | + python -c "import pathwaysutils; pathwaysutils.initialize(); print('Pathways initialized'); import jax; print(jax.devices());" && python orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --pathways=1 + - name: Run colacated pathways tests + env: + JAX_PLATFORMS: pathways + JAX_BACKEND_TARGET: subslice + PATHWAYS_IFRT: true + JAX_ALLOW_UNUSED_TPUS: true + PATHWAYS_EXPECTED_INSTANCES: df=1x1,df=1x1,df=1x1,df=1x1 + USE_COLOCATED_PYTHON: true + run: | + python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --pathways=1 + - name: Run 2 multiprocess tests + run: | + python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=2 --tpu_chips_per_process=4 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=2 + - name: Run 4 multiprocess tests run: | - python orbax/checkpoint/_src/testing/oss/run_multihost.py orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4 + python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=4 --tpu_chips_per_process=2 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=4 - name: Run single process tests run: | python orbax/checkpoint/_src/testing/oss/run_multihost.py --num_processes=1 --tpu_chips_per_process=8 orbax/checkpoint/_src/testing/oss/run_tests.py --filename=orbax/checkpoint/_src/testing/oss/tagged_tests.yaml --processes=1 diff --git a/checkpoint/orbax/checkpoint/_src/multihost/multihost_test.py b/checkpoint/orbax/checkpoint/_src/multihost/multihost_test.py new file mode 100644 index 000000000..0b950dec1 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/multihost/multihost_test.py @@ -0,0 +1,212 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from unittest import mock + +from absl.testing import flagsaver +from absl.testing import parameterized +from etils import epath +import jax +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.testing import multiprocess_test + + +class MultihostUtilsTestBase: + + class Test(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.assertEqual(jax.device_count(), 8) + self.assertEqual(jax.process_count(), 4) + self.assertEqual(jax.local_device_count(), 2) + + if not multihost.is_runtime_to_distributed_ids_initialized(): + multihost.initialize_runtime_to_distributed_ids() + + self.tmpdir = epath.Path( + self.create_tempdir(name='multihost_test').full_path + ) + test_utils.sync_global_processes('setUp') + + def tearDown(self): + test_utils.sync_global_processes('tearDown') + super().tearDown() + + def test_process_errors(self): + if multihost.process_index() == 1: + with self.assertRaises(ValueError): + multihost.sync_global_processes( + 'test_process_errors_1', processes={0} + ) + + def test_sync_global_processes(self): + if multihost.process_index() == 0: + time.sleep(2) + (self.tmpdir / 'dummy').mkdir(parents=False, exist_ok=False) + multihost.sync_global_processes('test_sync_global_processes') + self.assertTrue((self.tmpdir / 'dummy').exists()) + + def test_sync_global_processes_partial(self): + participating_processes = {0, 2} + primary_process = 0 + non_primary_process = 1 + + directory = self.tmpdir / 'testdir' + if multihost.process_index() == primary_process: + directory.mkdir(parents=False, exist_ok=False) + test_utils.sync_global_processes( + 'test_sync_global_processes_partial_setup' + ) + + if multihost.process_index() == primary_process: + time.sleep(2) + (directory / 'dummy').mkdir(parents=False, exist_ok=False) + if multihost.process_index() in participating_processes: + multihost.sync_global_processes( + 'test_sync_global_processes_partial', + processes=participating_processes, + ) + if multihost.process_index() in participating_processes: + self.assertTrue((directory / 'dummy').exists()) + else: + self.assertFalse((directory / 'dummy').exists()) + + if multihost.process_index() == primary_process: + time.sleep(2) + (directory / 'foo').mkdir(parents=False, exist_ok=False) + if multihost.process_index() in participating_processes: + multihost.sync_global_processes( + 'test_sync_global_processes_partial_second', + processes=participating_processes, + ) + if multihost.process_index() in participating_processes: + self.assertTrue((directory / 'foo').exists()) + else: + self.assertFalse((directory / 'foo').exists()) + + multihost.sync_global_processes('test_sync_global_processes_partial_all') + # If non-primary processes get past the above barrier without waiting for + # all, then an error would happen for the primary process when trying to + # create subdirectories. + if multihost.process_index() == non_primary_process: + directory.rmtree() + + def test_different_barriers(self): + slice1 = {0, 2} + slice2 = {1, 3} + primary_processes = [0, 1] + + if multihost.process_index() in primary_processes: + # Don't sleep for slice1, but do sleep for slice2, so that when slice1 + # finishes waiting at the barrier, one file exists but the other does + # not. + time.sleep(3 * multihost.process_index()) + (self.tmpdir / f'dummy_{multihost.process_index()}').mkdir( + parents=False, exist_ok=False + ) + + if multihost.process_index() in slice1: + multihost.sync_global_processes( + 'test_different_barriers_slice1', + processes=slice1, + ) + else: + multihost.sync_global_processes( + 'test_different_barriers_slice2', + processes=slice2, + ) + if multihost.process_index() in slice1: + self.assertTrue((self.tmpdir / 'dummy_0').exists()) + self.assertFalse((self.tmpdir / 'dummy_1').exists()) + else: + self.assertTrue((self.tmpdir / 'dummy_0').exists()) + self.assertTrue((self.tmpdir / 'dummy_1').exists()) + + def test_broadcast_one_to_all(self): + if multihost.process_index() == 0: + tree = {'bar': [5, 12]} + else: + tree = {'bar': [0, 0]} + result = multihost.broadcast_one_to_all(tree) + + expected = { + 'bar': [np.asarray(5, dtype=np.int32), np.asarray(12, dtype=np.int32)] + } + test_utils.assert_tree_equal(self, expected, result) + + + def test_sync_global_processes_with_distributed_barrier(self): + with flagsaver.flagsaver( + experimental_orbax_use_distributed_barrier=True + ), mock.patch.object( + multihost.multihost_utils, 'sync_global_devices', autospec=True + ) as mock_sync_global_devices, mock.patch.object( + multihost, 'get_barrier_sync_fn', autospec=True + ) as mock_get_barrier_sync_fn, mock.patch.object( + multihost, 'should_skip_process_sync', return_value=False + ): + multihost.sync_global_processes('test_barrier') + + mock_sync_global_devices.assert_not_called() + mock_get_barrier_sync_fn.assert_called_once_with(processes=None) + mock_get_barrier_sync_fn.return_value.assert_called_once_with( + key='test_barrier', timeout_ms=300000 + ) + + def test_sync_global_processes_without_distributed_barrier(self): + with flagsaver.flagsaver( + experimental_orbax_use_distributed_barrier=False + ), mock.patch.object( + multihost.multihost_utils, 'sync_global_devices', autospec=True + ) as mock_sync_global_devices, mock.patch.object( + multihost, 'get_barrier_sync_fn', autospec=True + ) as mock_get_barrier_sync_fn, mock.patch.object( + multihost, 'should_skip_process_sync', return_value=False + ): + multihost.sync_global_processes('test_barrier') + + mock_sync_global_devices.assert_called_once() + mock_get_barrier_sync_fn.assert_not_called() + + +class MultihostUtilsTestStandard(MultihostUtilsTestBase.Test): + + def setUp(self): + self.enter_context( + flagsaver.flagsaver(experimental_orbax_use_distributed_process_id=False) + ) + super().setUp() + + def test_sync_global_processes_partial(self): + self.skipTest('Fix this scenario.') + + def test_different_barriers(self): + self.skipTest('Fix this scenario.') + + +class MultihostUtilsTestDistributedId(MultihostUtilsTestBase.Test): + + def setUp(self): + self.enter_context( + flagsaver.flagsaver(experimental_orbax_use_distributed_process_id=True) + ) + super().setUp() + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/local_type_handlers_test.py b/checkpoint/orbax/checkpoint/_src/serialization/local_type_handlers_test.py new file mode 100644 index 000000000..9a84815ba --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/local_type_handlers_test.py @@ -0,0 +1,214 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +import unittest + +from absl import flags +from absl.testing import parameterized +from etils import epath +import jax +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.serialization import ocdbt_utils +from orbax.checkpoint._src.serialization import serialization +from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils +from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.testing import multiprocess_test +from orbax.checkpoint.testing import local_path as local_path_test_lib +import tensorstore as ts + + +mock = unittest.mock +PyTree = Any +ParamInfo = type_handlers.ParamInfo +ArrayRestoreArgs = type_handlers.ArrayRestoreArgs +PLACEHOLDER = type_handlers.PLACEHOLDER + + +FLAGS = flags.FLAGS + +jax.config.update('jax_enable_x64', True) + + +class LocalTypeHandlersTest( + unittest.IsolatedAsyncioTestCase, + parameterized.TestCase, +): + """Captures aspects of serialization relevant to type_handlers.""" + + def setUp(self): + super().setUp() + self.base_directory = local_path_test_lib.create_local_path_base(self) + + test_utils.set_tensorstore_driver_for_test() + self.validate_topology() + + test_utils.sync_global_processes('LocalTypeHandlersTest:setup_complete') + + def tearDown(self): + test_utils.sync_global_processes('LocalTypeHandlersTest:tests_complete') + super().tearDown() + + def validate_topology(self): + self.assertEqual(jax.device_count(), 8) + self.assertGreater(jax.process_count(), 1) + + def get_array_handler(self): + return type_handlers.ArrayHandler( + primary_host=None, replica_id=None, use_replica_parallel=False + ) + + @property + def local_directory(self) -> epath.Path: + return local_path_test_lib.LocalPath(self.base_directory) + + def validate_paths(self): + # Array files should not exist at the global path level. + self.assertFalse((self.base_directory / 'manifest.ocdbt').exists()) + self.assertTrue(self.local_directory.exists()) + + def get_param_info( + self, + name: str, + path: epath.Path, + is_ocdbt: bool | None = False, + ts_context: ts.Context | None = None, + raise_array_data_missing_error: bool = True, + ) -> ParamInfo: + return ParamInfo( + name=name, + parent_dir=path, + is_ocdbt_checkpoint=is_ocdbt, + ts_context=ts_context, + raise_array_data_missing_error=raise_array_data_missing_error, + ) + + async def finalize_save( + self, *, ts_context: ts.Context, use_zarr3: bool, use_ocdbt: bool + ): + if use_ocdbt: + await ocdbt_utils.merge_ocdbt_per_process_files( + self.local_directory, ts_context=ts_context, use_zarr3=use_zarr3 + ) + test_utils.sync_global_processes( + 'local_serialization:merge_ocdbt_complete' + ) + + @parameterized.product( + use_ocdbt=(True, False), + use_zarr3=(True, False), + ) + async def test_local_serialization(self, use_ocdbt, use_zarr3): + handler = self.get_array_handler() + sharding = jax.sharding.NamedSharding( + mesh=jax.sharding.Mesh( + devices=np.asarray(jax.devices()), + axis_names=('x',), + ), + spec=jax.sharding.PartitionSpec('x'), + ) + # 8 shards, each of length 4. + arr = jax.device_put(np.arange(32, dtype=np.int32), sharding) + ts_context = ts_utils.get_ts_context(use_ocdbt=use_ocdbt) + info = self.get_param_info( + 'a', + self.local_directory, + is_ocdbt=use_ocdbt, + ts_context=ts_context, + ) + futures = await handler.serialize([arr], [info]) + for f in futures: + f.result() + test_utils.sync_global_processes('test_array_serialization:serialized') + await self.finalize_save( + ts_context=ts_context, use_zarr3=use_zarr3, use_ocdbt=use_ocdbt + ) + + restore_arg = ArrayRestoreArgs( + global_shape=arr.shape, dtype=arr.dtype, sharding=sharding + ) + test_utils.print_directory(self.base_directory) + restored = await handler.deserialize([info], [restore_arg]) + test_utils.assert_array_equal(self, arr, restored[0]) + + @parameterized.product( + use_ocdbt=(True, False), + raise_array_data_missing_error=(True, False), + use_zarr3=(True, False), + ) + async def test_local_serialization_shuffled_devices( + self, use_ocdbt, raise_array_data_missing_error, use_zarr3 + ): + if multihost.is_pathways_backend(): + self.skipTest('Pathways does not support shuffling devices.') + handler = self.get_array_handler() + sharding = jax.sharding.NamedSharding( + mesh=jax.sharding.Mesh( + devices=np.asarray(jax.devices()), + axis_names=('x',), + ), + spec=jax.sharding.PartitionSpec('x'), + ) + # 8 shards, each of length 4. + arr = jax.device_put(np.arange(32, dtype=np.int32), sharding) + zeros_arr = jax.device_put(np.zeros((32,), dtype=np.int32), sharding) + ts_context = ts_utils.get_ts_context(use_ocdbt=use_ocdbt) + info = self.get_param_info( + 'a', + self.local_directory, + is_ocdbt=use_ocdbt, + ts_context=ts_context, + raise_array_data_missing_error=raise_array_data_missing_error, + ) + futures = await handler.serialize([arr], [info]) + for f in futures: + f.result() + test_utils.sync_global_processes('test_array_serialization:serialized') + await self.finalize_save( + ts_context=ts_context, use_zarr3=use_zarr3, use_ocdbt=use_ocdbt + ) + + restore_arg = ArrayRestoreArgs( + global_shape=arr.shape, dtype=arr.dtype, sharding=sharding + ) + + orig_get_device_to_index_map = serialization._get_device_to_index_map + + def shuffled_get_device_to_index_map(global_shape, sharding): + device_to_index_map = orig_get_device_to_index_map(global_shape, sharding) + processes = [d.process_index for d in device_to_index_map.keys()] + assert processes == sorted(processes) + devices = list(device_to_index_map.keys()) + devices.reverse() + return dict(zip(devices, device_to_index_map.values())) + + with mock.patch.object( + serialization, + '_get_device_to_index_map', + new=shuffled_get_device_to_index_map, + ): + if raise_array_data_missing_error: + with self.assertRaisesRegex( + Exception, 'Encountered error while reading array index' + ): + await handler.deserialize([info], [restore_arg]) + else: + restored = await handler.deserialize([info], [restore_arg]) + test_utils.assert_array_equal(self, zeros_arr, restored[0]) + + +if __name__ == '__main__': + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/pathways_local_type_handlers_test.py b/checkpoint/orbax/checkpoint/_src/serialization/pathways_local_type_handlers_test.py new file mode 100644 index 000000000..b98d6efe0 --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/pathways_local_type_handlers_test.py @@ -0,0 +1,113 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for Pathways local type handlers.""" + +from typing import Any, Sequence + +from absl import flags +import jax +from orbax.checkpoint import test_utils +from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib +from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.multihost import pathways as multihost_pathways +from orbax.checkpoint._src.serialization import local_type_handlers_test +from orbax.checkpoint._src.serialization import ocdbt_utils +from orbax.checkpoint._src.serialization import pathways_handler_registry +from orbax.checkpoint._src.serialization import type_handler_registry +from orbax.checkpoint._src.testing import multiprocess_test +import tensorstore as ts + + +USE_COLOCATED_PYTHON = flags.DEFINE_boolean( + 'use_colocated_python', + False, + 'Whether to use colocated Python.', +) + + +class FakeArrayMetadataStore(array_metadata_store_lib.Store): + """A fake in-memory store that mimics the real checkpoint store API.""" + + def __init__(self): + self._data = {} + + async def write( + self, + checkpoint_dir: Any, + array_metadatas: Sequence[Any], + process_index: int, + ) -> None: + """Simulates writing metadata to storage.""" + if checkpoint_dir not in self._data: + self._data[checkpoint_dir] = [] + self._data[checkpoint_dir].extend(array_metadatas) + + async def read( + self, checkpoint_dir: Any, process_index: int | None = None + ) -> Any: + """Simulates reading metadata from storage.""" + return {0: self._data.get(checkpoint_dir, [])} + + +class PathwaysLocalTypeHandlersTest( + local_type_handlers_test.LocalTypeHandlersTest, +): + + def setUp(self): + super().setUp() + self.assertTrue(multihost.is_pathways_backend()) + + def validate_topology(self): + self.assertEqual(jax.device_count(), 8) + self.assertGreater(multihost_pathways.worker_count(None), 1) + + def get_array_handler(self): + pathways_handler_registry.register_pathways_handlers( + checkpointing_impl=pathways_handler_registry.CheckpointingImpl.from_options( + use_colocated_python=USE_COLOCATED_PYTHON.value, + use_remote_python=True, # Fallback + ), + primary_host=None, + replica_id=None, + use_replica_parallel=False, + thinmint_testing=True, + array_metadata_store=FakeArrayMetadataStore(), + ) + return type_handler_registry.get_type_handler(jax.Array) + + def validate_paths(self): + # Array files should not exist at the global path level. + self.assertFalse((self.base_directory / 'manifest.ocdbt').exists()) + for worker_id in range(multihost_pathways.worker_count(None)): + self.assertTrue((self.base_directory / f'local_{worker_id}').exists()) + + async def finalize_save( + self, *, ts_context: ts.Context, use_zarr3: bool, use_ocdbt: bool + ): + if use_ocdbt: + for worker_id in range(multihost_pathways.worker_count(None)): + await ocdbt_utils.merge_ocdbt_per_process_files( + self.base_directory / f'local_{worker_id}', + ts_context=ts_context, + use_zarr3=use_zarr3, + ) + test_utils.sync_global_processes( + 'local_serialization:merge_ocdbt_complete' + ) + + +if __name__ == '__main__': + jax.config.parse_flags_with_absl() + multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/_src/serialization/pathways_memory_usage_test.py b/checkpoint/orbax/checkpoint/_src/serialization/pathways_memory_usage_test.py new file mode 100644 index 000000000..7c5e518fd --- /dev/null +++ b/checkpoint/orbax/checkpoint/_src/serialization/pathways_memory_usage_test.py @@ -0,0 +1,396 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time +from unittest import mock + +from absl import flags +from absl.testing import parameterized +from etils import epath +import jax +import numpy as np +from orbax.checkpoint import test_utils +from orbax.checkpoint import utils +from orbax.checkpoint._src.handlers import pytree_checkpoint_handler +from orbax.checkpoint._src.multihost import dispatchers +from orbax.checkpoint._src.serialization import jax_array_handlers +from orbax.checkpoint._src.serialization import pathways_handler_registry +from orbax.checkpoint._src.serialization import type_handler_registry +from orbax.checkpoint._src.serialization import types as serialization_types +from orbax.checkpoint._src.serialization import worker_memory_utils +from orbax.checkpoint._src.tree import utils as tree_utils + +from .learning.deepmind.jax.ocean.remote_python import rp +from .pyglib.contrib.g3_multiprocessing import g3_multiprocessing +from absl.testing import absltest + +USE_COLOCATED_PYTHON = flags.DEFINE_boolean( + 'use_colocated_python', + False, + 'Whether to use colocated Python.', +) + +FLAGS = flags.FLAGS +PyTreeCheckpointHandler = test_utils.PyTreeCheckpointHandler +PyTreeSaveArgs = pytree_checkpoint_handler.PyTreeSaveArgs +PyTreeRestoreArgs = pytree_checkpoint_handler.PyTreeRestoreArgs +ArrayRestoreArgs = pytree_checkpoint_handler.ArrayRestoreArgs +ParamInfo = serialization_types.ParamInfo +SaveArgs = serialization_types.SaveArgs + + +def _get_dispatcher(use_colocated_python: bool): + return ( + dispatchers.ColocatedPythonDispatcher() + if use_colocated_python + else dispatchers.RemotePythonDispatcher() + ) + + +def _get_actual_worker_memory_usage( + arr: jax.Array, use_colocated_python: bool +) -> dict[int, int]: + dispatcher = _get_dispatcher(use_colocated_python=use_colocated_python) + device_count = jax.device_count() + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), 'x'), + jax.sharding.PartitionSpec( + 'x', + ), + ) + + def _get_actual_worker_memory_usage_impl( + a: jax.Array, sharding: jax.sharding.Sharding, device_count: int + ) -> jax.Array: + bytes_size = a.itemsize * sum( + [shard.data.size for shard in a.addressable_shards] + ) + return jax.make_array_from_callback( + (device_count,), + sharding, + lambda _: np.array(bytes_size).reshape( + 1, + ), + dtype=np.int32, + ) + + result_specs = jax.ShapeDtypeStruct( + (device_count,), dtype=np.int32, sharding=sharding + ) + actual_worker_memory_usage_by_device = dispatcher.dispatch( + _get_actual_worker_memory_usage_impl, + input_arrays=arr, + result_specs=result_specs, + func_kwargs={ + 'sharding': sharding, + 'device_count': device_count, + }, + ) + jax.block_until_ready(actual_worker_memory_usage_by_device) + + device_to_worker_ids = worker_memory_utils._device_to_worker_ids(dispatcher) + actual_worker_memory_usage = {} + for shard in actual_worker_memory_usage_by_device.addressable_shards: + worker_id = device_to_worker_ids[shard.device.id] + memory_usage = np.asarray(shard.data) + assert memory_usage.shape == (1,) + memory_usage = int(memory_usage[0]) + if worker_id in actual_worker_memory_usage: + assert actual_worker_memory_usage[worker_id] == memory_usage + else: + actual_worker_memory_usage[worker_id] = memory_usage + + return actual_worker_memory_usage + + +def _create_array( + array_shape: tuple[int, ...], + mesh_shape: tuple[int, ...], + mesh_axes: tuple[str, ...], + partition_axes: tuple[str | None, ...], +) -> jax.Array: + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh( + np.asarray(jax.devices()).reshape(mesh_shape), mesh_axes + ), + jax.sharding.PartitionSpec(*partition_axes), + ) + return jax.device_put( + np.arange(np.prod(array_shape), dtype=np.float32).reshape(array_shape), + device=sharding, + ) + + +class PathwaysMemoryUsageTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self._use_colocated_python = USE_COLOCATED_PYTHON.value + pathways_handler_registry.register_pathways_handlers( + checkpointing_impl=pathways_handler_registry.CheckpointingImpl.from_options( + use_colocated_python=self._use_colocated_python, + use_remote_python=True, # Fallback + ), + thinmint_testing=True, + ) + PyTreeCheckpointHandler() + self.assertTrue(utils.is_pathways_backend()) + self.assertTrue(rp.available()) + self.assertIsInstance( + type_handler_registry.get_type_handler(jax.Array), + jax_array_handlers.ArrayHandler, + ) + + self.directory = epath.Path( + self.create_tempdir(name='checkpointing_test').full_path + ) + test_utils.set_tensorstore_driver_for_test() + + test_utils.sync_global_processes('PathwaysMemoryUsageTest:setup_complete') + + def tearDown(self): + test_utils.sync_global_processes('PathwaysMemoryUsageTest:tests_complete') + super().tearDown() + + @parameterized.parameters( + ((8,), ('x',), (None,), (1,)), + ((8,), ('x',), (None,), (64,)), + ((8,), ('x',), ('x',), (64,)), + ((4, 2), ('x', 'y'), (None, None), (64, 16)), + ((4, 2), ('x', 'y'), ('x', None), (64, 16)), + ((4, 2), ('x', 'y'), ('x', 'y'), (64, 16)), + ) + def test_worker_memory_usage_calculation( + self, + mesh_shape, + mesh_axes, + partition_axes, + array_shape, + ): + device_to_worker_ids = worker_memory_utils._device_to_worker_ids( + _get_dispatcher(use_colocated_python=self._use_colocated_python) + ) + arr = _create_array(array_shape, mesh_shape, mesh_axes, partition_axes) + actual_worker_memory_usage = _get_actual_worker_memory_usage( + arr, self._use_colocated_python + ) + estimated_worker_memory_usage = ( + worker_memory_utils._estimate_worker_memory_usage( + arr, replica_id=None, device_to_worker_ids_map=device_to_worker_ids + ) + ) + self.assertDictEqual( + actual_worker_memory_usage, estimated_worker_memory_usage + ) + + @parameterized.parameters( + (1, 200, ('x', 'y'), [1]), + (1, 300, ('x', 'y'), [1]), + (2, 200, ('x', 'y'), [1, 1]), + (2, 300, ('x', 'y'), [1, 1]), + (2, 513, ('x', 'y'), [2]), + (3, 513, ('x', 'y'), [2, 1]), + (2, 513, (None, None), [1, 1]), + ) + def test_batching( + self, + num_arrays, + device_host_max_bytes, + partition_axes, + expected_batch_sizes, + ): + mesh_shape = (4, 2) + mesh_axes = ('x', 'y') + # Array total size = 16 * 16 * 4 bytes = 1024 bytes. + # Shard size = 1024 / (8 devices) = 128 bytes (fully replicated) + # Per host, per array size = 128 * (2 workers_per_device) = 256 bytes. + array_shape = (16, 16) + + values = [ + _create_array(array_shape, mesh_shape, mesh_axes, partition_axes) + for _ in range(num_arrays) + ] + infos = [mock.Mock(spec=ParamInfo)() for _ in values] + args = [SaveArgs() for _ in values] + + batch_idx = 0 + for batch in worker_memory_utils.next_memory_budgeted_batch( + list(zip(values, infos, args)), + device_host_max_bytes, + replica_id=0, + dispatcher=_get_dispatcher( + use_colocated_python=self._use_colocated_python + ), + ): + self.assertLen(batch, expected_batch_sizes[batch_idx]) + batch_idx += 1 + self.assertLen(expected_batch_sizes, batch_idx) + + def test_save_restore(self): + arrays = [] + arr_size = 2**26 + num_arrays = 10 + for _ in range(num_arrays): + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), 'x'), + jax.sharding.PartitionSpec( + 'x', + ), + ) + # ~268 MB per array + arr = jax.device_put( + np.arange(arr_size, dtype=np.float32), device=sharding + ) + arrays.append(arr) + handler = PyTreeCheckpointHandler( + use_ocdbt=False, is_prioritized_key_fn=lambda _: False + ) + arr_bytes = arrays[0].itemsize * arr_size + unique_shards = jax.device_count() + shards_per_worker = 2 + arrays_per_batch = 4 + handler._handler_impl._save_device_host_concurrent_bytes = ( + arr_bytes // unique_shards * shards_per_worker * arrays_per_batch + 1000 + ) + handler.save(self.directory, args=PyTreeSaveArgs(arrays)) + + # Verify that individual param mtimes increase with each successive + # batch. This verifies that batch `i` completes before batch `i+1` starts. + param_mtimes = [0] * num_arrays + for param_dir in self.directory.iterdir(): + if param_dir.is_dir() and param_dir.name.isdigit(): + mtime = param_dir.stat().mtime + param_mtimes[int(param_dir.name)] = mtime + prev_greatest_mtime = -1 + for i in range(0, num_arrays, arrays_per_batch): + cur_greatest_mtime = max(param_mtimes[i : i + arrays_per_batch]) + self.assertGreaterEqual(cur_greatest_mtime, prev_greatest_mtime) + prev_greatest_mtime = cur_greatest_mtime + + # Verify restore correctness. + restore_args = jax.tree.map( + lambda x: ArrayRestoreArgs(sharding=x.sharding), arrays + ) + restored = handler.restore( + self.directory, args=PyTreeRestoreArgs(restore_args=restore_args) + ) + test_utils.assert_tree_equal(self, arrays, restored) + handler.close() + + def test_save_restore_no_memory_limiting(self): + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), 'x'), + jax.sharding.PartitionSpec( + 'x', + ), + ) + arr_size = 256 + arrays = [ + jax.device_put(np.arange(arr_size, dtype=np.float32), device=sharding) + ] + + handler = PyTreeCheckpointHandler() + original_get_deprioritized_batches_to_serialize = ( + jax_array_handlers._get_deprioritized_batches_to_serialize + ) + with mock.patch.object( + jax_array_handlers, + '_get_deprioritized_batches_to_serialize', + wraps=original_get_deprioritized_batches_to_serialize, + ) as mock_get_deprioritized_batches_to_serialize: + handler.save(self.directory, args=PyTreeSaveArgs(arrays)) + + # Assert that _get_deprioritized_batches_to_serialize was not called + mock_get_deprioritized_batches_to_serialize.assert_not_called() + + # Verify restore correctness. + restore_args = jax.tree.map( + lambda x: ArrayRestoreArgs(sharding=x.sharding), arrays + ) + restored = handler.restore( + self.directory, args=PyTreeRestoreArgs(restore_args=restore_args) + ) + test_utils.assert_tree_equal(self, arrays, restored) + handler.close() + + # TODO(cpgaffney): Test with an async D2H time that is artificially long. + # Currently there is not a good way to guarantee this. + def test_save_restore_with_prioritized_params(self): + arrays = [] + arr_size = 2**26 + num_arrays = 10 + for _ in range(num_arrays): + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(jax.devices(), 'x'), + jax.sharding.PartitionSpec( + 'x', + ), + ) + # ~268 MB per array + arr = jax.device_put( + np.arange(arr_size, dtype=np.float32), device=sharding + ) + arrays.append(arr) + prioritized_keys_for_saving = [ + ('0',), + ('2',), + ('4',), + ('6',), + ('8',), + ] + handler = PyTreeCheckpointHandler( + use_ocdbt=False, + is_prioritized_key_fn=lambda key: tree_utils.str_keypath(key) + in prioritized_keys_for_saving, + ) + arr_bytes = arrays[0].itemsize * arr_size + unique_shards = jax.device_count() + shards_per_worker = 2 + arrays_per_batch = 4 + handler._handler_impl._save_device_host_concurrent_bytes = ( + arr_bytes // unique_shards * shards_per_worker * arrays_per_batch + 1000 + ) + start = time.time() + handler.save(self.directory, args=PyTreeSaveArgs(arrays)) + end = time.time() + logging.info('Time taken: %s seconds', end - start) + + # Verify all even params complete before odd params are started. + param_ctimes = [0] * num_arrays + param_mtimes = [0] * num_arrays + for param_dir in self.directory.iterdir(): + if param_dir.is_dir() and param_dir.name.isdigit(): + ctime = os.stat(param_dir).st_ctime + mtime = os.stat(param_dir).st_mtime + param_mtimes[int(param_dir.name)] = mtime + param_ctimes[int(param_dir.name)] = ctime + + self.assertLessEqual(max(param_mtimes[::2]), min(param_ctimes[1::2])) + + # Verify restore correctness. + restore_args = jax.tree.map( + lambda x: ArrayRestoreArgs(sharding=x.sharding), arrays + ) + restored = handler.restore( + self.directory, args=PyTreeRestoreArgs(restore_args=restore_args) + ) + test_utils.assert_tree_equal(self, arrays, restored) + handler.close() + + +if __name__ == '__main__': + jax.config.parse_flags_with_absl() + g3_multiprocessing.handle_test_main(googletest.main) diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py b/checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py index b1275067f..94e0acfaf 100644 --- a/checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/generate_multiprocess_test.py @@ -15,6 +15,7 @@ """Script to generate YAML file with test targets based on tags.""" import ast +import collections import os import sys @@ -29,7 +30,11 @@ 'pytype_strict_contrib_test', ] EXCLUDED_PATHS = [ - 'orbax/checkpoint/experimental', + 'orbax/checkpoint/experimental/model_surgery', + 'orbax/checkpoint/experimental/v1', + 'orbax/checkpoint/experimental/emergency/p2p', + 'orbax/checkpoint/experimental/emergency/checkpoint_manager_test.py', + 'orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py', 'orbax/checkpoint/google', ] @@ -60,53 +65,17 @@ def get_str_val(node): return None -def inherits_from_multiprocess_test(test_file_path): - """Checks if test file inherits from MultiProcessTest.""" - try: - with open(test_file_path, 'r') as f: - content = f.read() - except FileNotFoundError: - return False - try: - tree = ast.parse(content, filename=test_file_path) - except SyntaxError: - return False +def get_num_processes(args): + """Returns num_processes from args.""" + for arg in args: + if arg.startswith('--num_processes='): + try: + return int(arg.split('=', 1)[1]) + except ValueError: + return None + return None - imported_as_name = None # if imported as `from ... import MultiProcessTest` - imported_as_module = [] # if imported as `from ... import multiprocess_test` - for node in tree.body: - if isinstance(node, ast.ImportFrom): - if node.module == 'orbax.checkpoint._src.testing.multiprocess_test': - for alias in node.names: - if alias.name == 'MultiProcessTest': - imported_as_name = alias.asname or alias.name - elif node.module == 'orbax.checkpoint._src.testing': - for alias in node.names: - if alias.name == 'multiprocess_test': - imported_as_module.append(alias.asname or alias.name) - - if not imported_as_name and not imported_as_module: - return False - - for node in tree.body: - if isinstance(node, ast.ClassDef): - for base in node.bases: - if ( - imported_as_name - and isinstance(base, ast.Name) - and base.id == imported_as_name - ): - return True - if ( - imported_as_module - and isinstance(base, ast.Attribute) - and isinstance(base.value, ast.Name) - and base.value.id in imported_as_module - and base.attr == 'MultiProcessTest' - ): - return True - return False def get_build_targets(build_file_path): @@ -135,18 +104,18 @@ def get_build_targets(build_file_path): if rule_name in TEST_RULES: kwargs = get_kwargs(call) - if 'name' in kwargs and 'tags' in kwargs: + if 'name' in kwargs: name = get_str_val(kwargs['name']) - tags = get_list_val(kwargs['tags']) + tags = get_list_val(kwargs['tags']) if 'tags' in kwargs else [] srcs = get_list_val(kwargs['srcs']) if 'srcs' in kwargs else [] - if name and tags: - yield name, tags, srcs + args = get_list_val(kwargs['args']) if 'args' in kwargs else [] + if name: + yield name, tags, srcs, args def run(root_dir, output_file): """Runs the script to generate tagged tests file.""" - tests_by_tag = {tag: [] for tag in TAG_MAPPING.values()} - tests_by_tag['processes:1'] = [] + tests_by_tag = collections.defaultdict(list) count = 0 for dirpath, dirnames, filenames in os.walk(root_dir): @@ -166,35 +135,33 @@ def run(root_dir, output_file): count += 1 build_file = os.path.join(dirpath, 'BUILD') package_path = dirpath.removeprefix('third_party/py/') - for name, tags, srcs in get_build_targets(build_file): + for name, tags, srcs, args in get_build_targets(build_file): + if not any(tag in TAG_MAPPING for tag in tags): + continue if srcs and any( os.path.join(dirpath, srcs[0]).startswith(p) for p in EXCLUDED_PATHS ): continue - is_multiprocess = False - if srcs: - is_multiprocess = inherits_from_multiprocess_test( - os.path.join(dirpath, srcs[0]) - ) target_path = f'{package_path}:{name}' - if not is_multiprocess: - tests_by_tag['processes:1'].append(target_path) + num_processes = get_num_processes(args) + if num_processes and num_processes > 1: + tag = f'processes:{num_processes}' + tests_by_tag[tag].append(target_path) else: - for tag in tags: - if tag in TAG_MAPPING: - tests_by_tag[TAG_MAPPING[tag]].append(target_path) + tests_by_tag['processes:1'].append(target_path) print(f'Processed {count} BUILD files.') + result_dict = {} for tag in tests_by_tag: - tests_by_tag[tag] = sorted(list(set(tests_by_tag[tag]))) + result_dict[tag] = sorted(list(set(tests_by_tag[tag]))) header = """# DO NOT EDIT! """ os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, 'w') as f: f.write(header) - yaml.dump(tests_by_tag, f, default_flow_style=False) + yaml.dump(result_dict, f, default_flow_style=False) print(f'Output written to {output_file}') diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py b/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py index 8e66f33f0..e2fba6ebe 100755 --- a/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/run_tests.py @@ -36,6 +36,11 @@ None, 'Number of processes to select test list from yaml file.', ) +flags.DEFINE_integer( + 'pathways', + None, + 'Number of pathways to select test list from yaml file.', +) def install_deps(): @@ -84,6 +89,13 @@ def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') + if FLAGS.processes is None and FLAGS.pathways is None: + raise app.UsageError('Either --processes or --pathways must be specified.') + if FLAGS.processes is not None and FLAGS.pathways is not None: + raise app.UsageError( + 'Only one of --processes or --pathways can be specified.' + ) + install_deps() try: @@ -97,13 +109,15 @@ def main(argv: Sequence[str]) -> None: logging.error('YAML file not found: %s', FLAGS.filename) sys.exit(1) - key = f'processes:{FLAGS.processes}' + if FLAGS.processes is not None: + key = f'processes:{FLAGS.processes}' + else: + key = f'pathways:{FLAGS.pathways}' + if key not in tests_by_process_count: logging.error( - 'key=%s (from processes=%d) not found as a key in %s. Available' - ' keys: %s', + 'key=%s not found as a key in %s. Available keys: %s', key, - FLAGS.processes, FLAGS.filename, list(tests_by_process_count.keys()), ) @@ -112,8 +126,8 @@ def main(argv: Sequence[str]) -> None: test_files = tests_by_process_count[key] if not test_files: logging.warning( - 'No test files found for processes=%d in %s.', - FLAGS.processes, + 'No test files found for key=%s in %s.', + key, FLAGS.filename, ) return @@ -157,5 +171,4 @@ def main(argv: Sequence[str]) -> None: if __name__ == '__main__': flags.mark_flag_as_required('filename') - flags.mark_flag_as_required('processes') app.run(main) diff --git a/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests.yaml b/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests.yaml index c82405d3b..4f04cd371 100755 --- a/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests.yaml +++ b/checkpoint/orbax/checkpoint/_src/testing/oss/tagged_tests.yaml @@ -1,27 +1,31 @@ # DO NOT EDIT! -processes:1: -- orbax/checkpoint/_src/metadata:sharding_tpu_test -- orbax/checkpoint/_src/multihost:multihost_test -- orbax/checkpoint/_src/multihost:multislice_ghostfish_test -- orbax/checkpoint/_src/multihost:multislice_test +pathways:1: - orbax/checkpoint/_src/serialization:colocated_pathways_local_type_handlers_test - orbax/checkpoint/_src/serialization:colocated_pathways_memory_usage_test -- orbax/checkpoint/_src/serialization:local_type_handlers_test - orbax/checkpoint/_src/serialization:pathways_local_type_handlers_test - orbax/checkpoint/_src/serialization:pathways_memory_usage_test +processes:1: +- orbax/checkpoint/_src/futures:future_test +- orbax/checkpoint/_src/metadata:sharding_tpu_test +- orbax/checkpoint/_src/multihost:multislice_test - orbax/checkpoint/_src/serialization:replica_slices_test - orbax/checkpoint/_src/serialization:serialization_test - orbax/checkpoint:single_host_test -processes:4: +processes:2: - orbax/checkpoint/_src/checkpointers:async_checkpointer_test - orbax/checkpoint/_src/checkpointers:checkpointer_test -- orbax/checkpoint/_src/futures:future_test - orbax/checkpoint/_src/handlers:array_checkpoint_handler_test - orbax/checkpoint/_src/handlers:pytree_checkpoint_handler_test - orbax/checkpoint/_src/handlers:standard_checkpoint_handler_test +- orbax/checkpoint/_src/serialization:local_type_handlers_test - orbax/checkpoint/_src/serialization:type_handlers_test +processes:4: +- orbax/checkpoint/_src/multihost:multihost_test - orbax/checkpoint/_src/testing/tree_verity:checkpoint_manager_test +- orbax/checkpoint/experimental/emergency:local_checkpoint_data_debugging_test +- orbax/checkpoint/experimental/emergency:local_checkpoint_manager_test +- orbax/checkpoint/experimental/emergency:process_metadata_checkpoint_handler_test +- orbax/checkpoint/experimental/emergency:single_slice_checkpoint_manager_test - orbax/checkpoint/testing:local_path_test - orbax/checkpoint:checkpoint_manager_slice_test - orbax/checkpoint:checkpoint_manager_test -processes:8: [] diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py deleted file mode 100644 index 223066301..000000000 --- a/checkpoint/orbax/checkpoint/experimental/emergency/local_checkpoint_data_debugging_test.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2026 The Orbax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from absl.testing import flagsaver -from absl.testing import parameterized -from etils import epath -import jax -import numpy as np -from orbax.checkpoint import test_utils -from orbax.checkpoint._src.arrays import types -from orbax.checkpoint._src.multihost import multihost -from orbax.checkpoint._src.serialization import serialization -from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils -from orbax.checkpoint._src.testing import multiprocess_test -from orbax.checkpoint.experimental.emergency import local_checkpoint_data_debugging - - -Index = types.Index -Shape = types.Shape -ChunkId = tuple[int, ...] - -index_to_chunk_id = local_checkpoint_data_debugging.index_to_chunk_id -get_chunk_ids_from_tensorstore = ( - local_checkpoint_data_debugging.get_chunk_ids_from_tensorstore -) -open_tensorstore = local_checkpoint_data_debugging.open_tensorstore -get_present_and_missing_chunks = ( - local_checkpoint_data_debugging.get_present_and_missing_chunks -) - - -class LocalCheckpointDataValidatorTest( - unittest.IsolatedAsyncioTestCase, - parameterized.TestCase, - multiprocess_test.MultiProcessTest, -): - - def make_global_mesh(self) -> jax.sharding.Mesh: - self.assertEqual(jax.device_count(), 8) - self.assertEqual(jax.process_count(), 4) - self.assertEqual(jax.local_device_count(), 2) - return jax.sharding.Mesh(jax.devices(), ('data',)) - - def setUp(self): - super().setUp() - self.enter_context( - flagsaver.flagsaver(experimental_orbax_use_distributed_process_id=True) - ) - if not multihost.is_runtime_to_distributed_ids_initialized(): - multihost.initialize_runtime_to_distributed_ids() - - self.global_mesh = self.make_global_mesh() - - # make sure each process is working on different directories - self.local_directory = epath.Path( - self.create_tempdir( - name=self._local_directory_for_process(multihost.process_index()) - ).full_path - ) - test_utils.set_tensorstore_driver_for_test() - test_utils.sync_global_processes('CheckpointManagerTest:setup_complete') - - def tearDown(self): - super().tearDown() - test_utils.sync_global_processes('CheckpointManagerTest:teardown_complete') - - def _local_directory_for_process(self, process_index: int) -> epath.Path: - return f'local_checkpointing_test_pid_{process_index}' - - async def _write_array( - self, - array: jax.Array, - param_name: str, - *, - use_ocdbt: bool, - use_zarr3: bool, - ): - tspec = ts_utils.ArrayWriteSpec( - self.local_directory.as_posix(), - param_name, - global_shape=array.shape, - write_shape=array.sharding.shard_shape(array.shape), - dtype=array.dtype, - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - ).json - replica_id = array.addressable_shards[0].replica_id - await serialization.async_serialize( - array, - tspec, - context=ts_utils.get_ts_context(use_ocdbt=use_ocdbt), - primary_host=None, - replica_id=replica_id, - ) - - @parameterized.product(use_ocdbt=[False, True], use_zarr3=[False, True]) - async def test_main(self, use_ocdbt: bool, use_zarr3: bool): - self.assertEqual(multihost.process_count(), 4) - param_name = 'array' - array = test_utils.create_sharded_array( - np.arange(16), - self.global_mesh, - jax.sharding.PartitionSpec('data'), - ) - - await self._write_array( - array, param_name, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3 - ) - test_utils.sync_global_processes('sync_after_write_array') - - # Rearrange two local directories to simulate restart. - if multihost.process_index() == 0: - process_0_directory = self.local_directory - process_1_directory = ( - self.local_directory.parent / self._local_directory_for_process(1) - ) - tmp_directory = self.local_directory.parent / 'tmp' - process_0_directory.rename(tmp_directory) - process_1_directory.rename(process_0_directory) - tmp_directory.rename(process_1_directory) - test_utils.sync_global_processes('sync_after_local_dir_rearrange') - - present_chunk_ids, missing_chunk_ids = await get_present_and_missing_chunks( - self.local_directory, - param_name, - array, - use_ocdbt=use_ocdbt, - use_zarr3=use_zarr3, - ) - - if multihost.process_index() == 0: - self.assertSameElements(present_chunk_ids, ((4,), (5,))) - self.assertSameElements(missing_chunk_ids, ((0,), (1,))) - elif multihost.process_index() == 1: - self.assertSameElements(present_chunk_ids, ((0,), (1,))) - self.assertSameElements(missing_chunk_ids, ((4,), (5,))) - elif multihost.process_index() == 2: - self.assertSameElements(present_chunk_ids, ((2,), (3,))) - self.assertEmpty(missing_chunk_ids) - elif multihost.process_index() == 3: - self.assertSameElements(present_chunk_ids, ((6,), (7,))) - self.assertEmpty(missing_chunk_ids) - - -if __name__ == '__main__': - multiprocess_test.main() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py index 122b93d06..c2660ef5f 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/replicator_checkpoint_manager_test.py @@ -31,12 +31,12 @@ from orbax.checkpoint._src.multihost import multislice from orbax.checkpoint._src.path import gcs_utils from orbax.checkpoint._src.serialization import type_handlers +from orbax.checkpoint._src.testing import multiprocess_test from orbax.checkpoint.experimental.emergency import mesh_consistency from orbax.checkpoint.experimental.emergency import replicator_checkpoint_manager from orbax.checkpoint.experimental.emergency.test_utils import dataset_iterator_checkpoint_handler from orbax.checkpoint.experimental.emergency.test_utils import test_base as emergency_test_utils from orbax.checkpoint.path import atomicity -from .learning.brain.research.jax.tests.multiprocess import multiprocess_test PyTree = Any