From de67fbc6e77f424225a8e842dcb6025726d7d673 Mon Sep 17 00:00:00 2001 From: Luke Baumann Date: Tue, 6 Jan 2026 14:04:21 -0800 Subject: [PATCH] Require JAX>=0.8.0 ---- Directly use jax.extend.ifrt_proxy. This change updates pathwaysutils to import and use `jax.extend.ifrt_proxy.ifrt_proxy` directly. The re-export of this function from `pathwaysutils.jax` is removed, along with version-specific compatibility code for older JAX versions. PiperOrigin-RevId: 852927594 --- pathwaysutils/jax/__init__.py | 18 +----------------- pathwaysutils/proxy_backend.py | 6 +++--- pathwaysutils/test/proxy_backend_test.py | 4 ++-- 3 files changed, 6 insertions(+), 22 deletions(-) diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index c74d9a8..e5bc106 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -17,9 +17,8 @@ `pathwaysutils`'s compatibility window. """ -import functools -import jax +import functools class _FakeJaxFunction: @@ -46,20 +45,6 @@ def __call__(self, *args, **kwargs): raise ImportError(self.error_message) -try: - # jax>=0.7.1 - from jax.extend import backend # pylint: disable=g-import-not-at-top - - ifrt_proxy = backend.ifrt_proxy - del backend -except AttributeError: - # jax<0.7.1 - from jax.lib import xla_extension # pylint: disable=g-import-not-at-top - - ifrt_proxy = xla_extension.ifrt_proxy - del xla_extension - - try: # jax>=0.8.0 from jaxlib import _pathways # pylint: disable=g-import-not-at-top @@ -112,6 +97,5 @@ def ifrt_reshard_available() -> bool: del jax -del jax del _FakeJaxFunction del functools diff --git a/pathwaysutils/proxy_backend.py b/pathwaysutils/proxy_backend.py index d599f21..cf1f806 100644 --- a/pathwaysutils/proxy_backend.py +++ b/pathwaysutils/proxy_backend.py @@ -15,15 +15,15 @@ import jax from jax.extend import backend -from pathwaysutils import jax as pw_jax +from jax.extend.backend import ifrt_proxy def register_backend_factory() -> None: backend.register_backend_factory( "proxy", - lambda: pw_jax.ifrt_proxy.get_client( + lambda: ifrt_proxy.get_client( jax.config.read("jax_backend_target"), - pw_jax.ifrt_proxy.ClientConnectionOptions(), + ifrt_proxy.ClientConnectionOptions(), ), priority=-1, ) diff --git a/pathwaysutils/test/proxy_backend_test.py b/pathwaysutils/test/proxy_backend_test.py index 2d6b613..fb8ad8c 100644 --- a/pathwaysutils/test/proxy_backend_test.py +++ b/pathwaysutils/test/proxy_backend_test.py @@ -18,7 +18,7 @@ from absl.testing import absltest import jax from jax.extend import backend -from pathwaysutils import jax as pw_jax +from jax.extend.backend import ifrt_proxy from pathwaysutils import proxy_backend @@ -46,7 +46,7 @@ def test_no_proxy_backend_registration_raises_error(self): def test_proxy_backend_registration(self): self.enter_context( mock.patch.object( - pw_jax.ifrt_proxy, + ifrt_proxy, "get_client", return_value=mock.MagicMock(), )