From 37e62252f0194c569b2d4e93aaf6e4363e6d5b7d Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Tue, 17 Mar 2026 11:27:30 -0700 Subject: [PATCH] Refactor elastic retry decorators into a single elastic_retry method. This change unifies the `pause_resume` and `replica_resize` functionalities into a single `elastic_retry` decorator. The new decorator uses a `minimum_slice_count` parameter to control whether to wait for all slices (defaulting to pause/resume behavior) or a smaller subset (enabling replica/resize behavior). The old `pause_resume` and `replica_resize` methods are now deprecated and act as wrappers around `elastic_retry`. PiperOrigin-RevId: 885122724 --- pathwaysutils/elastic/manager.py | 243 +++++++++++++++++-------------- 1 file changed, 131 insertions(+), 112 deletions(-) diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index 13e59da..8cde6cd 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -23,6 +23,7 @@ import logging import threading from typing import Any, TypeVar +import warnings import jax from pathwaysutils.elastic import elastic @@ -156,40 +157,136 @@ def _cleanup_on_retry(self): for array in jax.live_arrays(): array.delete() - def _elasticity_retry_decorator( + def _monitor_new_slices( + self, stop_event: threading.Event, poll_interval: float | int + ): + """Monitors for new slices and sets the `new_slice_event` if found.""" + while not stop_event.wait(poll_interval): + try: + if not self.inactive_slice_indices: + _logger.debug("No inactive slices to check.") + continue + + _logger.debug( + "Checking inactive slices: %s", self.inactive_slice_indices + ) + inactive_slice_to_devices = { + i: self.slice_to_devices[i] for i in self.inactive_slice_indices + } + newly_active_indices = elastic.get_active_slice_indices( + inactive_slice_to_devices + ) + + if newly_active_indices: + _logger.info( + "New slices found: %s. Setting new slice event.", + newly_active_indices, + ) + self.new_slice_event.set() + return + + _logger.debug("No new slices found.") + except Exception: # pylint: disable=broad-exception-caught + _logger.exception("Error in monitor thread") + + def elastic_retry( self, max_retries: int, + minimum_slice_count: int | None = None, + poll_interval: float | int = 10, + timeout: float | None = None, pre_callback: Callable[..., Any] | None = None, on_elastic_event_callback: Callable[..., Any] | None = None, ) -> Callable[[_F], _F]: """Retries a function with elasticity fault tolerance. + This decorator wraps a function to automatically retry execution in case of + `jax.errors.JaxRuntimeError` caused by slice down events. It waits for + active slices before each attempt and cleans up JAX caches on failure. + + If `minimum_slice_count` is not met, the function will wait until at least + `minimum_slice_count` slices are active before execution. If + `minimum_slice_count` is None, it defaults to the total number of slices + (i.e., it waits for all slices to be active). + + When `minimum_slice_count` is less than the total number of slices, a + background thread will monitor for new slices becoming available and trigger + a retry if they do. + + Often, the function will dispatch JAX operations and wait for them to + complete while creating a log message. If using Python logging, it is + recommended to set `logging.raiseExceptions=True` to ensure that the + `jax.errors.JaxRuntimeError` is not silently ignored within the logging + call. + Args: max_retries: The maximum number of times to retry the function. - pre_callback: A callback to call before each attempt of the wrapped - function. + minimum_slice_count: The minimum number of slices required to run the + function. If None, defaults to the total number of slices. + poll_interval: The number of seconds to wait between activity checks. + Defaults to 10 seconds. + timeout: The maximum number of seconds to wait for slices to become + active before each retry attempt. If None, there is no timeout. + pre_callback: A callback to call before the function is attempted. on_elastic_event_callback: A callback to call after an elastic failure occurs. Returns: - A function decorator. + A decorator that retries the wrapped function. + + Raises: + ElasticRuntimeError: If all retry attempts fail. + Exception: Any other exception raised by the wrapped function that is not + due to a slice down event. """ + if minimum_slice_count is None: + target_slice_count = self.total_slice_count + else: + target_slice_count = minimum_slice_count if max_retries <= 0: raise ValueError("max_retries must be positive.") + def decorator(func: _F) -> _F: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: - for retry_index in range(max_retries): - try: - _logger.info( - "Elastic attempt %d out of %d", retry_index + 1, max_retries - ) - if pre_callback is not None: - pre_callback() + def attempt_execution(retry_index: int): + _logger.info( + "Elastic attempt %d out of %d", retry_index + 1, max_retries + ) + self.active_slice_indices = elastic.wait_for_slices( + slice_count=target_slice_count, + slice_to_devices=self.slice_to_devices, + poll_interval=poll_interval, + timeout=timeout, + ) + if pre_callback is not None: + pre_callback() + + with jax.default_device(self.default_device): + self.new_slice_event.clear() + stop_event = threading.Event() + + if target_slice_count < self.total_slice_count: + monitor_thread = threading.Thread( + target=self._monitor_new_slices, + args=(stop_event, poll_interval), + daemon=True, + ) + monitor_thread.start() + else: + monitor_thread = None - with jax.default_device(self.default_device): + try: return func(*args, **kwargs) + finally: + stop_event.set() + if monitor_thread is not None: + monitor_thread.join() + + for retry_index in range(max_retries): + try: + return attempt_execution(retry_index) except ScaleUpSignalError: _logger.info("Scale up requested. Retrying.") _elastic_event_cleanup() @@ -230,17 +327,7 @@ def pause_resume( ) -> Callable[[_F], _F]: """Retries a function with pause/resume fault tolerance. - This decorator wraps a function to automatically retry execution in case of - `jax.errors.JaxRuntimeError` caused by slice down events. It waits for - active slices before each attempt and cleans up JAX caches on failure. - The function will not be attempted (or reattempted) until all of the slices - are active. - - Often, the function will dispatch JAX operations and wait for them to - complete while creating a log message. If using Python logging, it is - recommended to set `logging.raiseExceptions=True` to ensure that the - `jax.errors.JaxRuntimeError` is not silently ignored within the logging - call. + DEPRECATED: Use `elastic_retry` instead. Args: max_retries: The maximum number of times to retry the function. @@ -254,60 +341,21 @@ def pause_resume( Returns: A decorator that retries the wrapped function. - - Raises: - ElasticRuntimeError: If all retry attempts fail. - Exception: Any other exception raised by the wrapped function that is not - due to a slice down event. """ - def internal_pre_callback(): - self.active_slice_indices = elastic.wait_for_slices( - slice_count=self.total_slice_count, - slice_to_devices=self.slice_to_devices, - poll_interval=poll_interval, - timeout=timeout, - ) - if pre_callback is not None: - pre_callback() - - return self._elasticity_retry_decorator( + warnings.warn( + "`pause_resume` is deprecated. Please use `elastic_retry` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.elastic_retry( max_retries=max_retries, - pre_callback=internal_pre_callback, + minimum_slice_count=None, + poll_interval=poll_interval, + timeout=timeout, + pre_callback=pre_callback, on_elastic_event_callback=on_elastic_event_callback, ) - def _monitor_new_slices( - self, stop_event: threading.Event, poll_interval: float | int - ): - """Monitors for new slices and sets the `new_slice_event` if found.""" - while not stop_event.wait(poll_interval): - try: - if not self.inactive_slice_indices: - _logger.debug("No inactive slices to check.") - continue - - _logger.debug( - "Checking inactive slices: %s", self.inactive_slice_indices - ) - inactive_slice_to_devices = { - i: self.slice_to_devices[i] for i in self.inactive_slice_indices - } - newly_active_indices = elastic.get_active_slice_indices( - inactive_slice_to_devices - ) - - if newly_active_indices: - _logger.info( - "New slices found: %s. Setting new slice event.", - newly_active_indices, - ) - self.new_slice_event.set() - return - - _logger.debug("No new slices found.") - except Exception: # pylint: disable=broad-exception-caught - _logger.exception("Error in monitor thread") - def replica_resize( self, max_resizes: int, @@ -317,6 +365,8 @@ def replica_resize( ) -> Callable[[_F], _F]: """Retries a function with replica/resize fault tolerance. + DEPRECATED: Use `elastic_retry` instead. + Args: max_resizes: The maximum number of times to retry the function after resizing the replica count. @@ -328,47 +378,16 @@ def replica_resize( Returns: A decorator that retries the wrapped function. - - Raises: - ElasticRuntimeError: If all retry attempts fail. - Exception: Any other exception raised by the wrapped function that is not - due to a slice down event. """ - - def internal_pre_callback(): - self.active_slice_indices = elastic.wait_for_slices( - slice_count=1, - slice_to_devices=self.slice_to_devices, - poll_interval=poll_interval, - ) - - if pre_callback is not None: - pre_callback() - - retry_decorator = self._elasticity_retry_decorator( + warnings.warn( + "`replica_resize` is deprecated. Please use `elastic_retry` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.elastic_retry( max_retries=max_resizes, - pre_callback=internal_pre_callback, + minimum_slice_count=1, + poll_interval=poll_interval, + pre_callback=pre_callback, on_elastic_event_callback=on_elastic_event_callback, ) - - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - self.new_slice_event.clear() - stop_event = threading.Event() - - monitor_thread = threading.Thread( - target=self._monitor_new_slices, - args=(stop_event, poll_interval), - daemon=True, - ) - monitor_thread.start() - try: - return func(*args, **kwargs) - finally: - stop_event.set() - monitor_thread.join() - - return retry_decorator(wrapper) - - return decorator