Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pathwaysutils/elastic/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

import collections
from collections.abc import Mapping, Sequence
from collections.abc import Mapping, Sequence, Set
import logging
import time

Expand Down Expand Up @@ -83,7 +83,7 @@ def _simple_execution(devices: Sequence[jax.Device]) -> jax.Array:

def get_slice_to_devices(
devices: Sequence[jax.Device],
) -> dict[int, Sequence[jax.Device]]:
) -> Mapping[int, Sequence[jax.Device]]:
"""Returns the mapping from slice index to devices."""
slice_to_devices = collections.defaultdict(list)
for d in devices:
Expand All @@ -94,7 +94,7 @@ def get_slice_to_devices(
@timing.timeit
def get_active_slice_indices(
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
) -> set[int]:
) -> Set[int]:
"""Returns the set of active slices indices.

Args:
Expand Down Expand Up @@ -153,7 +153,7 @@ def wait_for_slices(
poll_interval: float | int = 10,
timeout: float | int | None = None,
slice_to_devices: Mapping[int, Sequence[jax.Device]] | None = None,
) -> set[int]:
) -> Set[int]:
"""Waits until after at least `slice_count` slices become active.

Args:
Expand Down
4 changes: 2 additions & 2 deletions pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
events. It also provides a utility for waiting for slices to become active.
"""

from collections.abc import Callable, Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence, Set
import functools
import logging
from typing import Any, TypeVar
Expand Down Expand Up @@ -58,7 +58,7 @@ class Manager:

_total_slice_count: int | None = None
slice_to_devices: Mapping[int, Sequence[jax.Device]]
active_slice_indices: set[int]
active_slice_indices: Set[int]

def __init__(self, devices: Sequence[jax.Device] | None = None) -> None:
"""Initializes the manager.
Expand Down
5 changes: 3 additions & 2 deletions pathwaysutils/experimental/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.
"""Experimental profiling utilites."""

from collections.abc import Mapping
from typing import Any

from pathwaysutils import profiling


def start_trace(
profile_request: dict[str, Any],
profile_request: Mapping[str, Any],
*,
create_perfetto_link: bool = False,
create_perfetto_trace: bool = False,
Expand All @@ -33,7 +34,7 @@ def start_trace(
Use `jax.profiler.stop_trace` to end profiling.

Args:
profile_request: A dictionary containing the profile request options.
profile_request: A mapping containing the profile request options.
create_perfetto_link: A boolean which, if true, creates and prints link to
the Perfetto trace viewer UI (https://ui.perfetto.dev). The program will
block until the link is opened and Perfetto loads the trace. This feature
Expand Down
21 changes: 12 additions & 9 deletions pathwaysutils/experimental/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def __init__(
destination_shardings: Sequence[jax.sharding.Sharding],
donate: bool,
):

def ifrt_hlo_sharding(
aval: jax.core.ShapedArray, sharding: jax.sharding.Sharding
) -> dict[str, Any]:
) -> Mapping[str, Any]:
result = {
"devices": {
"device_ids": [
Expand Down Expand Up @@ -190,7 +191,9 @@ class NoIntermediateShardingNeededError(NoIntermediateShardingError):
"""Raised when no intermediate sharding is needed for optimization."""


def _get_sharding_spec_dims(sharding: jax.sharding.NamedSharding) -> list[int]:
def _get_sharding_spec_dims(
sharding: jax.sharding.NamedSharding,
) -> Sequence[int]:
"""Gets the sharding dimension sizes from a NamedSharding."""
mesh = sharding.mesh
dims = []
Expand Down Expand Up @@ -244,7 +247,7 @@ def _get_split_candidates(
src_dims: Sequence[int],
dst_dims: Sequence[int],
gcd_shards: Sequence[int],
) -> list[tuple[int, str]]:
) -> Sequence[tuple[int, str]]:
"""Finds dimensions that are candidates for splitting."""
split_candidates = []
for i, spec in enumerate(in_sharding.spec):
Expand All @@ -271,8 +274,8 @@ def _build_intermediate_mesh_and_spec(
in_spec: jax.sharding.PartitionSpec,
src_dims: Sequence[int],
dst_dims: Sequence[int],
split_candidates: list[tuple[int, str]],
) -> tuple[jax.sharding.Mesh, jax.sharding.PartitionSpec, list[str]]:
split_candidates: Sequence[tuple[int, str]],
) -> tuple[jax.sharding.Mesh, jax.sharding.PartitionSpec, Sequence[str]]:
"""Builds the intermediate Mesh and PartitionSpec."""
# Build a map of mesh axis to split information: (dim_idx, replicas)
mesh_axis_to_split_info = {}
Expand Down Expand Up @@ -321,7 +324,7 @@ def _build_intermediate_mesh_and_spec(

def find_intermediate_sharding(
in_sharding: jax.sharding.Sharding, out_sharding: jax.sharding.Sharding
) -> tuple[jax.sharding.NamedSharding, list[str]]:
) -> tuple[jax.sharding.NamedSharding, Sequence[str]]:
"""Finds an intermediate sharding to reshard to before target sharding.

This function tries to find an intermediate sharding that can be used to
Expand All @@ -343,9 +346,9 @@ def find_intermediate_sharding(
out_sharding: The target sharding.

Returns:
A tuple containing:
- An intermediate sharding.
- A list of axis names that are replicated in the intermediate sharding.
A tuple (intermediate_sharding, replicated_axes), where
replicated_axes is a sequence of axis names that are replicated in the
intermediate sharding.

Raises:
NoIntermediateShardingError: If no intermediate sharding is found.
Expand Down
8 changes: 4 additions & 4 deletions pathwaysutils/persistence/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Helper functions for persistence."""

import base64
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
import concurrent.futures
import datetime
import json
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_hlo_sharding_string(
def get_shape_info(
dtype: np.dtype,
dimensions: Sequence[int],
) -> dict[str, Sequence[int] | str]:
) -> Mapping[str, Sequence[int] | str]:
"""Returns shape info in the format expected by read requests."""
return {
"xla_primitive_type_str": dtype_to_xla_primitive_type_str(dtype),
Expand All @@ -108,7 +108,7 @@ def get_write_request(
jax_array: jax.Array,
timeout: datetime.timedelta,
return_dict: bool = False,
) -> str | dict[str, Any]:
) -> str | Mapping[str, Any]:
"""Returns a string representation of the plugin program which writes the given jax_array to the given location."""
sharding = jax_array.sharding
assert isinstance(sharding, jax.sharding.Sharding), sharding
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_read_request(
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
return_dict: bool = False,
) -> str | dict[str, Any]:
) -> str | Mapping[str, Any]:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
Expand Down
4 changes: 2 additions & 2 deletions pathwaysutils/persistence/orbax_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def serialize(
values: Sequence[jax.Array],
infos: Sequence[ParamInfo],
args: Sequence[SaveArgs] | None = None,
) -> list[future.Future]:
) -> Sequence[future.Future]:
"""Uses Pathways Persistence API to serialize a jax array."""
type_handlers.check_input_arguments(values, infos, args)

Expand Down Expand Up @@ -158,7 +158,7 @@ async def deserialize(
self,
infos: Sequence[ParamInfo],
args: Sequence[RestoreArgs] | None = None,
) -> list[jax.Array]:
) -> Sequence[jax.Array]:
"""Uses Pathways Persistence API to deserialize a jax array."""
if args is None:
raise ValueError("Must provide ArrayRestoreArgs to restore as jax.Array.")
Expand Down
12 changes: 6 additions & 6 deletions pathwaysutils/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import os
import threading
from typing import Any
from typing import Any, Mapping
import urllib.parse

import fastapi
Expand Down Expand Up @@ -59,23 +59,23 @@ def toy_computation() -> None:

def _create_profile_request(
log_dir: os.PathLike[str] | str,
) -> dict[str, Any]:
"""Creates a profile request dictionary from the given options."""
) -> Mapping[str, Any]:
"""Creates a profile request mapping from the given options."""
profile_request = {}
profile_request["traceLocation"] = str(log_dir)

return profile_request


def _start_pathways_trace_from_profile_request(
profile_request: dict[str, Any],
profile_request: Mapping[str, Any],
) -> None:
"""Starts a profiler trace on Pathways components from a profile request.

This will only profile the Pathways components and not the JAX client code.

Args:
profile_request: A dictionary containing the profile request options.
profile_request: A mapping containing the profile request options.
"""
with _profile_state.lock:
global _first_profile_start
Expand Down Expand Up @@ -191,7 +191,7 @@ class ProfilingConfig:
repository_path: str

@app.post("/profiling")
async def profiling(pc: ProfilingConfig) -> dict[str, str]:
async def profiling(pc: ProfilingConfig) -> Mapping[str, str]:
_logger.debug("Capturing profiling data for %s ms", pc.duration_ms)
_logger.debug("Writing profiling data to %s", pc.repository_path)
await asyncio.to_thread(jax.profiler.start_trace, pc.repository_path)
Expand Down
Loading