Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 131 additions & 112 deletions pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import threading
from typing import Any, TypeVar
import warnings

import jax
from pathwaysutils.elastic import elastic
Expand Down Expand Up @@ -156,40 +157,136 @@ def _cleanup_on_retry(self):
for array in jax.live_arrays():
array.delete()

def _elasticity_retry_decorator(
def _monitor_new_slices(
self, stop_event: threading.Event, poll_interval: float | int
):
"""Monitors for new slices and sets the `new_slice_event` if found."""
while not stop_event.wait(poll_interval):
try:
if not self.inactive_slice_indices:
_logger.debug("No inactive slices to check.")
continue

_logger.debug(
"Checking inactive slices: %s", self.inactive_slice_indices
)
inactive_slice_to_devices = {
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
}
newly_active_indices = elastic.get_active_slice_indices(
inactive_slice_to_devices
)

if newly_active_indices:
_logger.info(
"New slices found: %s. Setting new slice event.",
newly_active_indices,
)
self.new_slice_event.set()
return

_logger.debug("No new slices found.")
except Exception: # pylint: disable=broad-exception-caught
_logger.exception("Error in monitor thread")

def elastic_retry(
self,
max_retries: int,
minimum_slice_count: int | None = None,
poll_interval: float | int = 10,
timeout: float | None = None,
pre_callback: Callable[..., Any] | None = None,
on_elastic_event_callback: Callable[..., Any] | None = None,
) -> Callable[[_F], _F]:
"""Retries a function with elasticity fault tolerance.

This decorator wraps a function to automatically retry execution in case of
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
active slices before each attempt and cleans up JAX caches on failure.

If `minimum_slice_count` is not met, the function will wait until at least
`minimum_slice_count` slices are active before execution. If
`minimum_slice_count` is None, it defaults to the total number of slices
(i.e., it waits for all slices to be active).

When `minimum_slice_count` is less than the total number of slices, a
background thread will monitor for new slices becoming available and trigger
a retry if they do.

Often, the function will dispatch JAX operations and wait for them to
complete while creating a log message. If using Python logging, it is
recommended to set `logging.raiseExceptions=True` to ensure that the
`jax.errors.JaxRuntimeError` is not silently ignored within the logging
call.

Args:
max_retries: The maximum number of times to retry the function.
pre_callback: A callback to call before each attempt of the wrapped
function.
minimum_slice_count: The minimum number of slices required to run the
function. If None, defaults to the total number of slices.
poll_interval: The number of seconds to wait between activity checks.
Defaults to 10 seconds.
timeout: The maximum number of seconds to wait for slices to become
active before each retry attempt. If None, there is no timeout.
pre_callback: A callback to call before the function is attempted.
on_elastic_event_callback: A callback to call after an elastic failure
occurs.

Returns:
A function decorator.
A decorator that retries the wrapped function.

Raises:
ElasticRuntimeError: If all retry attempts fail.
Exception: Any other exception raised by the wrapped function that is not
due to a slice down event.
"""
if minimum_slice_count is None:
target_slice_count = self.total_slice_count
else:
target_slice_count = minimum_slice_count

if max_retries <= 0:
raise ValueError("max_retries must be positive.")

def decorator(func: _F) -> _F:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
for retry_index in range(max_retries):
try:
_logger.info(
"Elastic attempt %d out of %d", retry_index + 1, max_retries
)
if pre_callback is not None:
pre_callback()
def attempt_execution(retry_index: int):
_logger.info(
"Elastic attempt %d out of %d", retry_index + 1, max_retries
)
self.active_slice_indices = elastic.wait_for_slices(
slice_count=target_slice_count,
slice_to_devices=self.slice_to_devices,
poll_interval=poll_interval,
timeout=timeout,
)
if pre_callback is not None:
pre_callback()

with jax.default_device(self.default_device):
self.new_slice_event.clear()
stop_event = threading.Event()

if target_slice_count < self.total_slice_count:
monitor_thread = threading.Thread(
target=self._monitor_new_slices,
args=(stop_event, poll_interval),
daemon=True,
)
monitor_thread.start()
else:
monitor_thread = None

with jax.default_device(self.default_device):
try:
return func(*args, **kwargs)
finally:
stop_event.set()
if monitor_thread is not None:
monitor_thread.join()

for retry_index in range(max_retries):
try:
return attempt_execution(retry_index)
except ScaleUpSignalError:
_logger.info("Scale up requested. Retrying.")
_elastic_event_cleanup()
Expand Down Expand Up @@ -230,17 +327,7 @@ def pause_resume(
) -> Callable[[_F], _F]:
"""Retries a function with pause/resume fault tolerance.

This decorator wraps a function to automatically retry execution in case of
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
active slices before each attempt and cleans up JAX caches on failure.
The function will not be attempted (or reattempted) until all of the slices
are active.

Often, the function will dispatch JAX operations and wait for them to
complete while creating a log message. If using Python logging, it is
recommended to set `logging.raiseExceptions=True` to ensure that the
`jax.errors.JaxRuntimeError` is not silently ignored within the logging
call.
DEPRECATED: Use `elastic_retry` instead.

Args:
max_retries: The maximum number of times to retry the function.
Expand All @@ -254,60 +341,21 @@ def pause_resume(

Returns:
A decorator that retries the wrapped function.

Raises:
ElasticRuntimeError: If all retry attempts fail.
Exception: Any other exception raised by the wrapped function that is not
due to a slice down event.
"""
def internal_pre_callback():
self.active_slice_indices = elastic.wait_for_slices(
slice_count=self.total_slice_count,
slice_to_devices=self.slice_to_devices,
poll_interval=poll_interval,
timeout=timeout,
)
if pre_callback is not None:
pre_callback()

return self._elasticity_retry_decorator(
warnings.warn(
"`pause_resume` is deprecated. Please use `elastic_retry` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.elastic_retry(
max_retries=max_retries,
pre_callback=internal_pre_callback,
minimum_slice_count=None,
poll_interval=poll_interval,
timeout=timeout,
pre_callback=pre_callback,
on_elastic_event_callback=on_elastic_event_callback,
)

def _monitor_new_slices(
self, stop_event: threading.Event, poll_interval: float | int
):
"""Monitors for new slices and sets the `new_slice_event` if found."""
while not stop_event.wait(poll_interval):
try:
if not self.inactive_slice_indices:
_logger.debug("No inactive slices to check.")
continue

_logger.debug(
"Checking inactive slices: %s", self.inactive_slice_indices
)
inactive_slice_to_devices = {
i: self.slice_to_devices[i] for i in self.inactive_slice_indices
}
newly_active_indices = elastic.get_active_slice_indices(
inactive_slice_to_devices
)

if newly_active_indices:
_logger.info(
"New slices found: %s. Setting new slice event.",
newly_active_indices,
)
self.new_slice_event.set()
return

_logger.debug("No new slices found.")
except Exception: # pylint: disable=broad-exception-caught
_logger.exception("Error in monitor thread")

def replica_resize(
self,
max_resizes: int,
Expand All @@ -317,6 +365,8 @@ def replica_resize(
) -> Callable[[_F], _F]:
"""Retries a function with replica/resize fault tolerance.

DEPRECATED: Use `elastic_retry` instead.

Args:
max_resizes: The maximum number of times to retry the function after
resizing the replica count.
Expand All @@ -328,47 +378,16 @@ def replica_resize(

Returns:
A decorator that retries the wrapped function.

Raises:
ElasticRuntimeError: If all retry attempts fail.
Exception: Any other exception raised by the wrapped function that is not
due to a slice down event.
"""

def internal_pre_callback():
self.active_slice_indices = elastic.wait_for_slices(
slice_count=1,
slice_to_devices=self.slice_to_devices,
poll_interval=poll_interval,
)

if pre_callback is not None:
pre_callback()

retry_decorator = self._elasticity_retry_decorator(
warnings.warn(
"`replica_resize` is deprecated. Please use `elastic_retry` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.elastic_retry(
max_retries=max_resizes,
pre_callback=internal_pre_callback,
minimum_slice_count=1,
poll_interval=poll_interval,
pre_callback=pre_callback,
on_elastic_event_callback=on_elastic_event_callback,
)

def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
self.new_slice_event.clear()
stop_event = threading.Event()

monitor_thread = threading.Thread(
target=self._monitor_new_slices,
args=(stop_event, poll_interval),
daemon=True,
)
monitor_thread.start()
try:
return func(*args, **kwargs)
finally:
stop_event.set()
monitor_thread.join()

return retry_decorator(wrapper)

return decorator
Loading