From 2c8593c7ac6e183f9208be1940dac123373bf9bb Mon Sep 17 00:00:00 2001 From: Pathways-on-Cloud Team Date: Sun, 15 Mar 2026 20:26:36 -0700 Subject: [PATCH] Use abstract types for return type hints in pathwaysutils. This change replaces concrete types like `dict`, `set`, and `list` with their abstract counterparts `Mapping`, `Set`, and `Sequence` in function signatures and class attributes across `pathwaysutils`. This improves type hint flexibility and adheres to Python best practices. PiperOrigin-RevId: 884182804 --- pathwaysutils/elastic/elastic.py | 8 ++++---- pathwaysutils/elastic/manager.py | 4 ++-- pathwaysutils/experimental/profiling.py | 5 +++-- pathwaysutils/experimental/reshard.py | 21 ++++++++++++--------- pathwaysutils/persistence/helper.py | 8 ++++---- pathwaysutils/persistence/orbax_handler.py | 4 ++-- pathwaysutils/profiling.py | 12 ++++++------ 7 files changed, 33 insertions(+), 29 deletions(-) diff --git a/pathwaysutils/elastic/elastic.py b/pathwaysutils/elastic/elastic.py index 81a42f6..9eeb4c6 100644 --- a/pathwaysutils/elastic/elastic.py +++ b/pathwaysutils/elastic/elastic.py @@ -19,7 +19,7 @@ """ import collections -from collections.abc import Mapping, Sequence +from collections.abc import Mapping, Sequence, Set import logging import time @@ -83,7 +83,7 @@ def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array: def get_slice_to_devices( devices: Sequence[jax.Device], -) -> dict[int, Sequence[jax.Device]]: +) -> Mapping[int, Sequence[jax.Device]]: """Returns the mapping from slice index to devices.""" slice_to_devices = collections.defaultdict(list) for d in devices: @@ -94,7 +94,7 @@ def get_slice_to_devices( @timing.timeit def get_active_slice_indices( slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, -) -> set[int]: +) -> Set[int]: """Returns the set of active slices indices. Args: @@ -153,7 +153,7 @@ def wait_for_slices( poll_interval: float | int = 10, timeout: float | int | None = None, slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None, -) -> set[int]: +) -> Set[int]: """Waits until after at least `slice_count` slices become active. Args: diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index c166f90..9fb017b 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -18,7 +18,7 @@ events. It also provides a utility for waiting for slices to become active. """ -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence, Set import functools import logging from typing import Any, TypeVar @@ -58,7 +58,7 @@ class Manager: _total_slice_count: int | None = None slice_to_devices: Mapping[int, Sequence[jax.Device]] - active_slice_indices: set[int] + active_slice_indices: Set[int] def __init__(self, devices: Sequence[jax.Device] | None = None) -> None: """Initializes the manager. diff --git a/pathwaysutils/experimental/profiling.py b/pathwaysutils/experimental/profiling.py index 9b4a71a..0816302 100644 --- a/pathwaysutils/experimental/profiling.py +++ b/pathwaysutils/experimental/profiling.py @@ -13,13 +13,14 @@ # limitations under the License. """Experimental profiling utilites.""" +from collections.abc import Mapping from typing import Any from pathwaysutils import profiling def start_trace( - profile_request: dict[str, Any], + profile_request: Mapping[str, Any], *, create_perfetto_link: bool = False, create_perfetto_trace: bool = False, @@ -33,7 +34,7 @@ def start_trace( Use `jax.profiler.stop_trace` to end profiling. Args: - profile_request: A dictionary containing the profile request options. + profile_request: A mapping containing the profile request options. create_perfetto_link: A boolean which, if true, creates and prints link to the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will block until the link is opened and Perfetto loads the trace. This feature diff --git a/pathwaysutils/experimental/reshard.py b/pathwaysutils/experimental/reshard.py index 9ba4f0b..cb58d50 100644 --- a/pathwaysutils/experimental/reshard.py +++ b/pathwaysutils/experimental/reshard.py @@ -54,9 +54,10 @@ def __init__( destination_shardings: Sequence[jax.sharding.Sharding], donate: bool, ): + def ifrt_hlo_sharding( aval: jax.core.ShapedArray, sharding: jax.sharding.Sharding - ) -> dict[str, Any]: + ) -> Mapping[str, Any]: result = { "devices": { "device_ids": [ @@ -190,7 +191,9 @@ class NoIntermediateShardingNeededError(NoIntermediateShardingError): """Raised when no intermediate sharding is needed for optimization.""" -def _get_sharding_spec_dims(sharding: jax.sharding.NamedSharding) -> list[int]: +def _get_sharding_spec_dims( + sharding: jax.sharding.NamedSharding, +) -> Sequence[int]: """Gets the sharding dimension sizes from a NamedSharding.""" mesh = sharding.mesh dims = [] @@ -244,7 +247,7 @@ def _get_split_candidates( src_dims: Sequence[int], dst_dims: Sequence[int], gcd_shards: Sequence[int], -) -> list[tuple[int, str]]: +) -> Sequence[tuple[int, str]]: """Finds dimensions that are candidates for splitting.""" split_candidates = [] for i, spec in enumerate(in_sharding.spec): @@ -271,8 +274,8 @@ def _build_intermediate_mesh_and_spec( in_spec: jax.sharding.PartitionSpec, src_dims: Sequence[int], dst_dims: Sequence[int], - split_candidates: list[tuple[int, str]], -) -> tuple[jax.sharding.Mesh, jax.sharding.PartitionSpec, list[str]]: + split_candidates: Sequence[tuple[int, str]], +) -> tuple[jax.sharding.Mesh, jax.sharding.PartitionSpec, Sequence[str]]: """Builds the intermediate Mesh and PartitionSpec.""" # Build a map of mesh axis to split information: (dim_idx, replicas) mesh_axis_to_split_info = {} @@ -321,7 +324,7 @@ def _build_intermediate_mesh_and_spec( def find_intermediate_sharding( in_sharding: jax.sharding.Sharding, out_sharding: jax.sharding.Sharding -) -> tuple[jax.sharding.NamedSharding, list[str]]: +) -> tuple[jax.sharding.NamedSharding, Sequence[str]]: """Finds an intermediate sharding to reshard to before target sharding. This function tries to find an intermediate sharding that can be used to @@ -343,9 +346,9 @@ def find_intermediate_sharding( out_sharding: The target sharding. Returns: - A tuple containing: - - An intermediate sharding. - - A list of axis names that are replicated in the intermediate sharding. + A tuple (intermediate_sharding, replicated_axes), where + replicated_axes is a sequence of axis names that are replicated in the + intermediate sharding. Raises: NoIntermediateShardingError: If no intermediate sharding is found. diff --git a/pathwaysutils/persistence/helper.py b/pathwaysutils/persistence/helper.py index 5d8b535..c9b24a1 100644 --- a/pathwaysutils/persistence/helper.py +++ b/pathwaysutils/persistence/helper.py @@ -14,7 +14,7 @@ """Helper functions for persistence.""" import base64 -from collections.abc import Sequence +from collections.abc import Mapping, Sequence import concurrent.futures import datetime import json @@ -94,7 +94,7 @@ def get_hlo_sharding_string( def get_shape_info( dtype: np.dtype, dimensions: Sequence[int], -) -> dict[str, Sequence[int] | str]: +) -> Mapping[str, Sequence[int] | str]: """Returns shape info in the format expected by read requests.""" return { "xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype), @@ -108,7 +108,7 @@ def get_write_request( jax_array: jax.Array, timeout: datetime.timedelta, return_dict: bool = False, -) -> str | dict[str, Any]: +) -> str | Mapping[str, Any]: """Returns a string representation of the plugin program which writes the given jax_array to the given location.""" sharding = jax_array.sharding assert isinstance(sharding, jax.sharding.Sharding), sharding @@ -172,7 +172,7 @@ def get_read_request( devices: Sequence[jax.Device], timeout: datetime.timedelta, return_dict: bool = False, -) -> str | dict[str, Any]: +) -> str | Mapping[str, Any]: """Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding.""" if not isinstance(devices, np.ndarray): devices = np.array(devices) diff --git a/pathwaysutils/persistence/orbax_handler.py b/pathwaysutils/persistence/orbax_handler.py index 539f834..0a992b9 100644 --- a/pathwaysutils/persistence/orbax_handler.py +++ b/pathwaysutils/persistence/orbax_handler.py @@ -96,7 +96,7 @@ async def serialize( values: Sequence[jax.Array], infos: Sequence[ParamInfo], args: Sequence[SaveArgs] | None = None, - ) -> list[future.Future]: + ) -> Sequence[future.Future]: """Uses Pathways Persistence API to serialize a jax array.""" type_handlers.check_input_arguments(values, infos, args) @@ -158,7 +158,7 @@ async def deserialize( self, infos: Sequence[ParamInfo], args: Sequence[RestoreArgs] | None = None, - ) -> list[jax.Array]: + ) -> Sequence[jax.Array]: """Uses Pathways Persistence API to deserialize a jax array.""" if args is None: raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.") diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index f918696..34b315e 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -19,7 +19,7 @@ import logging import os import threading -from typing import Any +from typing import Any, Mapping import urllib.parse import fastapi @@ -59,8 +59,8 @@ def toy_computation() -> None: def _create_profile_request( log_dir: os.PathLike[str] | str, -) -> dict[str, Any]: - """Creates a profile request dictionary from the given options.""" +) -> Mapping[str, Any]: + """Creates a profile request mapping from the given options.""" profile_request = {} profile_request["traceLocation"] = str(log_dir) @@ -68,14 +68,14 @@ def _create_profile_request( def _start_pathways_trace_from_profile_request( - profile_request: dict[str, Any], + profile_request: Mapping[str, Any], ) -> None: """Starts a profiler trace on Pathways components from a profile request. This will only profile the Pathways components and not the JAX client code. Args: - profile_request: A dictionary containing the profile request options. + profile_request: A mapping containing the profile request options. """ with _profile_state.lock: global _first_profile_start @@ -191,7 +191,7 @@ class ProfilingConfig: repository_path: str @app.post("/profiling") - async def profiling(pc: ProfilingConfig) -> dict[str, str]: + async def profiling(pc: ProfilingConfig) -> Mapping[str, str]: _logger.debug("Capturing profiling data for %s ms", pc.duration_ms) _logger.debug("Writing profiling data to %s", pc.repository_path) await asyncio.to_thread(jax.profiler.start_trace, pc.repository_path)