diff --git a/orchestration/_tests/test_bl832/__init__.py b/orchestration/_tests/test_bl832/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py new file mode 100644 index 00000000..98f8f243 --- /dev/null +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -0,0 +1,310 @@ +# orchestration/_tests/test_bl832/test_nersc.py +import pytest +from uuid import uuid4 + +from prefect.blocks.system import Secret +from prefect.testing.utilities import prefect_test_harness + + +@pytest.fixture(autouse=True, scope="session") +def prefect_test_fixture(): + with prefect_test_harness(): + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) + yield + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_config(mocker): + config = mocker.MagicMock() + config.ghcr_images832 = { + "recon_image": "mock_recon_image", + "multires_image": "mock_multires_image", + } + return config + + +@pytest.fixture +def mock_sfapi_client(mocker): + """sfapi_client.Client mock with user, compute, submit_job, and job chained.""" + client = mocker.MagicMock() + + mock_user = mocker.MagicMock() + mock_user.name = "testuser" + client.user.return_value = mock_user + + mock_job = mocker.MagicMock() + mock_job.jobid = "12345" + mock_job.state = "COMPLETED" + client.compute.return_value.submit_job.return_value = mock_job + client.compute.return_value.job.return_value = mock_job + + return client + + +@pytest.fixture +def mock_iriapi_client(mocker): + """httpx.Client mock for IRI API responses.""" + client = mocker.MagicMock() + + submit_response = mocker.MagicMock() + submit_response.json.return_value = {"id": "99999"} + client.post.return_value = submit_response + + status_response = mocker.MagicMock() + status_response.json.return_value = {"status": {"state": "completed"}} + client.get.return_value = status_response + + return client + +# --------------------------------------------------------------------------- +# _create_sfapi_client +# --------------------------------------------------------------------------- + + +def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) + mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") + mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") + + client = NERSCTomographyHPCController._create_sfapi_client() + + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value + + +def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) + + with pytest.raises(ValueError, match="Missing NERSC credentials paths."): + NERSCTomographyHPCController._create_sfapi_client() + + +def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) + + with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): + NERSCTomographyHPCController._create_sfapi_client() + + +# --------------------------------------------------------------------------- +# reconstruct — SFAPI +# --------------------------------------------------------------------------- + +def test_reconstruct_sfapi_success(mocker, mock_sfapi_client, mock_config): + """SFAPI reconstruct submits a job and waits for completion.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result is True + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.job.assert_called_once_with(jobid="12345") + mock_sfapi_client.compute.return_value.job.return_value.complete.assert_called_once() + + +def test_reconstruct_sfapi_submission_failure(mocker, mock_sfapi_client, mock_config): + """SFAPI reconstruct returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("SFAPI error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result is False + + +# --------------------------------------------------------------------------- +# reconstruct — IRIAPI +# --------------------------------------------------------------------------- + +def test_reconstruct_iriapi_success(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI reconstruct POSTs a job and polls for COMPLETED state.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result is True + mock_iriapi_client.post.assert_called_once() + assert mock_iriapi_client.post.call_args.args[0] == "/api/v1/compute/job/compute" + assert "executable" in mock_iriapi_client.post.call_args.kwargs["json"] + mock_iriapi_client.get.assert_called_once_with( + "/api/v1/compute/status/compute/99999" + ) + + +def test_reconstruct_iriapi_job_failed(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI reconstruct returns False when job state is failed.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} # was {"state": "FAILED"} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.reconstruct(file_path="folder/scan.h5") + + assert result is False + + +def test_reconstruct_iriapi_missing_username(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI reconstruct raises ValueError when NERSC_USERNAME is unset.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.delenv("NERSC_USERNAME", raising=False) + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + with pytest.raises(ValueError, match="NERSC_USERNAME"): + controller.reconstruct(file_path="folder/scan.h5") + + +# --------------------------------------------------------------------------- +# build_multi_resolution — SFAPI +# --------------------------------------------------------------------------- + +def test_build_multi_resolution_sfapi_success(mocker, mock_sfapi_client, mock_config): + """SFAPI build_multi_resolution submits and waits successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + assert mock_sfapi_client.compute.call_count == 2 + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + + +def test_build_multi_resolution_sfapi_failure(mocker, mock_sfapi_client, mock_config): + """SFAPI build_multi_resolution returns False when submission raises.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("error") + + controller = NERSCTomographyHPCController( + client=mock_sfapi_client, + config=mock_config, + login_method=NERSCLoginMethod.SFAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is False + + +# --------------------------------------------------------------------------- +# build_multi_resolution — IRIAPI +# --------------------------------------------------------------------------- + +def test_build_multi_resolution_iriapi_success(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI build_multi_resolution POSTs and polls successfully.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is True + mock_iriapi_client.post.assert_called_once() + mock_iriapi_client.get.assert_called_once_with( + "/api/v1/compute/status/compute/99999" + ) + + +def test_build_multi_resolution_iriapi_failure(mocker, mock_iriapi_client, mock_config, monkeypatch): + """IRIAPI build_multi_resolution returns False when job state is failed.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + + monkeypatch.setenv("NERSC_USERNAME", "alsdev") + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_iriapi_client.get.return_value.json.return_value = {"status": {"state": "failed"}} + + controller = NERSCTomographyHPCController( + client=mock_iriapi_client, + config=mock_config, + login_method=NERSCLoginMethod.IRIAPI, + ) + + result = controller.build_multi_resolution(file_path="folder/scan.h5") + + assert result is False diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py index 66203d19..e0d4a854 100644 --- a/orchestration/_tests/test_sfapi_flow.py +++ b/orchestration/_tests/test_sfapi_flow.py @@ -1,8 +1,5 @@ # orchestration/_tests/test_sfapi_flow.py - -from pathlib import Path import pytest -from unittest.mock import MagicMock, patch, mock_open from uuid import uuid4 from prefect.blocks.system import Secret @@ -11,276 +8,56 @@ @pytest.fixture(autouse=True, scope="session") def prefect_test_fixture(): - """ - A pytest fixture that automatically sets up and tears down the Prefect test harness - for the entire test session. It creates and saves test secrets and configurations - required for Globus integration. - - Yields: - None - """ with prefect_test_harness(): - globus_client_id = Secret(value=str(uuid4())) - globus_client_id.save(name="globus-client-id", overwrite=True) - globus_client_secret = Secret(value=str(uuid4())) - globus_client_secret.save(name="globus-client-secret", overwrite=True) - + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) yield -# ---------------------------- -# Tests for create_sfapi_client -# ---------------------------- - - -def test_create_sfapi_client_success(): - """ - Test successful creation of the SFAPI client. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Mock data for client_id and client_secret files - mock_client_id = 'value' - mock_client_secret = '{"key": "value"}' - - # Create separate mock_open instances for each file - mock_open_client_id = mock_open(read_data=mock_client_id) - mock_open_client_secret = mock_open(read_data=mock_client_secret) - - with patch("orchestration.flows.bl832.nersc.os.getenv") as mock_getenv, \ - patch("orchestration.flows.bl832.nersc.os.path.isfile") as mock_isfile, \ - patch("builtins.open", side_effect=[ - mock_open_client_id.return_value, - mock_open_client_secret.return_value - ]), \ - patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key") as mock_import_key, \ - patch("orchestration.flows.bl832.nersc.Client") as MockClient: - - # Mock environment variables - mock_getenv.side_effect = lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - - # Mock file existence - mock_isfile.return_value = True - - # Mock JsonWebKey.import_key to return a mock secret - mock_import_key.return_value = "mock_secret" - - # Create the client - client = NERSCTomographyHPCController.create_sfapi_client() - - # Assert that Client was instantiated with 'value' and 'mock_secret' - MockClient.assert_called_once_with("value", "mock_secret") - - # Assert that the returned client is the mocked client - assert client == MockClient.return_value, "Client should be the mocked sfapi_client.Client instance" - - -def test_create_sfapi_client_missing_paths(): - """ - Test creation of the SFAPI client with missing credential paths. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - with patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None): - with pytest.raises(ValueError, match="Missing NERSC credentials paths."): - NERSCTomographyHPCController.create_sfapi_client() - - -def test_create_sfapi_client_missing_files(): - """ - Test creation of the SFAPI client with missing credential files. - """ - with ( - # Mock environment variables - patch( - "orchestration.flows.bl832.nersc.os.getenv", - side_effect=lambda x: { - "PATH_NERSC_CLIENT_ID": "/path/to/client_id", - "PATH_NERSC_PRI_KEY": "/path/to/client_secret" - }.get(x, None) - ), - - # Mock file existence to simulate missing files - patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - ): - # Import the module after applying patches to ensure mocks are in place - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - # Expect a FileNotFoundError due to missing credential files - with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): - NERSCTomographyHPCController.create_sfapi_client() - -# ---------------------------- -# Fixture for Mocking SFAPI Client -# ---------------------------- - - -@pytest.fixture -def mock_sfapi_client(): - """ - Mock the sfapi_client.Client class with necessary methods. - """ - with patch("orchestration.flows.bl832.nersc.Client") as MockClient: - mock_client_instance = MockClient.return_value - - # Mock the user method - mock_user = MagicMock() - mock_user.name = "testuser" - mock_client_instance.user.return_value = mock_user - - # Mock the compute method to return a mocked compute object - mock_compute = MagicMock() - mock_job = MagicMock() - mock_job.jobid = "12345" - mock_job.state = "COMPLETED" - mock_compute.submit_job.return_value = mock_job - mock_client_instance.compute.return_value = mock_compute - - yield mock_client_instance - - -# ---------------------------- -# Fixture for Mocking Config832 -# ---------------------------- - -@pytest.fixture -def mock_config832(): - """ - Mock the Config832 class to provide necessary configurations. - """ - with patch("orchestration.flows.bl832.nersc.Config832") as MockConfig: - mock_config = MockConfig.return_value - mock_config.harbor_images832 = { - "recon_image": "mock_recon_image", - "multires_image": "mock_multires_image", - } - mock_config.apps = {"als_transfer": "some_config"} - yield mock_config - - -# ---------------------------- -# Tests for NERSCTomographyHPCController -# ---------------------------- - -def test_reconstruct_success(mock_sfapi_client, mock_config832): - """ - Test successful reconstruction job submission. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - - # Assert that the method returns True - assert result is True, "reconstruct should return True on successful job completion." - - -def test_reconstruct_submission_failure(mock_sfapi_client, mock_config832): - """ - Test reconstruction job submission failure. - """ - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.reconstruct(file_path=file_path) - - # Assert that the method returns False - assert result is False, "reconstruct should return False on submission failure." - - -def test_build_multi_resolution_success(mock_sfapi_client, mock_config832): - """ - Test successful multi-resolution job submission. - """ +def test_create_sfapi_client_success(mocker): + """Valid credentials produce a Client instance.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() - - # Verify that complete was called on the job - mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - # Assert that the method returns True - assert result is True, "build_multi_resolution should return True on successful job completion." - - -def test_build_multi_resolution_submission_failure(mock_sfapi_client, mock_config832): - """ - Test multi-resolution job submission failure. - """ + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) + mocker.patch( + "builtins.open", + side_effect=[ + mocker.mock_open(read_data="my-client-id")(), + mocker.mock_open(read_data='{"kty": "RSA", "n": "x", "e": "y"}')(), + ] + ) + mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") + mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") + + client = NERSCTomographyHPCController._create_sfapi_client() + + mock_client_cls.assert_called_once_with("my-client-id", "mock_secret") + assert client is mock_client_cls.return_value + + +def test_create_sfapi_client_missing_paths(mocker): + """Unset env vars raise ValueError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - file_path = "path/to/file.h5" - - # Simulate submission failure - mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - result = controller.build_multi_resolution(file_path=file_path) + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) - # Assert that the method returns False - assert result is False, "build_multi_resolution should return False on submission failure." + with pytest.raises(ValueError, match="Missing NERSC credentials paths."): + NERSCTomographyHPCController._create_sfapi_client() -def test_job_submission(mock_sfapi_client): - """ - Test job submission and status updates. - """ +def test_create_sfapi_client_missing_files(mocker): + """Env vars set but files absent raise FileNotFoundError.""" from orchestration.flows.bl832.nersc import NERSCTomographyHPCController - from sfapi_client.compute import Machine - - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=MagicMock()) - file_path = "path/to/file.h5" - - # Mock Path to extract file and folder names - with patch.object(Path, 'parent', new_callable=MagicMock) as mock_parent, \ - patch.object(Path, 'stem', new_callable=MagicMock) as mock_stem: - mock_parent.name = "to" - mock_stem.return_value = "file" - - with patch("orchestration.flows.bl832.nersc.time.sleep", return_value=None): - controller.reconstruct(file_path=file_path) - - # Verify that compute was called with Machine.perlmutter - mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) - # Verify that submit_job was called once - mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) - # Verify the returned job has the expected attributes - submitted_job = mock_sfapi_client.compute.return_value.submit_job.return_value - assert submitted_job.jobid == "12345", "Job ID should match the mock job ID." - assert submitted_job.state == "COMPLETED", "Job state should be COMPLETED." + with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): + NERSCTomographyHPCController._create_sfapi_client() diff --git a/orchestration/flows/bl832/job_controller.py b/orchestration/flows/bl832/job_controller.py index b2ff064b..1a23d02a 100644 --- a/orchestration/flows/bl832/job_controller.py +++ b/orchestration/flows/bl832/job_controller.py @@ -10,6 +10,19 @@ load_dotenv() +class NERSCLoginMethod(Enum): + """Selects which NERSC API login method to use when creating a NERSC client. + + Each method corresponds to a different set of credentials and API base URL. + """ + + SFAPI = "sfapi" + """Standard Superfacility API via Iris-registered OAuth2 credentials.""" + + IRIAPI = "iriapi" + """Integrated Research Infrastructure API via IRI-registered OAuth2 credentials.""" + + class TomographyHPCController(ABC): """ Abstract class for tomography HPC controllers. @@ -65,7 +78,8 @@ class HPC(Enum): def get_controller( hpc_type: HPC, - config: Config832 + config: Config832, + login_method: "NERSCLoginMethod | None" = None, ) -> TomographyHPCController: """ Factory function that returns an HPC controller instance for the given HPC environment. @@ -86,10 +100,14 @@ def get_controller( config=config ) elif hpc_type == HPC.NERSC: - from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController, NERSCLoginMethod + resolved_login_method = login_method if isinstance(login_method, NERSCLoginMethod) else NERSCLoginMethod.SFAPI return NERSCTomographyHPCController( - client=NERSCTomographyHPCController.create_sfapi_client(), - config=config + client=NERSCTomographyHPCController.create_nersc_client( + login_method=resolved_login_method + ), + config=config, + login_method=resolved_login_method, ) elif hpc_type == HPC.OLCF: # TODO: Implement OLCF controller diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 727cbbaf..7a7c3b8f 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1,10 +1,10 @@ import datetime from dotenv import load_dotenv +import httpx import json import logging import os from pathlib import Path -import re import time from authlib.jose import JsonWebKey @@ -15,18 +15,34 @@ from typing import Optional from orchestration.flows.bl832.config import Config832 -from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController -from orchestration.transfer_controller import get_transfer_controller, CopyMethod + +from orchestration.flows.bl832.job_controller import get_controller, HPC, NERSCLoginMethod, TomographyHPCController from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) +from orchestration.globus.get_globus_token import ( + get_iri_access_token, + DEFAULT_TOKEN_FILE, +) from orchestration.prefect import schedule_prefect_flow +from orchestration.transfer_controller import get_transfer_controller, CopyMethod logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) load_dotenv() +# Applies only to NERSCLoginMethod.IRIAPI +_IRIAPI_GLOBUS_CLIENT_ID_ENV: str = "GLOBUS_CLIENT_ID" +_IRI_COMPUTE_RESOURCE: str = "compute" +_IRIAPI_TOKEN_FILE_ENV: str = "PATH_GLOBUS_TOKEN_FILE" + +_API_BASE_URLS: dict[NERSCLoginMethod, str] = { + NERSCLoginMethod.SFAPI: "https://api.nersc.gov/api/v1.2", + NERSCLoginMethod.IRIAPI: "https://api.iri.nersc.gov", +} + + class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): """ Implementation for a NERSC-based tomography HPC controller. @@ -36,14 +52,101 @@ class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin) def __init__( self, - client: Client, - config: Config832 + config: Config832, + client: Client | httpx.Client | None = None, + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, ) -> None: TomographyHPCController.__init__(self, config) self.client = client + self.login_method = login_method + + @staticmethod + def create_nersc_client( + login_method: NERSCLoginMethod = NERSCLoginMethod.SFAPI, + ) -> Client: + """Create and return a NERSC client for the requested login method. + + Two fundamentally different auth strategies are supported: + + - :attr:`NERSCLoginMethod.SFAPI`: uses an Iris-registered OAuth2 + client ID + private key (NERSC OIDC flow). Set ``PATH_NERSC_CLIENT_ID`` + and ``PATH_NERSC_PRI_KEY`` to the paths of those files. + + - :attr:`NERSCLoginMethod.IRIAPI`: uses a Globus bearer token written + by ``globus_token.py``. Set ``PATH_GLOBUS_TOKEN_FILE`` to the token + file path, or rely on the default (``~/.globus/auth_tokens.json``). + + Args: + login_method: Which NERSC API to authenticate against. + Defaults to :attr:`NERSCLoginMethod.SFAPI`. + + Returns: + An authenticated :class:`sfapi_client.Client` instance. + + Raises: + ValueError: If SFAPI credential environment variables are unset. + FileNotFoundError: If credential or token files are absent. + RuntimeError: If the Globus token is expired. + Exception: If the underlying client construction fails. + """ + logger.info(f"Creating NERSC client using login method: {login_method.value}") + api_url = _API_BASE_URLS[login_method] + logger.info(f"Targeting API base URL: {api_url}") + + if login_method is NERSCLoginMethod.SFAPI: + client = NERSCTomographyHPCController._create_sfapi_client() + + elif login_method is NERSCLoginMethod.IRIAPI: + client = NERSCTomographyHPCController._create_iriapi_client() + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {login_method}") + + logger.info( + f"NERSC client created successfully " + f"(method={login_method.value}, api_url={api_url})." + ) + return client @staticmethod - def create_sfapi_client() -> Client: + def _create_iriapi_client() -> Client: + """Create a NERSC client for the IRI API using a Globus bearer token. + + Requires ``GLOBUS_CLIENT_ID`` and ``GLOBUS_CLIENT_SECRET`` in the + environment. Reuses a cached token if valid; otherwise mints a new one + via the client credentials grant. No browser or user interaction. + + Returns: + An authenticated :class:`sfapi_client.Client` targeting the IRI API. + + Raises: + ValueError: If ``GLOBUS_CLIENT_ID`` or ``GLOBUS_CLIENT_SECRET`` are unset. + RuntimeError: If the acquired token is missing required scopes. + """ + client_id = "fae5c579-490a-4d76-b6eb-d78f65caeb63" # os.getenv(_IRIAPI_GLOBUS_CLIENT_ID_ENV) + + if not client_id: + raise ValueError( + f"Globus client ID is unset. Set {_IRIAPI_GLOBUS_CLIENT_ID_ENV}." + ) + + token_file_env = os.getenv(_IRIAPI_TOKEN_FILE_ENV) + token_file = Path(token_file_env) if token_file_env else DEFAULT_TOKEN_FILE + + access_token = get_iri_access_token( + token_file=token_file, + force_login=False, + prompt_login=False + ) + + return httpx.Client( + base_url=_API_BASE_URLS[NERSCLoginMethod.IRIAPI], + headers={"Authorization": f"Bearer {access_token}"}, + timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=10.0), + ) + + @staticmethod + def _create_sfapi_client() -> Client: """Create and return an NERSC client instance""" # When generating the SFAPI Key in Iris, make sure to select "asldev" as the user! @@ -74,6 +177,122 @@ def create_sfapi_client() -> Client: logger.error(f"Failed to create NERSC client: {e}") raise e + def _get_nersc_username(self) -> str: + """Get the NERSC username for constructing pscratch paths. + + Uses the sfapi_client user endpoint for SFAPI, or reads + ``NERSC_USERNAME`` from the environment for IRIAPI. + + Returns: + NERSC username string. + + Raises: + ValueError: If IRIAPI is selected and NERSC_USERNAME is unset. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + return self.client.user().name + else: + username = os.getenv("NERSC_USERNAME") + if not username: + raise ValueError( + "NERSC_USERNAME must be set in the environment when using IRIAPI." + ) + return username + + def _submit_job(self, job_script: str) -> str: + """Submit a Slurm job script and return the job ID. + + Dispatches to the appropriate submission mechanism based on + ``self.login_method``. + + Args: + job_script: The full Slurm batch script to submit. + + Returns: + The submitted job ID as a string. + + Raises: + RuntimeError: If job submission fails. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + return str(job.jobid) + + elif self.login_method is NERSCLoginMethod.IRIAPI: + username = self._get_nersc_username() + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" + + script_body = "\n".join( + line for line in job_script.splitlines() + if not line.startswith("#SBATCH") and not line.startswith("#!/") + ).strip() + + job_spec = { + "executable": "/bin/bash", + "arguments": ["-c", script_body], + "stdout_path": f"{pscratch_path}/tomo_recon_logs/iri_job.out", + "stderr_path": f"{pscratch_path}/tomo_recon_logs/iri_job.err", + "resources": { + "node_count": 1, + "processes_per_node": 1, + "cpu_cores_per_process": 64, + "exclusive_node_use": True, + }, + "attributes": { + "duration": 1800, + "queue_name": "realtime", + "account": "als", + "custom_attributes": {"constraint": "cpu"}, + }, + } + + response = self.client.post( + f"/api/v1/compute/job/{_IRI_COMPUTE_RESOURCE}", + json=job_spec, + ) + response.raise_for_status() + return str(response.json()["id"]) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + + def _wait_for_job(self, job_id: str) -> bool: + """Block until a submitted job completes. + + Dispatches to the appropriate polling mechanism based on + ``self.login_method``. + + Args: + job_id: The job ID returned by `_submit_job`. + + Returns: + True if the job completed successfully, False otherwise. + """ + if self.login_method is NERSCLoginMethod.SFAPI: + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.job(jobid=job_id) + job.complete() + return True + + elif self.login_method is NERSCLoginMethod.IRIAPI: + while True: + response = self.client.get( + f"/api/v1/compute/status/{_IRI_COMPUTE_RESOURCE}/{job_id}" # ← was "perlmutter" + ) + response.raise_for_status() + state = response.json().get("status", {}).get("state") + logger.info(f"Job {job_id} state: {state}") + if state == "completed": + return True + if state in ("failed", "canceled", "timeout"): + logger.error(f"Job {job_id} ended with state: {state}") + return False + time.sleep(60) + + else: + raise ValueError(f"Unhandled NERSCLoginMethod: {self.login_method}") + def reconstruct( self, file_path: str = "", @@ -83,7 +302,7 @@ def reconstruct( """ logger.info("Starting NERSC reconstruction process.") - user = self.client.user() + username = self._get_nersc_username() raw_path = self.config.nersc832_alsdev_raw.root_path logger.info(f"{raw_path=}") @@ -97,7 +316,7 @@ def reconstruct( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -159,42 +378,16 @@ def reconstruct( """ try: - logger.info("Submitting reconstruction job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + logger.info("Submitting reconstruction job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - return True - + success = self._wait_for_job(job_id) + logger.info(f"Reconstruction job {'completed' if success else 'failed'}.") + return success except Exception as e: - logger.info(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - # Unknown error: cannot recover - return False + logger.error(f"Error during reconstruction job submission or completion: {e}") + return False def build_multi_resolution( self, @@ -204,7 +397,7 @@ def build_multi_resolution( logger.info("Starting NERSC multiresolution process.") - user = self.client.user() + username = self._get_nersc_username() multires_image = self.config.ghcr_images832["multires_image"] logger.info(f"{multires_image=}") @@ -215,7 +408,7 @@ def build_multi_resolution( scratch_path = self.config.nersc832_alsdev_scratch.root_path logger.info(f"{scratch_path=}") - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + pscratch_path = f"/pscratch/sd/{username[0]}/{username}" logger.info(f"{pscratch_path=}") path = Path(file_path) @@ -255,42 +448,16 @@ def build_multi_resolution( date """ try: - logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - + logger.info("Submitting Tiff to Zarr job to Perlmutter.") + job_id = self._submit_job(job_script) + logger.info(f"Submitted job ID: {job_id}") time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - - return True - + success = self._wait_for_job(job_id) + logger.info(f"Multiresolution job {'completed' if success else 'failed'}.") + return success except Exception as e: - logger.warning(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False + logger.error(f"Error during multiresolution job submission or completion: {e}") + return False def start_streaming_service( self, @@ -449,7 +616,8 @@ def nersc_recon_flow( logger.info(f"Starting NERSC reconstruction flow for {file_path=}") controller = get_controller( hpc_type=HPC.NERSC, - config=config + config=config, + login_method=NERSCLoginMethod.SFAPI ) logger.info("NERSC reconstruction controller initialized") @@ -520,6 +688,59 @@ def nersc_recon_flow( return False +@flow(name="nersc_recon_test_iriapi_flow", flow_run_name="nersc_recon-{file_path}") +def nersc_recon_test_iriapi_flow( + file_path: str, + config: Optional[Config832] = None, +) -> bool: + """ + Perform tomography reconstruction on NERSC. + + :param file_path: Path to the file to reconstruct. + """ + logger = get_run_logger() + + if config is None: + logger.info("Initializing Config") + config = Config832() + + logger.info(f"Starting NERSC reconstruction flow for {file_path=}") + controller = get_controller( + hpc_type=HPC.NERSC, + config=config, + login_method=NERSCLoginMethod.IRIAPI + ) + logger.info("NERSC reconstruction controller initialized") + + nersc_reconstruction_success = controller.reconstruct( + file_path=file_path, + ) + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + + nersc_multi_res_success = controller.build_multi_resolution( + file_path=file_path, + ) + logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + tiff_file_path = f"{folder_name}/rec{file_name}" + zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + + logger.info(f"{tiff_file_path=}") + logger.info(f"{zarr_file_path=}") + + # Transfers and pruning omitted from test flow. + + # TODO: Ingest into SciCat + if nersc_reconstruction_success: + return True + else: + return False + + @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( walltime: datetime.timedelta = datetime.timedelta(minutes=5), @@ -549,8 +770,8 @@ def nersc_streaming_flow( if __name__ == "__main__": config = Config832() - nersc_recon_flow( - file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + nersc_recon_test_iriapi_flow( + file_path="dabramov/20241216_153047_ddd.h5", config=config ) # nersc_streaming_flow( diff --git a/orchestration/globus/get_globus_token.py b/orchestration/globus/get_globus_token.py new file mode 100644 index 00000000..c47057e8 --- /dev/null +++ b/orchestration/globus/get_globus_token.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +import argparse +import json +import os +import stat +import time +import urllib.error +import urllib.request +from pathlib import Path + +import globus_sdk +from globus_sdk.exc import GlobusAPIError + +DEFAULT_TOKEN_FILE: Path = Path.home() / ".globus" / "auth_tokens.json" +CLIENT_ID = "fae5c579-490a-4d76-b6eb-d78f65caeb63" +RESOURCE_SERVER = "auth.globus.org" +IRI_SCOPE = ( + "https://auth.globus.org/scopes/" + "ed3e577d-f7f3-4639-b96e-ff5a8445d699/iri_api" +) +REQUIRED_SCOPES = { + "openid", + "profile", + "email", + "urn:globus:auth:scope:auth.globus.org:view_identities", +} +REQUESTED_SCOPES = REQUIRED_SCOPES | {IRI_SCOPE} +DEFAULT_IRI_VALIDATE_URL = "https://api.iri.nersc.gov/api/v1/account/projects" + + +def parse_args() -> argparse.Namespace: + default_token_file = Path.home() / ".globus" / "auth_tokens.json" + parser = argparse.ArgumentParser( + description=( + "Get Globus Auth tokens with required scopes. " + "Tokens are saved to a secure local file by default." + ) + ) + parser.add_argument( + "--token-file", + type=Path, + default=default_token_file, + help=f"Path for saved token JSON (default: {default_token_file})", + ) + parser.add_argument( + "--print-token", + action="store_true", + help="Print the access token to stdout (off by default).", + ) + parser.add_argument( + "--force-login", + action="store_true", + help="Skip refresh and force interactive browser login.", + ) + parser.add_argument( + "--refresh-only", + action="store_true", + help="Refresh saved tokens only; do not fall back to interactive login.", + ) + parser.add_argument( + "--prompt-login", + action="store_true", + help="Add prompt=login to the Globus authorize URL to force re-authentication.", + ) + parser.add_argument( + "--validate-iri", + action="store_true", + help="Validate the IRI token by calling the IRI account/projects endpoint.", + ) + parser.add_argument( + "--iri-validate-url", + default=DEFAULT_IRI_VALIDATE_URL, + help=( + "IRI endpoint used by --validate-iri " + f"(default: {DEFAULT_IRI_VALIDATE_URL})" + ), + ) + return parser.parse_args() + + +def parse_scope_string(scope_string: str) -> set[str]: + return set(scope_string.split()) if scope_string else set() + + +def ensure_private_parent_dir(path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + os.chmod(path.parent, 0o700) + + +def load_tokens(token_file: Path) -> dict | None: + if not token_file.exists(): + return None + with token_file.open("r", encoding="utf-8") as f: + return json.load(f) + + +def save_tokens(token_file: Path, tokens: dict) -> None: + ensure_private_parent_dir(token_file) + tmp = token_file.with_suffix(".tmp") + with os.fdopen( + os.open(tmp, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600), + "w", + encoding="utf-8", + ) as f: + json.dump(tokens, f, indent=2) + os.replace(tmp, token_file) + os.chmod(token_file, stat.S_IRUSR | stat.S_IWUSR) + + +def get_refresh_token(stored_tokens: dict) -> str | None: + if "refresh_token" in stored_tokens: + return stored_tokens.get("refresh_token") + + auth_tokens = stored_tokens.get(RESOURCE_SERVER) + if isinstance(auth_tokens, dict): + return auth_tokens.get("refresh_token") + + return None + + +def get_iri_token(token_response_data: dict) -> dict: + for token_data in token_response_data.get("other_tokens", []): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + return token_data + raise RuntimeError(f"Missing token for required IRI scope: {IRI_SCOPE}") + + +def get_iri_refresh_token(stored_tokens: dict) -> str | None: + try: + return get_iri_token(stored_tokens).get("refresh_token") + except RuntimeError: + return None + + +def replace_iri_token(token_response_data: dict, iri_token_data: dict) -> dict: + merged = dict(token_response_data) + other_tokens = list(merged.get("other_tokens", [])) + for index, token_data in enumerate(other_tokens): + if IRI_SCOPE in parse_scope_string(token_data.get("scope", "")): + other_tokens[index] = iri_token_data + break + else: + other_tokens.append(iri_token_data) + merged["other_tokens"] = other_tokens + return merged + + +def validate_auth_data(auth_data: dict) -> dict: + if auth_data.get("resource_server") != RESOURCE_SERVER: + raise RuntimeError( + f"Missing token for required resource server: {RESOURCE_SERVER}" + ) + + granted = parse_scope_string(auth_data.get("scope", "")) + missing = REQUIRED_SCOPES - granted + if missing: + raise RuntimeError(f"Missing required scopes: {sorted(missing)}") + + return get_iri_token(auth_data) + + +def validate_iri_token(iri_token_data: dict, validate_url: str) -> dict | list: + request = urllib.request.Request( + validate_url, + headers={ + "accept": "application/json", + "Authorization": f"Bearer {iri_token_data['access_token']}", + }, + method="GET", + ) + try: + with urllib.request.urlopen(request) as response: + body = response.read().decode("utf-8") + data = json.loads(body) if body else {} + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8") + details = body.strip() or exc.reason + raise RuntimeError( + f"IRI validation failed with HTTP {exc.code} from {validate_url}: {details}" + ) from exc + except urllib.error.URLError as exc: + raise RuntimeError( + f"IRI validation request failed for {validate_url}: {exc.reason}" + ) from exc + except json.JSONDecodeError as exc: + raise RuntimeError( + f"IRI validation returned non-JSON data from {validate_url}" + ) from exc + + if isinstance(data, dict): + session_info = data.get("session_info") + if isinstance(session_info, dict): + authentications = session_info.get("authentications") + if isinstance(authentications, dict) and not authentications: + raise RuntimeError( + "IRI validation succeeded but session_info.authentications is empty. " + "Re-run with --force-login --prompt-login and use a Chrome incognito window." + ) + + return data + + +def interactive_login( + client: globus_sdk.NativeAppAuthClient, + *, + prompt_login: bool = False, +) -> dict: + client.oauth2_start_flow( + requested_scopes=" ".join(sorted(REQUESTED_SCOPES)), + refresh_tokens=True, + ) + print("Open this URL, login, and consent:") + prompt = "login" if prompt_login else globus_sdk.MISSING + print(client.oauth2_get_authorize_url(prompt=prompt)) + code = input("\nEnter authorization code: ").strip() + if not code: + raise RuntimeError( + "No authorization code entered. Re-run the script and paste the code " + "shown by Globus after login." + ) + try: + token_response = client.oauth2_exchange_code_for_tokens(code) + except GlobusAPIError as exc: + if exc.http_status == 400: + raise RuntimeError( + "Authorization code exchange failed. The code was empty, invalid, " + "expired, or already used. Re-run the script and complete the " + "Globus login flow again." + ) from exc + raise RuntimeError( + f"Authorization code exchange failed with HTTP {exc.http_status}. " + "Re-run the script and try again." + ) from exc + return token_response.data + + +def refresh_tokens( + client: globus_sdk.NativeAppAuthClient, refresh_token: str +) -> dict | None: + try: + token_response = client.oauth2_refresh_token(refresh_token) + return token_response.data + except GlobusAPIError as exc: + print( + f"Refresh failed ({exc.http_status}); switching to interactive login." + ) + return None + + +def refresh_stored_tokens( + client: globus_sdk.NativeAppAuthClient, stored_tokens: dict +) -> tuple[dict | None, bool]: + iri_refresh_token = get_iri_refresh_token(stored_tokens) + if iri_refresh_token: + iri_token_data = refresh_tokens(client, iri_refresh_token) + if iri_token_data is not None: + return replace_iri_token(stored_tokens, iri_token_data), True + + auth_refresh_token = get_refresh_token(stored_tokens) + if auth_refresh_token: + auth_data = refresh_tokens(client, auth_refresh_token) + if auth_data is not None: + return auth_data, True + + return None, False + + +def get_iri_access_token( + token_file: Path = DEFAULT_TOKEN_FILE, + force_login: bool = False, + prompt_login: bool = False, +) -> str: + """ + Get a valid IRI access token, refreshing or prompting for login as needed. + Tokens are saved to the specified token_file path (default: ~/.globus/auth_tokens.json). + By default, the function will attempt to refresh saved tokens before falling back + to interactive login. Use force_login=True to skip refresh and require interactive login. + Use prompt_login=True to add prompt=login to the authorization URL, which forces + re-authentication even if the user has an active Globus session in their browser. + + Args: + token_file: Path to save and load token data (default: ~/.globus/auth_tokens.json) + force_login: If True, skip token refresh and require interactive login + prompt_login: If True, add prompt=login to the authorization URL to force re-authentication + + Returns: + A valid IRI access token string with the required scopes. + + Raises: + RuntimeError: If token refresh fails and interactive login is not allowed or fails, + or if the resulting tokens do not include a valid IRI access token. + """ + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + auth_data = None + used_refresh = False + if not force_login: + stored = load_tokens(token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + if auth_data is None: + auth_data = interactive_login(client, prompt_login=prompt_login) + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + auth_data = interactive_login(client, prompt_login=prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + save_tokens(token_file, auth_data) + return iri_token_data["access_token"] + + +def main() -> None: + args = parse_args() + if args.force_login and args.refresh_only: + raise RuntimeError("Choose only one of --force-login or --refresh-only") + + client = globus_sdk.NativeAppAuthClient(CLIENT_ID) + + auth_data = None + used_refresh = False + if not args.force_login: + stored = load_tokens(args.token_file) + if stored: + auth_data, used_refresh = refresh_stored_tokens(client, stored) + + if auth_data is None: + if args.refresh_only: + raise RuntimeError( + "Refresh-only mode failed. No usable saved refresh token was found " + "or token refresh did not return the required IRI token." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + + try: + iri_token_data = validate_auth_data(auth_data) + except RuntimeError as exc: + if used_refresh and "Missing token for required IRI scope" in str(exc): + print( + "Refreshed tokens did not include the IRI token; " + "switching to interactive login." + ) + auth_data = interactive_login(client, prompt_login=args.prompt_login) + iri_token_data = validate_auth_data(auth_data) + else: + raise + + save_tokens(args.token_file, auth_data) + + if args.validate_iri: + validation_data = validate_iri_token(iri_token_data, args.iri_validate_url) + print(f"IRI validation succeeded against {args.iri_validate_url}") + if isinstance(validation_data, dict): + session_info = validation_data.get("session_info") + if isinstance(session_info, dict): + session_id = session_info.get("session_id") + if session_id: + print(f"IRI session_id: {session_id}") + elif isinstance(validation_data, list): + print(f"IRI validation response items: {len(validation_data)}") + + expires_at = iri_token_data.get("expires_at_seconds") + if expires_at: + ttl = int(expires_at - time.time()) + print(f"\nIRI access token valid for ~{max(ttl, 0)} seconds.") + + print(f"Saved token data to {args.token_file}") + print(f"Granted Globus Auth scopes: {auth_data.get('scope', '')}") + print(f"IRI token resource server: {iri_token_data.get('resource_server')}") + print(f"IRI token scopes: {iri_token_data.get('scope', '')}") + + if args.print_token: + print("\nIRI access token:") + print(iri_token_data["access_token"]) + else: + print( + "IRI access token not printed " + "(use --print-token to display it for the NERSC IRI API)." + ) + + +if __name__ == "__main__": + main()