From 278cd86134f5ddff05d1a9c02660351ccd861994 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 16 Mar 2026 15:39:22 -0700 Subject: [PATCH 01/12] Initial commit for adding NERSC IRI-API support alongside SFAPI for job submission --- orchestration/flows/bl832/nersc.py | 130 +++++++++++++++- orchestration/globus/token.py | 235 +++++++++++++++++++++++++++++ 2 files changed, 363 insertions(+), 2 deletions(-) create mode 100644 orchestration/globus/token.py diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 727cbbaf..bfad0efd 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1,5 +1,6 @@ import datetime from dotenv import load_dotenv +import enum import json import logging import os @@ -15,18 +16,51 @@ from typing import Optional from orchestration.flows.bl832.config import Config832 + from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController -from orchestration.transfer_controller import get_transfer_controller, CopyMethod from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) +from orchestration.globus.token import get_access_token_confidential, DEFAULT_TOKEN_FILE from orchestration.prefect import schedule_prefect_flow +from orchestration.transfer_controller import get_transfer_controller, CopyMethod logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) load_dotenv() +class NERSCLoginMethod(enum.Enum): + """Selects which NERSC API login method to use when creating a NERSC client. + + Each method corresponds to a different set of credentials and API base URL. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" + + +# Applies only to NERSCLoginMethod.IRIAPI +_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" +_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client +_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" +_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" +_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +}) + +_API_BASE_URLS: dict[NERSCLoginMethod, str] = { + NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", + NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", +} + + class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): """ Implementation for a NERSC-based tomography HPC controller. @@ -43,7 +77,99 @@ def __init__( self.client = client @staticmethod - def create_sfapi_client() -> Client: + def create_nersc_client( + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, + ) -> Client: + """Create and return a NERSC client for the requested login method. + + Two fundamentally different auth strategies are supported: + + - :attr:`NERSCLoginMethod.SFAPI`: uses an Iris-registered OAuth2 + client ID + private key (NERSC OIDC flow). Set ``PATH_NERSC_CLIENT_ID`` + and ``PATH_NERSC_PRI_KEY`` to the paths of those files. + + - :attr:`NERSCLoginMethod.IRIAPI`: uses a Globus bearer token written + by ``globus_token.py``. Set ``PATH_GLOBUS_TOKEN_FILE`` to the token + file path, or rely on the default (``~/.globus/auth_tokens.json``). + + Args: + login_method: Which NERSC API to authenticate against. + Defaults to :attr:`NERSCLoginMethod.SFAPI`. + + Returns: + An authenticated :class:`sfapi_client.Client` instance. + + Raises: + ValueError: If SFAPI credential environment variables are unset. + FileNotFoundError: If credential or token files are absent. + RuntimeError: If the Globus token is expired. + Exception: If the underlying client construction fails. + """ + logger.info(f"Creating NERSC client using login method: {login_method.value}") + api_url = _API_BASE_URLS[login_method] + logger.info(f"Targeting API base URL: {api_url}") + + if login_method is NERSCLoginMethod.SFAPI: + client = NERSCTomographyHPCController._create_sfapi_client() + + elif login_method is NERSCLoginMethod.IRIAPI: + client = NERSCTomographyHPCController._create_iriapi_client() + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") + + logger.info( + f"NERSC client created successfully " + f"(method={login_method.value}, api_url={api_url})." + ) + return client + + @staticmethod + def _create_iriapi_client() -> Client: + """Create a NERSC client for the IRI API using a Globus bearer token. + + Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the + environment. Reuses a cached token if valid; otherwise mints a new one + via the client credentials grant. No browser or user interaction. + + Returns: + An authenticated :class:`sfapi_client.Client` targeting the IRI API. + + Raises: + ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. + RuntimeError: If the acquired token is missing required scopes. + """ + client_id = os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) + client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) + + if not client_id: + raise ValueError( + f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." + ) + if not client_secret: + raise ValueError( + f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " + "A Globus Confidential App client is required for automated IRI API auth." + ) + + token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) + token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE + + access_token = get_access_token_confidential( + client_id=client_id, + client_secret=client_secret, + required_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, + resource_server=_IRIAPI_GLOBUS_RESOURCE_SERVER, + token_file=token_file, + ) + + return Client( + token=access_token, + api_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + ) + + @staticmethod + def _create_sfapi_client() -> Client: """Create and return an NERSC client instance""" # When generating the SFAPI Key in Iris, make sure to select "asldev" as the user! diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py new file mode 100644 index 00000000..81b5438f --- /dev/null +++ b/orchestration/globus/token.py @@ -0,0 +1,235 @@ +import json +import logging +import os +from pathlib import Path +import stat +import time + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +logger = logging.getLogger(__name__) + +# Default token file location, matching the Globus SDK convention. +DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" +GLOBUS_OIDC_TOKEN_URL: str = "https://auth.globus.org/v2/oauth2/token" + + +def get_access_token_confidential( + client_id: str, + client_secret: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, +) -> str: + """Get a valid Globus access token using a Confidential Client (machine-to-machine). + + No browser or user interaction required. If a valid unexpired token exists + on disk it is reused; otherwise a new one is minted via the client + credentials grant and saved. + + Args: + client_id: Globus Confidential App client ID. + client_secret: Globus Confidential App client secret. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token cache file. Defaults to + ``~/.globus/auth_tokens.json``. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + + # 1. Do we already have a valid token? + stored = load_token_file(resolved_token_file) + if stored: + expires_at = stored.get("expires_at_seconds") + if expires_at and time.time() < expires_at: + logger.info("Using cached Globus token (still valid).") + return stored["access_token"] + logger.info("Cached Globus token is expired; minting a new one.") + else: + logger.info("No cached Globus token found; minting a new one.") + + # 2. Mint a new token — same call whether first time or expired. + globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) + token_response = globus_client.oauth2_client_credentials_tokens( + requested_scopes=" ".join(sorted(required_scopes)) + ) + auth_data = token_response.by_resource_server[resource_server] + + granted = set(auth_data.get("scope", "").split()) + missing = required_scopes - granted + if missing: + raise RuntimeError( + f"Globus token is missing required scopes: {sorted(missing)}" + ) + + save_token_file(resolved_token_file, auth_data) + logger.info(f"New Globus token saved to {resolved_token_file}.") + + return auth_data["access_token"] + + +def load_token_file(token_file: Path) -> dict | None: + """Load saved Globus token data from disk. + + Args: + token_file: Path to the JSON token file. + + Returns: + Parsed token dict, or None if the file does not exist. + """ + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_token_file(token_file: Path, tokens: dict) -> None: + """Atomically save Globus token data to disk with owner-only permissions. + + Writes to a temporary file then renames to avoid partial writes. + + Args: + token_file: Destination path for the JSON token file. + tokens: Token dict to serialise. + """ + _ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + required_scopes: frozenset[str], + resource_server: str, +) -> dict: + """Run an interactive browser-based Globus login flow. + + Prints an authorization URL, waits for the user to paste an auth code, + and exchanges it for tokens. + + Args: + client: Globus NativeAppAuthClient to drive the flow. + required_scopes: Set of OAuth2 scopes to request. + resource_server: Resource server key to extract from the token response + (e.g. ``"auth.globus.org"``). + + Returns: + Token dict for the given resource server. + """ + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(required_scopes)), + refresh_tokens=True, + ) + logger.info("Open this URL in your browser to authenticate with Globus:") + logger.info(client.oauth2_get_authorize_url()) + code = input("\nEnter authorization code: ").strip() + token_response = client.oauth2_exchange_code_for_tokens(code) + return token_response.by_resource_server[resource_server] + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, + refresh_token: str, + resource_server: str, +) -> dict | None: + """Attempt a silent Globus token refresh. + + Args: + client: Globus NativeAppAuthClient to drive the refresh. + refresh_token: The stored refresh token. + resource_server: Resource server key to extract from the token response. + + Returns: + Fresh token dict for the given resource server, or None if refresh failed. + """ + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.by_resource_server[resource_server] + except GlobusAPIError as e: + logger.warning( + f"Globus token refresh failed ({e.http_status}); " + "falling back to interactive login." + ) + return None + + +def get_access_token( + client_id: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, + force_login: bool = False, +) -> str: + """Get a valid Globus access token, refreshing or logging in as needed. + + Attempts a silent refresh from the saved token file first. Falls back to + interactive browser login if no saved tokens exist, the refresh token is + absent, or the refresh fails. Saves the resulting tokens back to disk. + + Args: + client_id: Globus NativeApp client ID. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token file. Defaults to + ``~/.globus/auth_tokens.json``. + force_login: If True, skip refresh and force interactive login. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + globus_client = globus_sdk.NativeAppAuthClient(client_id) + + auth_data: dict | None = None + + if not force_login: + stored = load_token_file(resolved_token_file) + if stored and stored.get("refresh_token"): + auth_data = refresh_tokens( + globus_client, stored["refresh_token"], resource_server + ) + + if auth_data is None: + logger.info("Initiating interactive Globus login.") + auth_data = interactive_login(globus_client, required_scopes, resource_server) + + granted = set(auth_data.get("scope", "").split()) + missing = required_scopes - granted + if missing: + raise RuntimeError( + f"Globus token is missing required scopes: {sorted(missing)}" + ) + + save_token_file(resolved_token_file, auth_data) + logger.info(f"Globus token saved to {resolved_token_file}.") + + return auth_data["access_token"] + + +def _ensure_private_parent_dir(path: Path) -> None: + """Create parent directories for path with owner-only permissions. + + Args: + path: The file path whose parent directory should be created. + """ + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) From d439ce02a4ffc4d68828b2e1c6c5f41686052130 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 16 Mar 2026 15:55:45 -0700 Subject: [PATCH 02/12] Adding an abstraction for _submit_job() and _wait_for_job() that use the correct mechanism based on IRI/SF-API --- orchestration/flows/bl832/nersc.py | 68 +++++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index bfad0efd..17055843 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -71,10 +71,12 @@ class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin) def __init__( self, client: Client, - config: Config832 + config: Config832, + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: TomographyHPCController.__init__(self, config) self.client = client + self.login_method = login_method @staticmethod def create_nersc_client( @@ -200,6 +202,70 @@ def _create_sfapi_client() -> Client: logger.error(f"Failed to create NERSC client: {e}") raise e + def _submit_job(self, job_script: str) -> str: + """Submit a Slurm job script and return the job ID. + + Dispatches to the appropriate submission mechanism based on + ``self.login_method``. + + Args: + job_script: The full Slurm batch script to submit. + + Returns: + The submitted job ID as a string. + + Raises: + RuntimeError: If job submission fails. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + return str(job.jobid) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + response = self.client.post( + "/api/v1/compute/job/perlmutter", + json={"script": job_script}, + ) + response.raise_for_status() + return str(response.json()["job_id"]) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _wait_for_job(self, job_id: str) -> bool: + """Block until a submitted job completes. + + Dispatches to the appropriate polling mechanism based on + ``self.login_method``. + + Args: + job_id: The job ID returned by :meth:`_submit_job`. + + Returns: + True if the job completed successfully, False otherwise. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.job(jobid=job_id) + job.complete() + return True + + elif self.login_method is NERSCLoginMethod.IRIAPI: + while True: + response = self.client.get( + f"/api/v1/compute/status/perlmutter/{job_id}" + ) + response.raise_for_status() + state = response.json().get("state") + logger.info(f"Job {job_id} state: {state}") + if state in ("COMPLETED", "FAILED", "CANCELLED", "TIMEOUT"): + return state == "COMPLETED" + time.sleep(60) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + def reconstruct( self, file_path: str = "", From 98f9a4478b248f57507bf9e8e8fcd48b8d97fb4c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:22:49 -0700 Subject: [PATCH 03/12] moving NERSCLoginMethod(Enum) to the job_controller.py module --- orchestration/flows/bl832/job_controller.py | 26 +++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/orchestration/flows/bl832/job_controller.py b/orchestration/flows/bl832/job_controller.py index b2ff064b..1a23d02a 100644 --- a/orchestration/flows/bl832/job_controller.py +++ b/orchestration/flows/bl832/job_controller.py @@ -10,6 +10,19 @@ load_dotenv() +class NERSCLoginMethod(Enum): + """Selects which NERSC API login method to use when creating a NERSC client. + + Each method corresponds to a different set of credentials and API base URL. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" + + class TomographyHPCController(ABC): """ Abstract class for tomography HPC controllers. @@ -65,7 +78,8 @@ class HPC(Enum): def get_controller( hpc_type: HPC, - config: Config832 + config: Config832, + login_method: "NERSCLoginMethod | None" = None, ) -> TomographyHPCController: """ Factory function that returns an HPC controller instance for the given HPC environment. @@ -86,10 +100,14 @@ def get_controller( config=config ) elif hpc_type == HPC.NERSC: - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + resolved_login_method = login_method if isinstance(login_method, NERSCLoginMethod) else NERSCLoginMethod.SFAPI return NERSCTomographyHPCController( - client=NERSCTomographyHPCController.create_sfapi_client(), - config=config + client=NERSCTomographyHPCController.create_nersc_client( + login_method=resolved_login_method + ), + config=config, + login_method=resolved_login_method, ) elif hpc_type == HPC.OLCF: # TODO: Implement OLCF controller From 8bcd92994332e076c292482b5e887cdff5f5d550 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:25:00 -0700 Subject: [PATCH 04/12] Removed NERSCLoginMethod(Enum) from nersc.py. Created a temporary test flow for reconstruction to test job submission. In reconstruct(), replaced the SFAPI-specific job submission/polling code with the general _submit_job() and _wait_for_job() methods. --- orchestration/flows/bl832/nersc.py | 229 +++++++++++++++++++++-------- 1 file changed, 171 insertions(+), 58 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 17055843..1ddb3a50 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1,6 +1,6 @@ import datetime from dotenv import load_dotenv -import enum +import httpx import json import logging import os @@ -17,7 +17,7 @@ from orchestration.flows.bl832.config import Config832 -from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController +from orchestration.flows.bl832.job_controller import get_controller, HPC, NERSCLoginMethod, TomographyHPCController from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) @@ -30,19 +30,6 @@ load_dotenv() -class NERSCLoginMethod(enum.Enum): - """Selects which NERSC API login method to use when creating a NERSC client. - - Each method corresponds to a different set of credentials and API base URL. - """ - - SFAPI = "sfapi" - """Standard Superfacility API via Iris-registered OAuth2 credentials.""" - - IRIAPI = "iriapi" - """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" - - # Applies only to NERSCLoginMethod.IRIAPI _IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" _IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client @@ -70,8 +57,8 @@ class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin) def __init__( self, - client: Client, config: Config832, + client: Client | httpx.Client | None = None, login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: TomographyHPCController.__init__(self, config) @@ -165,9 +152,9 @@ def _create_iriapi_client() -> Client: token_file=token_file, ) - return Client( - token=access_token, - api_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + return httpx.Client( + base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + headers={"Authorization": f"Bearer {access_token}"}, ) @staticmethod @@ -202,6 +189,28 @@ def _create_sfapi_client() -> Client: logger.error(f"Failed to create NERSC client: {e}") raise e + def _get_nersc_username(self) -> str: + """Get the NERSC username for constructing pscratch paths. + + Uses the sfapi_client user endpoint for SFAPI, or reads + ``NERSC_USERNAME`` from the environment for IRIAPI. + + Returns: + NERSC username string. + + Raises: + ValueError: If IRIAPI is selected and NERSC_USERNAME is unset. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + return self.client.user().name + else: + username = os.getenv("NERSC_USERNAME") + if not username: + raise ValueError( + "NERSC_USERNAME must be set in the environment when using IRIAPI." + ) + return username + def _submit_job(self, job_script: str) -> str: """Submit a Slurm job script and return the job ID. @@ -240,7 +249,7 @@ def _wait_for_job(self, job_id: str) -> bool: ``self.login_method``. Args: - job_id: The job ID returned by :meth:`_submit_job`. + job_id: The job ID returned by `_submit_job`. Returns: True if the job completed successfully, False otherwise. @@ -275,7 +284,8 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - user = self.client.user() + # user = self.client.user() + username = self._get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path logger.info(f"{raw_path=}") @@ -289,7 +299,7 @@ def reconstruct( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -351,42 +361,54 @@ def reconstruct( """ try: - logger.info("Submitting reconstruction job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + logger.info("Submitting reconstruction job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - return True - + success = self._wait_for_job(job_id) + logger.info(f"Reconstruction job {'completed' if success else 'failed'}.") + return success except Exception as e: - logger.info(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - # Unknown error: cannot recover - return False + logger.error(f"Error during reconstruction job submission or completion: {e}") + return False + + # try: + # logger.info("Submitting reconstruction job script to Perlmutter.") + # perlmutter = self.client.compute(Machine.perlmutter) + # job = perlmutter.submit_job(job_script) + # logger.info(f"Submitted job ID: {job.jobid}") + + # try: + # job.update() + # except Exception as update_err: + # logger.warning(f"Initial job update failed, continuing: {update_err}") + + # time.sleep(60) + # logger.info(f"Job {job.jobid} current state: {job.state}") + + # job.complete() # Wait until the job completes + # logger.info("Reconstruction job completed successfully.") + # return True + + # except Exception as e: + # logger.info(f"Error during job submission or completion: {e}") + # match = re.search(r"Job not found:\s*(\d+)", str(e)) + + # if match: + # jobid = match.group(1) + # logger.info(f"Attempting to recover job {jobid}.") + # try: + # job = self.client.perlmutter.job(jobid=jobid) + # time.sleep(30) + # job.complete() + # logger.info("Reconstruction job completed successfully after recovery.") + # return True + # except Exception as recovery_err: + # logger.error(f"Failed to recover job {jobid}: {recovery_err}") + # return False + # else: + # # Unknown error: cannot recover + # return False def build_multi_resolution( self, @@ -712,6 +734,97 @@ def nersc_recon_flow( return False +@flow(name="nersc_recon_test_flow", flow_run_name="nersc_recon-{file_path}") +def nersc_recon_test_flow( + file_path: str, + config: Optional[Config832] = None, +) -> bool: + """ + Perform tomography reconstruction on NERSC. + + :param file_path: Path to the file to reconstruct. + """ + logger = get_run_logger() + + if config is None: + logger.info("Initializing Config") + config = Config832() + + logger.info(f"Starting NERSC reconstruction flow for {file_path=}") + controller = get_controller( + hpc_type=HPC.NERSC, + config=config, + login_method=NERSCLoginMethod.IRIAPI + ) + logger.info("NERSC reconstruction controller initialized") + + nersc_reconstruction_success = controller.reconstruct( + file_path=file_path, + ) + # logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + # nersc_multi_res_success = controller.build_multi_resolution( + # file_path=file_path, + # ) + # logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") + + # path = Path(file_path) + # folder_name = path.parent.name + # file_name = path.stem + + # tiff_file_path = f"{folder_name}/rec{file_name}" + # zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + + # logger.info(f"{tiff_file_path=}") + # logger.info(f"{zarr_file_path=}") + + # # Transfer reconstructed data + # logger.info("Preparing transfer.") + # transfer_controller = get_transfer_controller( + # transfer_type=CopyMethod.GLOBUS, + # config=config + # ) + + # logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") + # transfer_controller.copy( + # file_path=tiff_file_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.nersc832_alsdev_scratch + # ) + + # transfer_controller.copy( + # file_path=zarr_file_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.nersc832_alsdev_scratch + # ) + + # logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") + # transfer_controller.copy( + # file_path=tiff_file_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch + # ) + + # transfer_controller.copy( + # file_path=zarr_file_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch + # ) + + # logger.info("Scheduling pruning tasks.") + # schedule_pruning( + # config=config, + # raw_file_path=file_path, + # tiff_file_path=tiff_file_path, + # zarr_file_path=zarr_file_path + # ) + + # TODO: Ingest into SciCat + if nersc_reconstruction_success: + return True + else: + return False + + @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( walltime: datetime.timedelta = datetime.timedelta(minutes=5), @@ -741,8 +854,8 @@ def nersc_streaming_flow( if __name__ == "__main__": config = Config832() - nersc_recon_flow( - file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + nersc_recon_test_flow( + file_path="dabramov/20241216_153047_ddd.h5.h5", config=config ) # nersc_streaming_flow( From e2fe24bb736674ab848d7550ccc9952dd68f450f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:50:35 -0700 Subject: [PATCH 05/12] Updating pytests --- orchestration/_tests/test_bl832/__init__.py | 0 orchestration/_tests/test_bl832/test_nersc.py | 310 ++++++++++++++++++ orchestration/_tests/test_sfapi_flow.py | 301 +++-------------- 3 files changed, 349 insertions(+), 262 deletions(-) create mode 100644 orchestration/_tests/test_bl832/__init__.py create mode 100644 orchestration/_tests/test_bl832/test_nersc.py diff --git a/orchestration/_tests/test_bl832/__init__.py b/orchestration/_tests/test_bl832/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py new file mode 100644 index 00000000..7994a3ac --- /dev/null +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -0,0 +1,310 @@ +# orchestration/_tests/test_bl832/test_nersc.py +import pytest +from uuid import uuid4 + +from prefect.blocks.system import Secret +from prefect.testing.utilities import prefect_test_harness + + +@pytest.fixture(autouse=True, scope="session") +def prefect_test_fixture(): + with prefect_test_harness(): + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) + yield + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_config(mocker): + config = mocker.MagicMock() + config.ghcr_images832 = { + "recon_image": "mock_recon_image", + "multires_image": "mock_multires_image", + } + return config + + +@pytest.fixture +def mock_sfapi_client(mocker): + """sfapi_client.Client mock with user, compute, submit_job, and job chained.""" + client = mocker.MagicMock() + + mock_user = mocker.MagicMock() + mock_user.name = "testuser" + client.user.return_value = mock_user + + mock_job = mocker.MagicMock() + mock_job.jobid = "12345" + mock_job.state = "COMPLETED" + client.compute.return_value.submit_job.return_value = mock_job + client.compute.return_value.job.return_value = mock_job + + return client + + +@pytest.fixture +def mock_iriapi_client(mocker): + """httpx.Client mock for IRI API responses.""" + client = mocker.MagicMock() + + submit_response = mocker.MagicMock() + submit_response.json.return_value = {"job_id": "99999"} + client.post.return_value = submit_response + + status_response = mocker.MagicMock() + status_response.json.return_value = {"state": "COMPLETED"} + client.get.return_value = status_response + + return client + + +# --------------------------------------------------------------------------- +# _create_sfapi_client +# --------------------------------------------------------------------------- + +def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) + mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") + mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") + + client = NERSCTomographyHPCController._create_sfapi_client() + + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value + + +def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) + + with pytest.raises(ValueError, match="Missing NERSC credentials paths."): + NERSCTomographyHPCController._create_sfapi_client() + + +def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) + + with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): + NERSCTomographyHPCController._create_sfapi_client() + + +# --------------------------------------------------------------------------- +# reconstruct — SFAPI +# --------------------------------------------------------------------------- + +def test_reconstruct_sfapi_success(mocker, mock_sfapi_client, mock_config): + """SFAPI reconstruct submits a job and waits for completion.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result is True + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.job.assert_called_once_with(jobid="12345") + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() + + +def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_config): + """SFAPI reconstruct returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("SFAPI error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result is False + + +# --------------------------------------------------------------------------- +# reconstruct — IRIAPI +# --------------------------------------------------------------------------- + +def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result is True + mock_iriapi_client.post.assert_called_once() + assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/perlmutter" + assert "script" in mock_iriapi_client.post.call_args.kwargs["json"] + mock_iriapi_client.get.assert_called_once_with( + "/api/v1/compute/status/perlmutter/99999" + ) + + +def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI reconstruct returns False when job state is FAILED.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result is False + + +def test_reconstruct_iriapi_missing_username(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI reconstruct raises ValueError when NERSC_USERNAME is unset.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.delenv("NERSC_USERNAME", raising=False) + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + with pytest.raises(ValueError, match="NERSC_USERNAME"): + controller.reconstruct(file_path="folder/scan.h5") + + +# --------------------------------------------------------------------------- +# build_multi_resolution — SFAPI +# --------------------------------------------------------------------------- + +def test_build_multi_resolution_sfapi_success(mocker, mock_sfapi_client, mock_config): + """SFAPI build_multi_resolution submits and waits successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + + +def test_build_multi_resolution_sfapi_failure(mocker, mock_sfapi_client, mock_config): + """SFAPI build_multi_resolution returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is False + + +# --------------------------------------------------------------------------- +# build_multi_resolution — IRIAPI +# --------------------------------------------------------------------------- + +def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI build_multi_resolution POSTs and polls successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + mock_iriapi_client.post.assert_called_once() + mock_iriapi_client.get.assert_called_once_with( + "/api/v1/compute/status/perlmutter/99999" + ) + + +def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI build_multi_resolution returns False when job state is FAILED.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is False diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py index 66203d19..e0d4a854 100644 --- a/orchestration/_tests/test_sfapi_flow.py +++ b/orchestration/_tests/test_sfapi_flow.py @@ -1,8 +1,5 @@ # orchestration/_tests/test_sfapi_flow.py - -from pathlib import Path import pytest -from unittest.mock import MagicMock, patch, mock_open from uuid import uuid4 from prefect.blocks.system import Secret @@ -11,276 +8,56 @@ @pytest.fixture(autouse=True, scope="session") def prefect_test_fixture(): - """ - A pytest fixture that automatically sets up and tears down the Prefect test harness - for the entire test session. It creates and saves test secrets and configurations - required for Globus integration. - - Yields: - None - """ with prefect_test_harness(): - globus_client_id = Secret(value=str(uuid4())) - globus_client_id.save(name="globus-client-id", overwrite=True) - globus_client_secret = Secret(value=str(uuid4())) - globus_client_secret.save(name="globus-client-secret", overwrite=True) - + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) yield -# ---------------------------- -# Tests for create_sfapi_client -# ---------------------------- - - -def test_create_sfapi_client_success(): - """ - Test successful creation of the SFAPI client. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Mock data for client_id and client_secret files - mock_client_id = 'value' - mock_client_secret = '{"key": "value"}' - - # Create separate mock_open instances for each file - mock_open_client_id = mock_open(read_data=mock_client_id) - mock_open_client_secret = mock_open(read_data=mock_client_secret) - - with patch("orchestration.flows.bl832.nersc.os.getenv") as mock_getenv, \ - patch("orchestration.flows.bl832.nersc.os.path.isfile") as mock_isfile, \ - patch("builtins.open", side_effect=[ - mock_open_client_id.return_value, - mock_open_client_secret.return_value - ]), \ - patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key") as mock_import_key, \ - patch("orchestration.flows.bl832.nersc.Client") as MockClient: - - # Mock environment variables - mock_getenv.side_effect = lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - - # Mock file existence - mock_isfile.return_value = True - - # Mock JsonWebKey.import_key to return a mock secret - mock_import_key.return_value = "mock_secret" - - # Create the client - client = NERSCTomographyHPCController.create_sfapi_client() - - # Assert that Client was instantiated with 'value' and 'mock_secret' - MockClient.assert_called_once_with("value", "mock_secret") - - # Assert that the returned client is the mocked client - assert client == MockClient.return_value, "Client should be the mocked sfapi_client.Client instance" - - -def test_create_sfapi_client_missing_paths(): - """ - Test creation of the SFAPI client with missing credential paths. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - with patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None): - with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController.create_sfapi_client() - - -def test_create_sfapi_client_missing_files(): - """ - Test creation of the SFAPI client with missing credential files. - """ - with ( - # Mock environment variables - patch( - "orchestration.flows.bl832.nersc.os.getenv", - side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - ), - - # Mock file existence to simulate missing files - patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - ): - # Import the module after applying patches to ensure mocks are in place - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Expect a FileNotFoundError due to missing credential files - with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController.create_sfapi_client() - -# ---------------------------- -# Fixture for Mocking SFAPI Client -# ---------------------------- - - -@pytest.fixture -def mock_sfapi_client(): - """ - Mock the sfapi_client.Client class with necessary methods. - """ - with patch("orchestration.flows.bl832.nersc.Client") as MockClient: - mock_client_instance = MockClient.return_value - - # Mock the user method - mock_user = MagicMock() - mock_user.name = "testuser" - mock_client_instance.user.return_value = mock_user - - # Mock the compute method to return a mocked compute object - mock_compute = MagicMock() - mock_job = MagicMock() - mock_job.jobid = "12345" - mock_job.state = "COMPLETED" - mock_compute.submit_job.return_value = mock_job - mock_client_instance.compute.return_value = mock_compute - - yield mock_client_instance - - -# ---------------------------- -# Fixture for Mocking Config832 -# ---------------------------- - -@pytest.fixture -def mock_config832(): - """ - Mock the Config832 class to provide necessary configurations. - """ - with patch("orchestration.flows.bl832.nersc.Config832") as MockConfig: - mock_config = MockConfig.return_value - mock_config.harbor_images832 = { - "recon_image": "mock_recon_image", - "multires_image": "mock_multires_image", - } - mock_config.apps = {"als_transfer": "some_config"} - yield mock_config - - -# ---------------------------- -# Tests for NERSCTomographyHPCController -# ---------------------------- - -def test_reconstruct_success(mock_sfapi_client, mock_config832): - """ - Test successful reconstruction job submission. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - - # Assert that the method returns True - assert result is True, "reconstruct should return True on successful job completion." - - -def test_reconstruct_submission_failure(mock_sfapi_client, mock_config832): - """ - Test reconstruction job submission failure. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) - - # Assert that the method returns False - assert result is False, "reconstruct should return False on submission failure." - - -def test_build_multi_resolution_success(mock_sfapi_client, mock_config832): - """ - Test successful multi-resolution job submission. - """ +def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - # Assert that the method returns True - assert result is True, "build_multi_resolution should return True on successful job completion." - - -def test_build_multi_resolution_submission_failure(mock_sfapi_client, mock_config832): - """ - Test multi-resolution job submission failure. - """ + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) + mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") + mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") + + client = NERSCTomographyHPCController._create_sfapi_client() + + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value + + +def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) - # Assert that the method returns False - assert result is False, "build_multi_resolution should return False on submission failure." + with pytest.raises(ValueError, match="Missing NERSC credentials paths."): + NERSCTomographyHPCController._create_sfapi_client() -def test_job_submission(mock_sfapi_client): - """ - Test job submission and status updates. - """ +def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=MagicMock()) - file_path = "path/to/file.h5" - - # Mock Path to extract file and folder names - with patch.object(Path, 'parent', new_callable=MagicMock) as mock_parent, \ - patch.object(Path, 'stem', new_callable=MagicMock) as mock_stem: - mock_parent.name = "to" - mock_stem.return_value = "file" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - # Verify the returned job has the expected attributes - submitted_job = mock_sfapi_client.compute.return_value.submit_job.return_value - assert submitted_job.jobid == "12345", "Job ID should match the mock job ID." - assert submitted_job.state == "COMPLETED", "Job state should be COMPLETED." + with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): + NERSCTomographyHPCController._create_sfapi_client() From 2f3b085554725b0b65828b548e258d7946f2820a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 17 Mar 2026 10:51:14 -0700 Subject: [PATCH 06/12] Updating multires() method to use the generic _submit_job() and _wait_for_job() helpers --- orchestration/flows/bl832/nersc.py | 79 +++++++++++++++++------------- 1 file changed, 45 insertions(+), 34 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 1ddb3a50..a412ff8d 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -5,7 +5,6 @@ import logging import os from pathlib import Path -import re import time from authlib.jose import JsonWebKey @@ -418,7 +417,8 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - user = self.client.user() + # user = self.client.user() + username = self._get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] logger.info(f"{multires_image=}") @@ -429,7 +429,7 @@ def build_multi_resolution( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -469,42 +469,53 @@ def build_multi_resolution( date """ try: - logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") + logger.info("Submitting Tiff to Zarr job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") + time.sleep(60) + success = self._wait_for_job(job_id) + logger.info(f"Multiresolution job {'completed' if success else 'failed'}.") + return success + except Exception as e: + logger.error(f"Error during multiresolution job submission or completion: {e}") + return False + # try: + # logger.info("Submitting Tiff to Zarr job script to Perlmutter.") + # perlmutter = self.client.compute(Machine.perlmutter) + # job = perlmutter.submit_job(job_script) + # logger.info(f"Submitted job ID: {job.jobid}") - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") + # try: + # job.update() + # except Exception as update_err: + # logger.warning(f"Initial job update failed, continuing: {update_err}") - time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") + # time.sleep(60) + # logger.info(f"Job {job.jobid} current state: {job.state}") - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") + # job.complete() # Wait until the job completes + # logger.info("Reconstruction job completed successfully.") - return True + # return True - except Exception as e: - logger.warning(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + # except Exception as e: + # logger.warning(f"Error during job submission or completion: {e}") + # match = re.search(r"Job not found:\s*(\d+)", str(e)) + + # if match: + # jobid = match.group(1) + # logger.info(f"Attempting to recover job {jobid}.") + # try: + # job = self.client.perlmutter.job(jobid=jobid) + # time.sleep(30) + # job.complete() + # logger.info("Reconstruction job completed successfully after recovery.") + # return True + # except Exception as recovery_err: + # logger.error(f"Failed to recover job {jobid}: {recovery_err}") + # return False + # else: + # return False def start_streaming_service( self, From 7c50366690abc355adb915943c79bd2c7604a66f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 30 Mar 2026 14:53:53 -0700 Subject: [PATCH 07/12] successfully ran reconstruction using the IRI-API --- orchestration/flows/bl832/nersc.py | 79 ++++-- orchestration/globus/token.py | 390 +++++++++++++++++++++-------- scripts/get_globus_token.py | 337 +++++++++++++++++++++++++ 3 files changed, 681 insertions(+), 125 deletions(-) create mode 100644 scripts/get_globus_token.py diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index a412ff8d..06146597 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -20,7 +20,11 @@ from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) -from orchestration.globus.token import get_access_token_confidential, DEFAULT_TOKEN_FILE +from orchestration.globus.token import ( + get_access_token, + DEFAULT_TOKEN_FILE, + IRI_SCOPE, +) from orchestration.prefect import schedule_prefect_flow from orchestration.transfer_controller import get_transfer_controller, CopyMethod @@ -31,7 +35,9 @@ # Applies only to NERSCLoginMethod.IRIAPI _IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" -_IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client +_IRI_COMPUTE_RESOURCE: str = "compute" +_IRI_SCRATCH_RESOURCE: str = "scratch" +# _IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client _IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" _IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" _IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ @@ -39,6 +45,7 @@ "profile", "email", "urn:globus:auth:scope:auth.globus.org:view_identities", + IRI_SCOPE, }) _API_BASE_URLS: dict[NERSCLoginMethod, str] = { @@ -127,33 +134,33 @@ def _create_iriapi_client() -> Client: ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. RuntimeError: If the acquired token is missing required scopes. """ - client_id = os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) - client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) + client_id = "fae5c579-490a-4d76-b6eb-d78f65caeb63" # os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) + # client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) if not client_id: raise ValueError( f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." ) - if not client_secret: - raise ValueError( - f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " - "A Globus Confidential App client is required for automated IRI API auth." - ) + # if not client_secret: + # raise ValueError( + # f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " + # "A Globus Confidential App client is required for automated IRI API auth." + # ) token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE - access_token = get_access_token_confidential( + access_token = get_access_token( client_id=client_id, - client_secret=client_secret, - required_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, - resource_server=_IRIAPI_GLOBUS_RESOURCE_SERVER, + requested_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, token_file=token_file, + force_login=False, ) return httpx.Client( base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], headers={"Authorization": f"Bearer {access_token}"}, + timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), ) @staticmethod @@ -231,12 +238,39 @@ def _submit_job(self, job_script: str) -> str: return str(job.jobid) elif self.login_method is NERSCLoginMethod.IRIAPI: + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" + + script_body = "\n".join( + line for line in job_script.splitlines() + if not line.startswith("#SBATCH") and not line.startswith("#!/") + ).strip() + + job_spec = { + "executable": "/bin/bash", + "arguments": ["-c", script_body], + "stdout_path": f"{pscratch_path}/tomo_recon_logs/iri_job.out", + "stderr_path": f"{pscratch_path}/tomo_recon_logs/iri_job.err", + "resources": { + "node_count": 1, + "processes_per_node": 1, + "cpu_cores_per_process": 64, + "exclusive_node_use": True, + }, + "attributes": { + "duration": 1800, + "queue_name": "realtime", + "account": "als", + "custom_attributes": {"constraint": "cpu"}, + }, + } + response = self.client.post( - "/api/v1/compute/job/perlmutter", - json={"script": job_script}, + f"/api/v1/compute/job/{_IRI_COMPUTE_RESOURCE}", + json=job_spec, ) response.raise_for_status() - return str(response.json()["job_id"]) + return str(response.json()["id"]) else: raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") @@ -262,13 +296,16 @@ def _wait_for_job(self, job_id: str) -> bool: elif self.login_method is NERSCLoginMethod.IRIAPI: while True: response = self.client.get( - f"/api/v1/compute/status/perlmutter/{job_id}" + f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/{job_id}" # ← was "perlmutter" ) response.raise_for_status() - state = response.json().get("state") + state = response.json().get("status", {}).get("state") logger.info(f"Job {job_id} state: {state}") - if state in ("COMPLETED", "FAILED", "CANCELLED", "TIMEOUT"): - return state == "COMPLETED" + if state == "completed": + return True + if state in ("failed", "canceled", "timeout"): + logger.error(f"Job {job_id} ended with state: {state}") + return False time.sleep(60) else: @@ -866,7 +903,7 @@ def nersc_streaming_flow( config = Config832() nersc_recon_test_flow( - file_path="dabramov/20241216_153047_ddd.h5.h5", + file_path="dabramov/20241216_153047_ddd.h5", config=config ) # nersc_streaming_flow( diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py index 81b5438f..4970eaa7 100644 --- a/orchestration/globus/token.py +++ b/orchestration/globus/token.py @@ -1,3 +1,4 @@ +# orchestration/globus/token.py import json import logging import os @@ -12,69 +13,20 @@ # Default token file location, matching the Globus SDK convention. DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" -GLOBUS_OIDC_TOKEN_URL: str = "https://auth.globus.org/v2/oauth2/token" +# IRI API Globus scope and resource server. +# The IRI access token lives in other_tokens under this scope, not at the +# top level of the auth.globus.org response. +IRI_SCOPE: str = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +IRI_RESOURCE_SERVER: str = "ed3e577d-f7f3-4639-b96e-ff5a8445d699" -def get_access_token_confidential( - client_id: str, - client_secret: str, - required_scopes: frozenset[str], - resource_server: str, - token_file: Path | None = None, -) -> str: - """Get a valid Globus access token using a Confidential Client (machine-to-machine). - - No browser or user interaction required. If a valid unexpired token exists - on disk it is reused; otherwise a new one is minted via the client - credentials grant and saved. - - Args: - client_id: Globus Confidential App client ID. - client_secret: Globus Confidential App client secret. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. - token_file: Path to the JSON token cache file. Defaults to - ``~/.globus/auth_tokens.json``. - - Returns: - A valid Globus access token string. - - Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - - # 1. Do we already have a valid token? - stored = load_token_file(resolved_token_file) - if stored: - expires_at = stored.get("expires_at_seconds") - if expires_at and time.time() < expires_at: - logger.info("Using cached Globus token (still valid).") - return stored["access_token"] - logger.info("Cached Globus token is expired; minting a new one.") - else: - logger.info("No cached Globus token found; minting a new one.") - - # 2. Mint a new token — same call whether first time or expired. - globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) - token_response = globus_client.oauth2_client_credentials_tokens( - requested_scopes=" ".join(sorted(required_scopes)) - ) - auth_data = token_response.by_resource_server[resource_server] - - granted = set(auth_data.get("scope", "").split()) - missing = required_scopes - granted - if missing: - raise RuntimeError( - f"Globus token is missing required scopes: {sorted(missing)}" - ) - - save_token_file(resolved_token_file, auth_data) - logger.info(f"New Globus token saved to {resolved_token_file}.") - - return auth_data["access_token"] +# --------------------------------------------------------------------------- +# File I/O +# --------------------------------------------------------------------------- def load_token_file(token_file: Path) -> dict | None: """Load saved Globus token data from disk. @@ -112,105 +64,345 @@ def save_token_file(token_file: Path, tokens: dict) -> None: os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) +def _ensure_private_parent_dir(path: Path) -> None: + """Create parent directories for path with owner-only permissions. + + Args: + path: The file path whose parent directory should be created. + """ + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +# --------------------------------------------------------------------------- +# IRI token helpers +# --------------------------------------------------------------------------- + +def _parse_scope_string(scope_string: str) -> set[str]: + """Split a space-separated scope string into a set. + + Args: + scope_string: Space-separated OAuth2 scope string. + + Returns: + Set of individual scope strings. + """ + return set(scope_string.split()) if scope_string else set() + + +def extract_iri_token(token_response_data: dict) -> dict: + """Extract the IRI access token entry from a Globus token response. + + The IRI token is not returned at the top level — it lives inside + ``other_tokens``, identified by :data:`IRI_SCOPE`. + + Args: + token_response_data: Full token response dict as returned by the + Globus SDK (i.e. ``token_response.data``). + + Returns: + Token dict for the IRI resource server. + + Raises: + RuntimeError: If no token matching the IRI scope is found. + """ + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError( + f"Missing token for required IRI scope: {IRI_SCOPE}. " + "Re-run with --force-login and ensure consent is granted for the IRI scope." + ) + + +def _replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + """Return a copy of token_response_data with the IRI entry replaced. + + Args: + token_response_data: Full stored token response dict. + iri_token_data: Updated IRI token dict to splice in. + + Returns: + Updated token response dict. + """ + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for i, token_data in enumerate(other_tokens): + if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): + other_tokens[i] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def _get_iri_refresh_token(stored_tokens: dict) -> str | None: + """Extract the IRI refresh token from stored token data, if present. + + Args: + stored_tokens: Full stored token response dict. + + Returns: + The IRI refresh token string, or None if absent. + """ + try: + return extract_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def _get_auth_refresh_token(stored_tokens: dict) -> str | None: + """Extract the top-level Globus Auth refresh token from stored data. + + Args: + stored_tokens: Full stored token response dict. + + Returns: + The auth refresh token string, or None if absent. + """ + if "refresh_token" in stored_tokens: + return stored_tokens["refresh_token"] + auth_tokens = stored_tokens.get("auth.globus.org") + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + return None + + +# --------------------------------------------------------------------------- +# NativeApp flow (interactive) +# --------------------------------------------------------------------------- + def interactive_login( client: globus_sdk.NativeAppAuthClient, - required_scopes: frozenset[str], - resource_server: str, + requested_scopes: frozenset[str], + prompt_login: bool = False, ) -> dict: """Run an interactive browser-based Globus login flow. Prints an authorization URL, waits for the user to paste an auth code, - and exchanges it for tokens. + and returns the full token response data including ``other_tokens``. Args: client: Globus NativeAppAuthClient to drive the flow. - required_scopes: Set of OAuth2 scopes to request. - resource_server: Resource server key to extract from the token response - (e.g. ``"auth.globus.org"``). + requested_scopes: Set of OAuth2 scopes to request. Should include + :data:`IRI_SCOPE` to obtain an IRI API token. + prompt_login: If True, add ``prompt=login`` to the authorize URL to + force a fresh identity-provider login. Returns: - Token dict for the given resource server. + Full token response dict (``token_response.data``), including + ``other_tokens``. + + Raises: + RuntimeError: If no authorization code is entered, or if the code + exchange fails. """ client.oauth2_start_flow( - requested_scopes=" ".join(sorted(required_scopes)), + requested_scopes=" ".join(sorted(requested_scopes)), refresh_tokens=True, ) logger.info("Open this URL in your browser to authenticate with Globus:") - logger.info(client.oauth2_get_authorize_url()) + prompt = "login" if prompt_login else globus_sdk.MISSING + logger.info(client.oauth2_get_authorize_url(prompt=prompt)) code = input("\nEnter authorization code: ").strip() - token_response = client.oauth2_exchange_code_for_tokens(code) - return token_response.by_resource_server[resource_server] + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the " + "code shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as e: + if e.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed — the code was empty, " + "invalid, expired, or already used. Re-run and try again." + ) from e + raise RuntimeError( + f"Authorization code exchange failed with HTTP {e.http_status}." + ) from e + return token_response.data -def refresh_tokens( +def _refresh_single_token( client: globus_sdk.NativeAppAuthClient, refresh_token: str, - resource_server: str, ) -> dict | None: - """Attempt a silent Globus token refresh. + """Attempt a single Globus token refresh, returning raw response data. Args: - client: Globus NativeAppAuthClient to drive the refresh. + client: NativeAppAuthClient to drive the refresh. refresh_token: The stored refresh token. - resource_server: Resource server key to extract from the token response. Returns: - Fresh token dict for the given resource server, or None if refresh failed. + Raw token response data dict, or None if the refresh failed. """ try: token_response = client.oauth2_refresh_token(refresh_token) - return token_response.by_resource_server[resource_server] + return token_response.data except GlobusAPIError as e: logger.warning( f"Globus token refresh failed ({e.http_status}); " - "falling back to interactive login." + "will fall back to interactive login." ) return None +def _refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, + stored_tokens: dict, +) -> tuple[dict | None, bool]: + """Try to refresh stored tokens, preferring the IRI refresh token. + + Attempts the IRI-specific refresh token first, then falls back to the + top-level Globus Auth refresh token. + + Args: + client: NativeAppAuthClient to drive the refresh. + stored_tokens: Full stored token response dict. + + Returns: + Tuple of ``(updated_token_data, success)``. On failure both values + are ``(None, False)``. + """ + iri_refresh = _get_iri_refresh_token(stored_tokens) + if iri_refresh: + iri_token_data = _refresh_single_token(client, iri_refresh) + if iri_token_data is not None: + return _replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh = _get_auth_refresh_token(stored_tokens) + if auth_refresh: + auth_data = _refresh_single_token(client, auth_refresh) + if auth_data is not None: + return auth_data, True + + return None, False + + def get_access_token( client_id: str, - required_scopes: frozenset[str], - resource_server: str, + requested_scopes: frozenset[str], token_file: Path | None = None, force_login: bool = False, + prompt_login: bool = False, ) -> str: - """Get a valid Globus access token, refreshing or logging in as needed. + """Get a valid IRI API access token via the NativeApp interactive flow. Attempts a silent refresh from the saved token file first. Falls back to interactive browser login if no saved tokens exist, the refresh token is absent, or the refresh fails. Saves the resulting tokens back to disk. + The IRI token is extracted from ``other_tokens`` in the response — it is + not the top-level Globus Auth token. + Args: client_id: Globus NativeApp client ID. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. + requested_scopes: Set of OAuth2 scopes to request. Must include + :data:`IRI_SCOPE` to obtain a usable IRI API token. token_file: Path to the JSON token file. Defaults to ``~/.globus/auth_tokens.json``. force_login: If True, skip refresh and force interactive login. + prompt_login: If True, add ``prompt=login`` to the authorize URL. Returns: - A valid Globus access token string. + A valid IRI API access token string. Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. + RuntimeError: If the IRI scope token is missing from the response. """ resolved_token_file = token_file or DEFAULT_TOKEN_FILE globus_client = globus_sdk.NativeAppAuthClient(client_id) - auth_data: dict | None = None + token_response_data: dict | None = None + used_refresh = False if not force_login: stored = load_token_file(resolved_token_file) - if stored and stored.get("refresh_token"): - auth_data = refresh_tokens( - globus_client, stored["refresh_token"], resource_server + if stored: + token_response_data, used_refresh = _refresh_stored_tokens( + globus_client, stored ) - if auth_data is None: + if token_response_data is None: logger.info("Initiating interactive Globus login.") - auth_data = interactive_login(globus_client, required_scopes, resource_server) + token_response_data = interactive_login( + globus_client, requested_scopes, prompt_login=prompt_login + ) + + # Extract IRI token — if a refresh ran but didn't return the IRI token, + # fall back to interactive login before raising. + try: + iri_token = extract_iri_token(token_response_data) + except RuntimeError: + if used_refresh: + logger.warning( + "Refreshed tokens did not include the IRI token; " + "falling back to interactive login." + ) + token_response_data = interactive_login( + globus_client, requested_scopes, prompt_login=prompt_login + ) + iri_token = extract_iri_token(token_response_data) + else: + raise + + save_token_file(resolved_token_file, token_response_data) + logger.info(f"Globus token saved to {resolved_token_file}.") + + return iri_token["access_token"] + + +# --------------------------------------------------------------------------- +# Confidential Client flow (machine-to-machine) +# --------------------------------------------------------------------------- + +def get_access_token_confidential( + client_id: str, + client_secret: str, + required_scopes: frozenset[str], + resource_server: str, + token_file: Path | None = None, +) -> str: + """Get a valid Globus access token using a Confidential Client. + + No browser or user interaction required. If a valid unexpired token exists + on disk it is reused; otherwise a new one is minted via the client + credentials grant and saved. + + Args: + client_id: Globus Confidential App client ID. + client_secret: Globus Confidential App client secret. + required_scopes: Set of OAuth2 scopes that must be present on the token. + resource_server: Resource server key to extract from the token response. + token_file: Path to the JSON token cache file. Defaults to + ``~/.globus/auth_tokens.json``. + + Returns: + A valid Globus access token string. + + Raises: + RuntimeError: If the acquired token is missing required scopes. + KeyError: If ``access_token`` is absent from the token response. + """ + resolved_token_file = token_file or DEFAULT_TOKEN_FILE + + stored = load_token_file(resolved_token_file) + if stored: + expires_at = stored.get("expires_at_seconds") + if expires_at and time.time() < expires_at: + logger.info("Using cached Globus token (still valid).") + return stored["access_token"] + logger.info("Cached Globus token is expired; minting a new one.") + else: + logger.info("No cached Globus token found; minting a new one.") + + globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) + token_response = globus_client.oauth2_client_credentials_tokens( + requested_scopes=" ".join(sorted(required_scopes)) + ) + auth_data = token_response.by_resource_server[resource_server] granted = set(auth_data.get("scope", "").split()) missing = required_scopes - granted @@ -220,16 +412,6 @@ def get_access_token( ) save_token_file(resolved_token_file, auth_data) - logger.info(f"Globus token saved to {resolved_token_file}.") + logger.info(f"New Globus token saved to {resolved_token_file}.") return auth_data["access_token"] - - -def _ensure_private_parent_dir(path: Path) -> None: - """Create parent directories for path with owner-only permissions. - - Args: - path: The file path whose parent directory should be created. - """ - path.parent.mkdir(parents=True, exist_ok=True) - os.chmod(path.parent, 0o700) diff --git a/scripts/get_globus_token.py b/scripts/get_globus_token.py new file mode 100644 index 00000000..6b615378 --- /dev/null +++ b/scripts/get_globus_token.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import stat +import time +import urllib.error +import urllib.request +from pathlib import Path + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" +RESOURCE_SERVER = "auth.globus.org" +IRI_SCOPE = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +REQUIRED_SCOPES = { + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +} +REQUESTED_SCOPES = REQUIRED_SCOPES | {IRI_SCOPE} +DEFAULT_IRI_VALIDATE_URL = "https://api.iri.nersc.gov/api/v1/account/projects" + + +def parse_args() -> argparse.Namespace: + default_token_file = Path.home() / ".globus" / "auth_tokens.json" + parser = argparse.ArgumentParser( + description=( + "Get Globus Auth tokens with required scopes. " + "Tokens are saved to a secure local file by default." + ) + ) + parser.add_argument( + "--token-file", + type=Path, + default=default_token_file, + help=f"Path for saved token JSON (default: {default_token_file})", + ) + parser.add_argument( + "--print-token", + action="store_true", + help="Print the access token to stdout (off by default).", + ) + parser.add_argument( + "--force-login", + action="store_true", + help="Skip refresh and force interactive browser login.", + ) + parser.add_argument( + "--refresh-only", + action="store_true", + help="Refresh saved tokens only; do not fall back to interactive login.", + ) + parser.add_argument( + "--prompt-login", + action="store_true", + help="Add prompt=login to the Globus authorize URL to force re-authentication.", + ) + parser.add_argument( + "--validate-iri", + action="store_true", + help="Validate the IRI token by calling the IRI account/projects endpoint.", + ) + parser.add_argument( + "--iri-validate-url", + default=DEFAULT_IRI_VALIDATE_URL, + help=( + "IRI endpoint used by --validate-iri " + f"(default: {DEFAULT_IRI_VALIDATE_URL})" + ), + ) + return parser.parse_args() + + +def parse_scope_string(scope_string: str) -> set[str]: + return set(scope_string.split()) if scope_string else set() + + +def ensure_private_parent_dir(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +def load_tokens(token_file: Path) -> dict | None: + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_tokens(token_file: Path, tokens: dict) -> None: + ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def get_refresh_token(stored_tokens: dict) -> str | None: + if "refresh_token" in stored_tokens: + return stored_tokens.get("refresh_token") + + auth_tokens = stored_tokens.get(RESOURCE_SERVER) + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + + return None + + +def get_iri_token(token_response_data: dict) -> dict: + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError(f"Missing token for required IRI scope: {IRI_SCOPE}") + + +def get_iri_refresh_token(stored_tokens: dict) -> str | None: + try: + return get_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for index, token_data in enumerate(other_tokens): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + other_tokens[index] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def validate_auth_data(auth_data: dict) -> dict: + if auth_data.get("resource_server") != RESOURCE_SERVER: + raise RuntimeError( + f"Missing token for required resource server: {RESOURCE_SERVER}" + ) + + granted = parse_scope_string(auth_data.get("scope", "")) + missing = REQUIRED_SCOPES - granted + if missing: + raise RuntimeError(f"Missing required scopes: {sorted(missing)}") + + return get_iri_token(auth_data) + + +def validate_iri_token(iri_token_data: dict, validate_url: str) -> dict | list: + request = urllib.request.Request( + validate_url, + headers={ + "accept": "application/json", + "Authorization": f"Bearer {iri_token_data['access_token']}", + }, + method="GET", + ) + try: + with urllib.request.urlopen(request) as response: + body = response.read().decode("utf-8") + data = json.loads(body) if body else {} + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8") + details = body.strip() or exc.reason + raise RuntimeError( + f"IRI validation failed with HTTP {exc.code} from {validate_url}: {details}" + ) from exc + except urllib.error.URLError as exc: + raise RuntimeError( + f"IRI validation request failed for {validate_url}: {exc.reason}" + ) from exc + except json.JSONDecodeError as exc: + raise RuntimeError( + f"IRI validation returned non-JSON data from {validate_url}" + ) from exc + + if isinstance(data, dict): + session_info = data.get("session_info") + if isinstance(session_info, dict): + authentications = session_info.get("authentications") + if isinstance(authentications, dict) and not authentications: + raise RuntimeError( + "IRI validation succeeded but session_info.authentications is empty. " + "Re-run with --force-login --prompt-login and use a Chrome incognito window." + ) + + return data + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + *, + prompt_login: bool = False, +) -> dict: + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(REQUESTED_SCOPES)), + refresh_tokens=True, + ) + print("Open this URL, login, and consent:") + prompt = "login" if prompt_login else globus_sdk.MISSING + print(client.oauth2_get_authorize_url(prompt=prompt)) + code = input("\nEnter authorization code: ").strip() + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the code " + "shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as exc: + if exc.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed. The code was empty, invalid, " + "expired, or already used. Re-run the script and complete the " + "Globus login flow again." + ) from exc + raise RuntimeError( + f"Authorization code exchange failed with HTTP {exc.http_status}. " + "Re-run the script and try again." + ) from exc + return token_response.data + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, refresh_token: str +) -> dict | None: + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.data + except GlobusAPIError as exc: + print( + f"Refresh failed ({exc.http_status}); switching to interactive login." + ) + return None + + +def refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, stored_tokens: dict +) -> tuple[dict | None, bool]: + iri_refresh_token = get_iri_refresh_token(stored_tokens) + if iri_refresh_token: + iri_token_data = refresh_tokens(client, iri_refresh_token) + if iri_token_data is not None: + return replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh_token = get_refresh_token(stored_tokens) + if auth_refresh_token: + auth_data = refresh_tokens(client, auth_refresh_token) + if auth_data is not None: + return auth_data, True + + return None, False + + +def main() -> None: + args = parse_args() + if args.force_login and args.refresh_only: + raise RuntimeError("Choose only one of --force-login or --refresh-only") + + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + + auth_data = None + used_refresh = False + if not args.force_login: + stored = load_tokens(args.token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + + if auth_data is None: + if args.refresh_only: + raise RuntimeError( + "Refresh-only mode failed. No usable saved refresh token was found " + "or token refresh did not return the required IRI token." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + print( + "Refreshed tokens did not include the IRI token; " + "switching to interactive login." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + + save_tokens(args.token_file, auth_data) + + if args.validate_iri: + validation_data = validate_iri_token(iri_token_data, args.iri_validate_url) + print(f"IRI validation succeeded against {args.iri_validate_url}") + if isinstance(validation_data, dict): + session_info = validation_data.get("session_info") + if isinstance(session_info, dict): + session_id = session_info.get("session_id") + if session_id: + print(f"IRI session_id: {session_id}") + elif isinstance(validation_data, list): + print(f"IRI validation response items: {len(validation_data)}") + + expires_at = iri_token_data.get("expires_at_seconds") + if expires_at: + ttl = int(expires_at - time.time()) + print(f"\nIRI access token valid for ~{max(ttl, 0)} seconds.") + + print(f"Saved token data to {args.token_file}") + print(f"Granted Globus Auth scopes: {auth_data.get('scope', '')}") + print(f"IRI token resource server: {iri_token_data.get('resource_server')}") + print(f"IRI token scopes: {iri_token_data.get('scope', '')}") + + if args.print_token: + print("\nIRI access token:") + print(iri_token_data["access_token"]) + else: + print( + "IRI access token not printed " + "(use --print-token to display it for the NERSC IRI API)." + ) + + +if __name__ == "__main__": + main() From 3a1d8cb91bbb0e7cefa17f1a45a27d9bbd4d0afe Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:21:14 -0700 Subject: [PATCH 08/12] removing token.py and moving the logic to get_globus_token.py --- orchestration/globus/token.py | 417 ---------------------------------- 1 file changed, 417 deletions(-) delete mode 100644 orchestration/globus/token.py diff --git a/orchestration/globus/token.py b/orchestration/globus/token.py deleted file mode 100644 index 4970eaa7..00000000 --- a/orchestration/globus/token.py +++ /dev/null @@ -1,417 +0,0 @@ -# orchestration/globus/token.py -import json -import logging -import os -from pathlib import Path -import stat -import time - -import globus_sdk -from globus_sdk.exc import GlobusAPIError - -logger = logging.getLogger(__name__) - -# Default token file location, matching the Globus SDK convention. -DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" - -# IRI API Globus scope and resource server. -# The IRI access token lives in other_tokens under this scope, not at the -# top level of the auth.globus.org response. -IRI_SCOPE: str = ( - "https://auth.globus.org/scopes/" - "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" -) -IRI_RESOURCE_SERVER: str = "ed3e577d-f7f3-4639-b96e-ff5a8445d699" - - -# --------------------------------------------------------------------------- -# File I/O -# --------------------------------------------------------------------------- - -def load_token_file(token_file: Path) -> dict | None: - """Load saved Globus token data from disk. - - Args: - token_file: Path to the JSON token file. - - Returns: - Parsed token dict, or None if the file does not exist. - """ - if not token_file.exists(): - return None - with token_file.open("r", encoding="utf-8") as f: - return json.load(f) - - -def save_token_file(token_file: Path, tokens: dict) -> None: - """Atomically save Globus token data to disk with owner-only permissions. - - Writes to a temporary file then renames to avoid partial writes. - - Args: - token_file: Destination path for the JSON token file. - tokens: Token dict to serialise. - """ - _ensure_private_parent_dir(token_file) - tmp = token_file.with_suffix(".tmp") - with os.fdopen( - os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), - "w", - encoding="utf-8", - ) as f: - json.dump(tokens, f, indent=2) - os.replace(tmp, token_file) - os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) - - -def _ensure_private_parent_dir(path: Path) -> None: - """Create parent directories for path with owner-only permissions. - - Args: - path: The file path whose parent directory should be created. - """ - path.parent.mkdir(parents=True, exist_ok=True) - os.chmod(path.parent, 0o700) - - -# --------------------------------------------------------------------------- -# IRI token helpers -# --------------------------------------------------------------------------- - -def _parse_scope_string(scope_string: str) -> set[str]: - """Split a space-separated scope string into a set. - - Args: - scope_string: Space-separated OAuth2 scope string. - - Returns: - Set of individual scope strings. - """ - return set(scope_string.split()) if scope_string else set() - - -def extract_iri_token(token_response_data: dict) -> dict: - """Extract the IRI access token entry from a Globus token response. - - The IRI token is not returned at the top level — it lives inside - ``other_tokens``, identified by :data:`IRI_SCOPE`. - - Args: - token_response_data: Full token response dict as returned by the - Globus SDK (i.e. ``token_response.data``). - - Returns: - Token dict for the IRI resource server. - - Raises: - RuntimeError: If no token matching the IRI scope is found. - """ - for token_data in token_response_data.get("other_tokens", []): - if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): - return token_data - raise RuntimeError( - f"Missing token for required IRI scope: {IRI_SCOPE}. " - "Re-run with --force-login and ensure consent is granted for the IRI scope." - ) - - -def _replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: - """Return a copy of token_response_data with the IRI entry replaced. - - Args: - token_response_data: Full stored token response dict. - iri_token_data: Updated IRI token dict to splice in. - - Returns: - Updated token response dict. - """ - merged = dict(token_response_data) - other_tokens = list(merged.get("other_tokens", [])) - for i, token_data in enumerate(other_tokens): - if IRI_SCOPE in _parse_scope_string(token_data.get("scope", "")): - other_tokens[i] = iri_token_data - break - else: - other_tokens.append(iri_token_data) - merged["other_tokens"] = other_tokens - return merged - - -def _get_iri_refresh_token(stored_tokens: dict) -> str | None: - """Extract the IRI refresh token from stored token data, if present. - - Args: - stored_tokens: Full stored token response dict. - - Returns: - The IRI refresh token string, or None if absent. - """ - try: - return extract_iri_token(stored_tokens).get("refresh_token") - except RuntimeError: - return None - - -def _get_auth_refresh_token(stored_tokens: dict) -> str | None: - """Extract the top-level Globus Auth refresh token from stored data. - - Args: - stored_tokens: Full stored token response dict. - - Returns: - The auth refresh token string, or None if absent. - """ - if "refresh_token" in stored_tokens: - return stored_tokens["refresh_token"] - auth_tokens = stored_tokens.get("auth.globus.org") - if isinstance(auth_tokens, dict): - return auth_tokens.get("refresh_token") - return None - - -# --------------------------------------------------------------------------- -# NativeApp flow (interactive) -# --------------------------------------------------------------------------- - -def interactive_login( - client: globus_sdk.NativeAppAuthClient, - requested_scopes: frozenset[str], - prompt_login: bool = False, -) -> dict: - """Run an interactive browser-based Globus login flow. - - Prints an authorization URL, waits for the user to paste an auth code, - and returns the full token response data including ``other_tokens``. - - Args: - client: Globus NativeAppAuthClient to drive the flow. - requested_scopes: Set of OAuth2 scopes to request. Should include - :data:`IRI_SCOPE` to obtain an IRI API token. - prompt_login: If True, add ``prompt=login`` to the authorize URL to - force a fresh identity-provider login. - - Returns: - Full token response dict (``token_response.data``), including - ``other_tokens``. - - Raises: - RuntimeError: If no authorization code is entered, or if the code - exchange fails. - """ - client.oauth2_start_flow( - requested_scopes=" ".join(sorted(requested_scopes)), - refresh_tokens=True, - ) - logger.info("Open this URL in your browser to authenticate with Globus:") - prompt = "login" if prompt_login else globus_sdk.MISSING - logger.info(client.oauth2_get_authorize_url(prompt=prompt)) - code = input("\nEnter authorization code: ").strip() - if not code: - raise RuntimeError( - "No authorization code entered. Re-run the script and paste the " - "code shown by Globus after login." - ) - try: - token_response = client.oauth2_exchange_code_for_tokens(code) - except GlobusAPIError as e: - if e.http_status == 400: - raise RuntimeError( - "Authorization code exchange failed — the code was empty, " - "invalid, expired, or already used. Re-run and try again." - ) from e - raise RuntimeError( - f"Authorization code exchange failed with HTTP {e.http_status}." - ) from e - return token_response.data - - -def _refresh_single_token( - client: globus_sdk.NativeAppAuthClient, - refresh_token: str, -) -> dict | None: - """Attempt a single Globus token refresh, returning raw response data. - - Args: - client: NativeAppAuthClient to drive the refresh. - refresh_token: The stored refresh token. - - Returns: - Raw token response data dict, or None if the refresh failed. - """ - try: - token_response = client.oauth2_refresh_token(refresh_token) - return token_response.data - except GlobusAPIError as e: - logger.warning( - f"Globus token refresh failed ({e.http_status}); " - "will fall back to interactive login." - ) - return None - - -def _refresh_stored_tokens( - client: globus_sdk.NativeAppAuthClient, - stored_tokens: dict, -) -> tuple[dict | None, bool]: - """Try to refresh stored tokens, preferring the IRI refresh token. - - Attempts the IRI-specific refresh token first, then falls back to the - top-level Globus Auth refresh token. - - Args: - client: NativeAppAuthClient to drive the refresh. - stored_tokens: Full stored token response dict. - - Returns: - Tuple of ``(updated_token_data, success)``. On failure both values - are ``(None, False)``. - """ - iri_refresh = _get_iri_refresh_token(stored_tokens) - if iri_refresh: - iri_token_data = _refresh_single_token(client, iri_refresh) - if iri_token_data is not None: - return _replace_iri_token(stored_tokens, iri_token_data), True - - auth_refresh = _get_auth_refresh_token(stored_tokens) - if auth_refresh: - auth_data = _refresh_single_token(client, auth_refresh) - if auth_data is not None: - return auth_data, True - - return None, False - - -def get_access_token( - client_id: str, - requested_scopes: frozenset[str], - token_file: Path | None = None, - force_login: bool = False, - prompt_login: bool = False, -) -> str: - """Get a valid IRI API access token via the NativeApp interactive flow. - - Attempts a silent refresh from the saved token file first. Falls back to - interactive browser login if no saved tokens exist, the refresh token is - absent, or the refresh fails. Saves the resulting tokens back to disk. - - The IRI token is extracted from ``other_tokens`` in the response — it is - not the top-level Globus Auth token. - - Args: - client_id: Globus NativeApp client ID. - requested_scopes: Set of OAuth2 scopes to request. Must include - :data:`IRI_SCOPE` to obtain a usable IRI API token. - token_file: Path to the JSON token file. Defaults to - ``~/.globus/auth_tokens.json``. - force_login: If True, skip refresh and force interactive login. - prompt_login: If True, add ``prompt=login`` to the authorize URL. - - Returns: - A valid IRI API access token string. - - Raises: - RuntimeError: If the IRI scope token is missing from the response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - globus_client = globus_sdk.NativeAppAuthClient(client_id) - - token_response_data: dict | None = None - used_refresh = False - - if not force_login: - stored = load_token_file(resolved_token_file) - if stored: - token_response_data, used_refresh = _refresh_stored_tokens( - globus_client, stored - ) - - if token_response_data is None: - logger.info("Initiating interactive Globus login.") - token_response_data = interactive_login( - globus_client, requested_scopes, prompt_login=prompt_login - ) - - # Extract IRI token — if a refresh ran but didn't return the IRI token, - # fall back to interactive login before raising. - try: - iri_token = extract_iri_token(token_response_data) - except RuntimeError: - if used_refresh: - logger.warning( - "Refreshed tokens did not include the IRI token; " - "falling back to interactive login." - ) - token_response_data = interactive_login( - globus_client, requested_scopes, prompt_login=prompt_login - ) - iri_token = extract_iri_token(token_response_data) - else: - raise - - save_token_file(resolved_token_file, token_response_data) - logger.info(f"Globus token saved to {resolved_token_file}.") - - return iri_token["access_token"] - - -# --------------------------------------------------------------------------- -# Confidential Client flow (machine-to-machine) -# --------------------------------------------------------------------------- - -def get_access_token_confidential( - client_id: str, - client_secret: str, - required_scopes: frozenset[str], - resource_server: str, - token_file: Path | None = None, -) -> str: - """Get a valid Globus access token using a Confidential Client. - - No browser or user interaction required. If a valid unexpired token exists - on disk it is reused; otherwise a new one is minted via the client - credentials grant and saved. - - Args: - client_id: Globus Confidential App client ID. - client_secret: Globus Confidential App client secret. - required_scopes: Set of OAuth2 scopes that must be present on the token. - resource_server: Resource server key to extract from the token response. - token_file: Path to the JSON token cache file. Defaults to - ``~/.globus/auth_tokens.json``. - - Returns: - A valid Globus access token string. - - Raises: - RuntimeError: If the acquired token is missing required scopes. - KeyError: If ``access_token`` is absent from the token response. - """ - resolved_token_file = token_file or DEFAULT_TOKEN_FILE - - stored = load_token_file(resolved_token_file) - if stored: - expires_at = stored.get("expires_at_seconds") - if expires_at and time.time() < expires_at: - logger.info("Using cached Globus token (still valid).") - return stored["access_token"] - logger.info("Cached Globus token is expired; minting a new one.") - else: - logger.info("No cached Globus token found; minting a new one.") - - globus_client = globus_sdk.ConfidentialAppAuthClient(client_id, client_secret) - token_response = globus_client.oauth2_client_credentials_tokens( - requested_scopes=" ".join(sorted(required_scopes)) - ) - auth_data = token_response.by_resource_server[resource_server] - - granted = set(auth_data.get("scope", "").split()) - missing = required_scopes - granted - if missing: - raise RuntimeError( - f"Globus token is missing required scopes: {sorted(missing)}" - ) - - save_token_file(resolved_token_file, auth_data) - logger.info(f"New Globus token saved to {resolved_token_file}.") - - return auth_data["access_token"] From 5d22a45ad2e0c8ba070686a83dc75a26da069058 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:21:55 -0700 Subject: [PATCH 09/12] moving get_globus_token.py to orchestration/globus/ to be used as a module --- .../globus}/get_globus_token.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) rename {scripts => orchestration/globus}/get_globus_token.py (84%) diff --git a/scripts/get_globus_token.py b/orchestration/globus/get_globus_token.py similarity index 84% rename from scripts/get_globus_token.py rename to orchestration/globus/get_globus_token.py index 6b615378..c47057e8 100644 --- a/scripts/get_globus_token.py +++ b/orchestration/globus/get_globus_token.py @@ -11,6 +11,7 @@ import globus_sdk from globus_sdk.exc import GlobusAPIError +DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" RESOURCE_SERVER = "auth.globus.org" IRI_SCOPE = ( @@ -264,6 +265,52 @@ def refresh_stored_tokens( return None, False +def get_iri_access_token( + token_file: Path = DEFAULT_TOKEN_FILE, + force_login: bool = False, + prompt_login: bool = False, +) -> str: + """ + Get a valid IRI access token, refreshing or prompting for login as needed. + Tokens are saved to the specified token_file path (default: ~/.globus/auth_tokens.json). + By default, the function will attempt to refresh saved tokens before falling back + to interactive login. Use force_login=True to skip refresh and require interactive login. + Use prompt_login=True to add prompt=login to the authorization URL, which forces + re-authentication even if the user has an active Globus session in their browser. + + Args: + token_file: Path to save and load token data (default: ~/.globus/auth_tokens.json) + force_login: If True, skip token refresh and require interactive login + prompt_login: If True, add prompt=login to the authorization URL to force re-authentication + + Returns: + A valid IRI access token string with the required scopes. + + Raises: + RuntimeError: If token refresh fails and interactive login is not allowed or fails, + or if the resulting tokens do not include a valid IRI access token. + """ + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + auth_data = None + used_refresh = False + if not force_login: + stored = load_tokens(token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + if auth_data is None: + auth_data = interactive_login(client, prompt_login=prompt_login) + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + auth_data = interactive_login(client, prompt_login=prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + save_tokens(token_file, auth_data) + return iri_token_data["access_token"] + + def main() -> None: args = parse_args() if args.force_login and args.refresh_only: From 92639d56ddd5b63691eddb30e1ab3f11e2037456 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:22:56 -0700 Subject: [PATCH 10/12] Cleaning up nersc.py --- orchestration/flows/bl832/nersc.py | 184 +++++------------------------ 1 file changed, 27 insertions(+), 157 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 06146597..4952c319 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -20,10 +20,9 @@ from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) -from orchestration.globus.token import ( - get_access_token, +from orchestration.globus.get_globus_token import ( + get_iri_access_token, DEFAULT_TOKEN_FILE, - IRI_SCOPE, ) from orchestration.prefect import schedule_prefect_flow from orchestration.transfer_controller import get_transfer_controller, CopyMethod @@ -36,17 +35,7 @@ # Applies only to NERSCLoginMethod.IRIAPI _IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" _IRI_COMPUTE_RESOURCE: str = "compute" -_IRI_SCRATCH_RESOURCE: str = "scratch" -# _IRIAPI_GLOBUS_CLIENT_SECRET_ENV: str = "GLOBUS_CLIENT_SECRET" # set → confidential client _IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" -_IRIAPI_GLOBUS_RESOURCE_SERVER: str = "auth.globus.org" -_IRIAPI_GLOBUS_REQUIRED_SCOPES: frozenset[str] = frozenset({ - "openid", - "profile", - "email", - "urn:globus:auth:scope:auth.globus.org:view_identities", - IRI_SCOPE, -}) _API_BASE_URLS: dict[NERSCLoginMethod, str] = { NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", @@ -135,26 +124,19 @@ def _create_iriapi_client() -> Client: RuntimeError: If the acquired token is missing required scopes. """ client_id = "fae5c579-490a-4d76-b6eb-d78f65caeb63" # os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) - # client_secret = os.getenv(_IRIAPI_GLOBUS_CLIENT_SECRET_ENV) if not client_id: raise ValueError( f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." ) - # if not client_secret: - # raise ValueError( - # f"Globus client secret is unset. Set {_IRIAPI_GLOBUS_CLIENT_SECRET_ENV}. " - # "A Globus Confidential App client is required for automated IRI API auth." - # ) token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE - access_token = get_access_token( - client_id=client_id, - requested_scopes=_IRIAPI_GLOBUS_REQUIRED_SCOPES, + access_token = get_iri_access_token( token_file=token_file, force_login=False, + prompt_login=False ) return httpx.Client( @@ -408,44 +390,6 @@ def reconstruct( logger.error(f"Error during reconstruction job submission or completion: {e}") return False - # try: - # logger.info("Submitting reconstruction job script to Perlmutter.") - # perlmutter = self.client.compute(Machine.perlmutter) - # job = perlmutter.submit_job(job_script) - # logger.info(f"Submitted job ID: {job.jobid}") - - # try: - # job.update() - # except Exception as update_err: - # logger.warning(f"Initial job update failed, continuing: {update_err}") - - # time.sleep(60) - # logger.info(f"Job {job.jobid} current state: {job.state}") - - # job.complete() # Wait until the job completes - # logger.info("Reconstruction job completed successfully.") - # return True - - # except Exception as e: - # logger.info(f"Error during job submission or completion: {e}") - # match = re.search(r"Job not found:\s*(\d+)", str(e)) - - # if match: - # jobid = match.group(1) - # logger.info(f"Attempting to recover job {jobid}.") - # try: - # job = self.client.perlmutter.job(jobid=jobid) - # time.sleep(30) - # job.complete() - # logger.info("Reconstruction job completed successfully after recovery.") - # return True - # except Exception as recovery_err: - # logger.error(f"Failed to recover job {jobid}: {recovery_err}") - # return False - # else: - # # Unknown error: cannot recover - # return False - def build_multi_resolution( self, file_path: str = "", @@ -516,43 +460,6 @@ def build_multi_resolution( except Exception as e: logger.error(f"Error during multiresolution job submission or completion: {e}") return False - # try: - # logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - # perlmutter = self.client.compute(Machine.perlmutter) - # job = perlmutter.submit_job(job_script) - # logger.info(f"Submitted job ID: {job.jobid}") - - # try: - # job.update() - # except Exception as update_err: - # logger.warning(f"Initial job update failed, continuing: {update_err}") - - # time.sleep(60) - # logger.info(f"Job {job.jobid} current state: {job.state}") - - # job.complete() # Wait until the job completes - # logger.info("Reconstruction job completed successfully.") - - # return True - - # except Exception as e: - # logger.warning(f"Error during job submission or completion: {e}") - # match = re.search(r"Job not found:\s*(\d+)", str(e)) - - # if match: - # jobid = match.group(1) - # logger.info(f"Attempting to recover job {jobid}.") - # try: - # job = self.client.perlmutter.job(jobid=jobid) - # time.sleep(30) - # job.complete() - # logger.info("Reconstruction job completed successfully after recovery.") - # return True - # except Exception as recovery_err: - # logger.error(f"Failed to recover job {jobid}: {recovery_err}") - # return False - # else: - # return False def start_streaming_service( self, @@ -711,7 +618,8 @@ def nersc_recon_flow( logger.info(f"Starting NERSC reconstruction flow for {file_path=}") controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=NERSCLoginMethod.SFAPI ) logger.info("NERSC reconstruction controller initialized") @@ -782,8 +690,8 @@ def nersc_recon_flow( return False -@flow(name="nersc_recon_test_flow", flow_run_name="nersc_recon-{file_path}") -def nersc_recon_test_flow( +@flow(name="nersc_recon_test_iriapi_flow", flow_run_name="nersc_recon-{file_path}") +def nersc_recon_test_iriapi_flow( file_path: str, config: Optional[Config832] = None, ) -> bool: @@ -809,62 +717,24 @@ def nersc_recon_test_flow( nersc_reconstruction_success = controller.reconstruct( file_path=file_path, ) - # logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") - # nersc_multi_res_success = controller.build_multi_resolution( - # file_path=file_path, - # ) - # logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") - - # path = Path(file_path) - # folder_name = path.parent.name - # file_name = path.stem - - # tiff_file_path = f"{folder_name}/rec{file_name}" - # zarr_file_path = f"{folder_name}/rec{file_name}.zarr" - - # logger.info(f"{tiff_file_path=}") - # logger.info(f"{zarr_file_path=}") - - # # Transfer reconstructed data - # logger.info("Preparing transfer.") - # transfer_controller = get_transfer_controller( - # transfer_type=CopyMethod.GLOBUS, - # config=config - # ) - - # logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") - # transfer_controller.copy( - # file_path=tiff_file_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.nersc832_alsdev_scratch - # ) - - # transfer_controller.copy( - # file_path=zarr_file_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.nersc832_alsdev_scratch - # ) - - # logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") - # transfer_controller.copy( - # file_path=tiff_file_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.data832_scratch - # ) - - # transfer_controller.copy( - # file_path=zarr_file_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.data832_scratch - # ) - - # logger.info("Scheduling pruning tasks.") - # schedule_pruning( - # config=config, - # raw_file_path=file_path, - # tiff_file_path=tiff_file_path, - # zarr_file_path=zarr_file_path - # ) + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + + nersc_multi_res_success = controller.build_multi_resolution( + file_path=file_path, + ) + logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + tiff_file_path = f"{folder_name}/rec{file_name}" + zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + + logger.info(f"{tiff_file_path=}") + logger.info(f"{zarr_file_path=}") + + # Transfers and pruning omitted from test flow. # TODO: Ingest into SciCat if nersc_reconstruction_success: @@ -902,7 +772,7 @@ def nersc_streaming_flow( if __name__ == "__main__": config = Config832() - nersc_recon_test_flow( + nersc_recon_test_iriapi_flow( file_path="dabramov/20241216_153047_ddd.h5", config=config ) From a1b53c8aebd6e99c5a28beba917210ae4a4b92ce Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:28:37 -0700 Subject: [PATCH 11/12] cleaning up old commented code --- orchestration/flows/bl832/nersc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 4952c319..7a7c3b8f 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -302,7 +302,6 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - # user = self.client.user() username = self._get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path @@ -398,7 +397,6 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - # user = self.client.user() username = self._get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] From 0e26469d68372e0ae6a16bd408f8ef9a48ee93a6 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:34:53 -0700 Subject: [PATCH 12/12] Updating unit tests --- orchestration/_tests/test_bl832/test_nersc.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 7994a3ac..98f8f243 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -52,20 +52,20 @@ def mock_iriapi_client(mocker): client = mocker.MagicMock() submit_response = mocker.MagicMock() - submit_response.json.return_value = {"job_id": "99999"} + submit_response.json.return_value = {"id": "99999"} client.post.return_value = submit_response status_response = mocker.MagicMock() - status_response.json.return_value = {"state": "COMPLETED"} + status_response.json.return_value = {"status": {"state": "completed"}} client.get.return_value = status_response return client - # --------------------------------------------------------------------------- # _create_sfapi_client # --------------------------------------------------------------------------- + def test_create_sfapi_client_success(mocker): """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController @@ -181,20 +181,20 @@ def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config, mon assert result is True mock_iriapi_client.post.assert_called_once() - assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/perlmutter" - assert "script" in mock_iriapi_client.post.call_args.kwargs["json"] + assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/compute" + assert "executable" in mock_iriapi_client.post.call_args.kwargs["json"] mock_iriapi_client.get.assert_called_once_with( - "/api/v1/compute/status/perlmutter/99999" + "/api/v1/compute/status/compute/99999" ) def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config, monkeypatch): - """IRIAPI reconstruct returns False when job state is FAILED.""" + """IRIAPI reconstruct returns False when job state is failed.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} # was {"state": "FAILED"} controller = NERSCTomographyHPCController( client=mock_iriapi_client, @@ -287,17 +287,17 @@ def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_ assert result is True mock_iriapi_client.post.assert_called_once() mock_iriapi_client.get.assert_called_once_with( - "/api/v1/compute/status/perlmutter/99999" + "/api/v1/compute/status/compute/99999" ) def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config, monkeypatch): - """IRIAPI build_multi_resolution returns False when job state is FAILED.""" + """IRIAPI build_multi_resolution returns False when job state is failed.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod monkeypatch.setenv("NERSC_USERNAME", "alsdev") mocker.patch("orchestration.flows.bl832.nersc.time.sleep") - mock_iriapi_client.get.return_value.json.return_value = {"state": "FAILED"} + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} controller = NERSCTomographyHPCController( client=mock_iriapi_client,