diff --git a/config.yml b/config.yml index 4e3c6ed2..69a2c077 100644 --- a/config.yml +++ b/config.yml @@ -160,7 +160,22 @@ prefect: scicat: jobs_api_url: https://dataportal.als.lbl.gov/api/ingest/jobs +mlflow: + local: + tracking_uri: http://localhost:5001 + registry_uri: http://localhost:5001 + dev: + tracking_uri: http://mlflow-dev.computing.als.lbl.gov + registry_uri: http://mlflow-dev.computing.als.lbl.gov + prod: + tracking_uri: https://mlflow.computing.als.lbl.gov + registry_uri: https://mlflow.computing.als.lbl.gov + staging: + tracking_uri: https://mlflow-staging.computing.als.lbl.gov + registry_uri: https://mlflow-staging.computing.als.lbl.gov + hpc_submission_settings832: + # ── RECON + MULTIRES SETTINGS ─────────────────────────────────────────────── nersc_reconstruction: # ── SLURM resource allocation ───────────────────────────────────────────── qos: realtime @@ -176,6 +191,8 @@ hpc_submission_settings832: reservation: "" cpus-per-task: 128 walltime: "0:15:00" + + # ── PETIOLE SEGMENTATION SETTINGS ─────────────────────────────────────────── nersc_segmentation_sam3: # ── SLURM resource allocation ───────────────────────────────────────────── qos: regular @@ -244,3 +261,26 @@ hpc_submission_settings832: cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo + + # ── MOON SEGMENTATION SETTINGS ─────────────────────────────────────────── + nersc_segmentation_dinov3_moon: + # ── SLURM resource allocation ───────────────────────────────────────────── + qos: regular + account: als + constraint: gpu + reservation: "" + num_nodes: 4 + ntasks-per-node: 1 + nproc_per_node: 4 + gpus-per-node: 4 + cpus-per-task: 128 + walltime: "00:59:00" + # ── Inference parameters ────────────────────────────────────────────────── + script_name: "src.inference_dino_v2" + project: "moon" + batch_size: 4 + # ── Paths ───────────────────────────────────────────────────────────────── + cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 + conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo + seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo/ + dino_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino_moon/best.ckpt diff --git a/orchestration/_tests/test_bl832/test_mlflow.py b/orchestration/_tests/test_bl832/test_mlflow.py new file mode 100644 index 00000000..4ca12720 --- /dev/null +++ b/orchestration/_tests/test_bl832/test_mlflow.py @@ -0,0 +1,500 @@ +# orchestration/_tests/test_bl832/test_mlflow.py +# +# Tests for the MLflow integration in the NERSC segmentation workflow. +# Covers: +# - get_checkpoint_info (orchestration/mlflow.py) +# - _load_job_options MLflow layer (orchestration/flows/bl832/nersc.py) +# - segmentation_sam3 checkpoint resolution via MLflow + +import json +import pytest +from uuid import uuid4 + +from prefect.blocks.system import Secret +from prefect.testing.utilities import prefect_test_harness + + +# ────────────────────────────────────────────────────────────────────────────── +# Session fixture +# ────────────────────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True, scope="session") +def prefect_test_fixture(): + with prefect_test_harness(): + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) + yield + + +# ────────────────────────────────────────────────────────────────────────────── +# Shared fixtures +# ────────────────────────────────────────────────────────────────────────────── + +@pytest.fixture +def mock_beamline_config(mocker): + """Minimal BeamlineConfig mock with mlflow tracking_uri.""" + config = mocker.MagicMock() + config.mlflow = {"tracking_uri": "http://mock-mlflow:5000"} + return config + + +@pytest.fixture +def mock_config832(mocker): + """ + Mock Config832 with fully-populated nersc_segment_sam3_settings. + Matches the schema expected by _load_job_options / segmentation_sam3. + """ + mock_config = mocker.MagicMock() + mock_config.mlflow = {"tracking_uri": "http://mock-mlflow:5000"} + mock_config.nersc_segment_sam3_settings = { + "qos": "regular", + "account": "als", + "constraint": "gpu", + "reservation": "", + "num_nodes": 4, + "ntasks-per-node": 1, + "gpus-per-node": 4, + "cpus-per-task": 32, + "walltime": "00:59:00", + "batch_size": 1, + "patch_size": 400, + "confidence": [0.5], + "overlap": 0.25, + "prompts": ["cell wall", "lumen"], + "cfs_path": "/mock/cfs", + "conda_env_path": "/mock/conda/sam3", + "seg_scripts_dir": "/mock/seg_scripts/sam3", + "checkpoints_dir": "/mock/checkpoints", + "bpe_path": "/mock/bpe.model", + "original_checkpoint_path": "/mock/original.pt", + "finetuned_checkpoint_path": "/mock/checkpoints/finetuned_v6.pt", + } + mocker.patch("orchestration.flows.bl832.nersc.Config832", return_value=mock_config) + return mock_config + + +@pytest.fixture +def mock_sfapi_client(mocker): + """Mock sfapi_client.Client with a completed SAM3 segmentation job.""" + mock_client = mocker.MagicMock() + mock_user = mocker.MagicMock() + mock_user.name = "testuser" + mock_client.user.return_value = mock_user + + mock_job = mocker.MagicMock() + mock_job.jobid = "55555" + mock_job.state = "COMPLETED" + mock_compute = mocker.MagicMock() + mock_compute.submit_job.return_value = mock_job + mock_client.compute.return_value = mock_compute + + mocker.patch("orchestration.flows.bl832.nersc.Client", return_value=mock_client) + return mock_client + + +def _make_model_version(mocker, *, version="1", tags=None): + """Helper: build a mock MlflowClient model version object.""" + mv = mocker.MagicMock() + mv.version = version + mv.tags = tags or {} + return mv + + +# ────────────────────────────────────────────────────────────────────────────── +# get_checkpoint_info +# ────────────────────────────────────────────────────────────────────────────── + +class TestGetCheckpointInfo: + + def test_returns_checkpoint_info_when_mlflow_reachable(self, mocker, mock_beamline_config): + """Happy path: reachable server + valid production alias → ModelCheckpointInfo.""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + + mv = _make_model_version(mocker, version="3", tags={ + "nersc_path": "/cfs/checkpoints/sam3_v3.pt", + "alcf_path": "/eagle/checkpoints/sam3_v3.pt", + "batch_size": "2", + "prompts": json.dumps(["cell wall", "lumen"]), + }) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.return_value = mv + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config, alias="production") + + assert info is not None + assert info.model_name == "sam3-petiole" + assert info.version == "3" + assert info.alias == "production" + assert info.nersc_path == "/cfs/checkpoints/sam3_v3.pt" + assert info.alcf_path == "/eagle/checkpoints/sam3_v3.pt" + + def test_deserializes_json_inference_params(self, mocker, mock_beamline_config): + """JSON-encoded tag values (lists, dicts) are decoded into Python objects.""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + mv = _make_model_version(mocker, tags={ + "nersc_path": "/cfs/sam3.pt", + "prompts": json.dumps(["cell wall", "lumen"]), + "confidence": json.dumps([0.6, 0.7]), + "batch_size": "4", + }) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.return_value = mv + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert info.inference_params["prompts"] == ["cell wall", "lumen"] + assert info.inference_params["confidence"] == [0.6, 0.7] + assert info.inference_params["batch_size"] == 4 # "4" is valid JSON → int + + def test_returns_none_when_mlflow_unreachable(self, mocker, mock_beamline_config): + """Unreachable tracking server → None (caller falls back to config defaults).""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=False) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert info is None + + def test_returns_none_when_alias_not_found(self, mocker, mock_beamline_config): + """Missing production alias → MlflowException → None.""" + from orchestration.mlflow import get_checkpoint_info + import mlflow.exceptions + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.side_effect = ( + mlflow.exceptions.MlflowException("Alias not found") + ) + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert info is None + + def test_returns_none_when_nersc_path_tag_missing(self, mocker, mock_beamline_config): + """A model version without 'nersc_path' tag → None.""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + mv = _make_model_version(mocker, tags={"alcf_path": "/eagle/sam3.pt"}) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.return_value = mv + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert info is None + + def test_nersc_and_alcf_paths_excluded_from_inference_params(self, mocker, mock_beamline_config): + """nersc_path and alcf_path must NOT appear in inference_params.""" + from orchestration.mlflow import get_checkpoint_info + + mocker.patch("orchestration.mlflow._is_mlflow_reachable", return_value=True) + mv = _make_model_version(mocker, tags={ + "nersc_path": "/cfs/sam3.pt", + "alcf_path": "/eagle/sam3.pt", + "batch_size": "2", + }) + mock_client = mocker.MagicMock() + mock_client.get_model_version_by_alias.return_value = mv + mocker.patch("orchestration.mlflow.get_mlflow_client", return_value=mock_client) + + info = get_checkpoint_info("sam3-petiole", mock_beamline_config) + + assert "nersc_path" not in info.inference_params + assert "alcf_path" not in info.inference_params + assert "batch_size" in info.inference_params + + +# ────────────────────────────────────────────────────────────────────────────── +# _load_job_options — MLflow layer +# ────────────────────────────────────────────────────────────────────────────── + +class TestLoadJobOptionsMLflowLayer: + """ + _load_job_options has three layers: config → MLflow → Prefect Variable. + These tests isolate the MLflow layer by stubbing get_checkpoint_info and + keeping the Prefect Variable at defaults. + """ + + def _patch_variable_defaults(self, mocker): + mocker.patch( + "orchestration.flows.bl832.nersc.Variable.get", + return_value={"defaults": True}, + ) + + def test_mlflow_nersc_path_mapped_to_checkpoint_key(self, mocker, mock_config832): + """When MLflow returns a checkpoint, nersc_path is written to mlflow_checkpoint_key.""" + from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.mlflow import ModelCheckpointInfo + + self._patch_variable_defaults(mocker) + + checkpoint_info = ModelCheckpointInfo( + model_name="sam3-petiole", + version="5", + alias="production", + nersc_path="/cfs/checkpoints/sam3_v5.pt", + alcf_path="", + inference_params={}, + ) + mocker.patch( + "orchestration.flows.bl832.nersc.get_checkpoint_info", + return_value=checkpoint_info, + ) + + base_settings = dict(mock_config832.nersc_segment_sam3_settings) + opts = _load_job_options( + "nersc-segmentation-options", + base_settings, + config=mock_config832, + mlflow_model_name="sam3-petiole", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) + + assert opts["finetuned_checkpoint_path"] == "/cfs/checkpoints/sam3_v5.pt" + + def test_mlflow_inference_params_overlay_config_defaults(self, mocker, mock_config832): + """inference_params from MLflow overwrite matching config keys.""" + from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.mlflow import ModelCheckpointInfo + + self._patch_variable_defaults(mocker) + + checkpoint_info = ModelCheckpointInfo( + model_name="sam3-petiole", + version="2", + alias="production", + nersc_path="/cfs/sam3.pt", + alcf_path="", + inference_params={ + "batch_size": 8, + "confidence": [0.6, 0.7], + "prompts": ["lumen", "cell wall", "vessel"], + }, + ) + mocker.patch( + "orchestration.flows.bl832.nersc.get_checkpoint_info", + return_value=checkpoint_info, + ) + + base_settings = dict(mock_config832.nersc_segment_sam3_settings) + opts = _load_job_options( + "nersc-segmentation-options", + base_settings, + config=mock_config832, + mlflow_model_name="sam3-petiole", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) + + assert opts["batch_size"] == 8 + assert opts["confidence"] == [0.6, 0.7] + assert opts["prompts"] == ["lumen", "cell wall", "vessel"] + + def test_mlflow_layer_skipped_when_config_is_none(self, mocker, mock_config832): + """Passing config=None skips the MLflow layer entirely.""" + from orchestration.flows.bl832.nersc import _load_job_options + + self._patch_variable_defaults(mocker) + spy = mocker.patch("orchestration.flows.bl832.nersc.get_checkpoint_info") + + base_settings = dict(mock_config832.nersc_segment_sam3_settings) + opts = _load_job_options( + "nersc-segmentation-options", + base_settings, + config=None, + mlflow_model_name="sam3-petiole", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) + + spy.assert_not_called() + # Config default should be unchanged + assert opts["finetuned_checkpoint_path"] == base_settings["finetuned_checkpoint_path"] + + def test_mlflow_layer_skipped_when_model_name_is_none(self, mocker, mock_config832): + """Passing mlflow_model_name=None skips the MLflow layer.""" + from orchestration.flows.bl832.nersc import _load_job_options + + self._patch_variable_defaults(mocker) + spy = mocker.patch("orchestration.flows.bl832.nersc.get_checkpoint_info") + + base_settings = dict(mock_config832.nersc_segment_sam3_settings) + _load_job_options( + "nersc-segmentation-options", + base_settings, + config=mock_config832, + mlflow_model_name=None, + ) + + spy.assert_not_called() + + def test_config_defaults_used_when_mlflow_returns_none(self, mocker, mock_config832): + """get_checkpoint_info returning None → config defaults unchanged.""" + from orchestration.flows.bl832.nersc import _load_job_options + + self._patch_variable_defaults(mocker) + mocker.patch( + "orchestration.flows.bl832.nersc.get_checkpoint_info", + return_value=None, + ) + + base_settings = dict(mock_config832.nersc_segment_sam3_settings) + opts = _load_job_options( + "nersc-segmentation-options", + base_settings, + config=mock_config832, + mlflow_model_name="sam3-petiole", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) + + assert opts["finetuned_checkpoint_path"] == base_settings["finetuned_checkpoint_path"] + + def test_config_defaults_used_when_mlflow_raises(self, mocker, mock_config832): + """An exception from get_checkpoint_info is caught; config defaults are used.""" + from orchestration.flows.bl832.nersc import _load_job_options + + self._patch_variable_defaults(mocker) + mocker.patch( + "orchestration.flows.bl832.nersc.get_checkpoint_info", + side_effect=RuntimeError("Network timeout"), + ) + + base_settings = dict(mock_config832.nersc_segment_sam3_settings) + opts = _load_job_options( + "nersc-segmentation-options", + base_settings, + config=mock_config832, + mlflow_model_name="sam3-petiole", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) + + assert opts["finetuned_checkpoint_path"] == base_settings["finetuned_checkpoint_path"] + + def test_prefect_variable_wins_over_mlflow(self, mocker, mock_config832): + """Prefect Variable overrides take priority over MLflow inference params (layer 3 > layer 2).""" + from orchestration.flows.bl832.nersc import _load_job_options + from orchestration.mlflow import ModelCheckpointInfo + + # MLflow says batch_size=8; Prefect Variable says batch_size=16 → 16 wins + mocker.patch( + "orchestration.flows.bl832.nersc.Variable.get", + return_value={"defaults": False, "batch_size": 16}, + ) + + checkpoint_info = ModelCheckpointInfo( + model_name="sam3-petiole", + version="2", + alias="production", + nersc_path="/cfs/sam3.pt", + alcf_path="", + inference_params={"batch_size": 8}, + ) + mocker.patch( + "orchestration.flows.bl832.nersc.get_checkpoint_info", + return_value=checkpoint_info, + ) + + base_settings = dict(mock_config832.nersc_segment_sam3_settings) + opts = _load_job_options( + "nersc-segmentation-options", + base_settings, + config=mock_config832, + mlflow_model_name="sam3-petiole", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) + + assert opts["batch_size"] == 16 + + +# ────────────────────────────────────────────────────────────────────────────── +# segmentation_sam3 — checkpoint path from MLflow in the job script +# ────────────────────────────────────────────────────────────────────────────── + +class TestSegmentationSam3MLflowCheckpoint: + """ + Verify that when _load_job_options resolves a checkpoint path from MLflow, + segmentation_sam3 uses it in the submitted SLURM job script. + """ + + def test_mlflow_checkpoint_appears_in_job_script(self, mocker, mock_sfapi_client, mock_config832): + """ + When _load_job_options returns an MLflow-sourced finetuned_checkpoint_path, + that path must appear in the SLURM script submitted to Perlmutter. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + + mlflow_checkpoint = "/cfs/checkpoints/sam3_mlflow_v7.pt" + resolved_settings = dict(mock_config832.nersc_segment_sam3_settings) + resolved_settings["finetuned_checkpoint_path"] = mlflow_checkpoint + + mocker.patch( + "orchestration.flows.bl832.nersc._load_job_options", + return_value=resolved_settings, + ) + + captured = [] + original_job = mock_sfapi_client.compute.return_value.submit_job.return_value + + def capture_script(script): + captured.append(script) + return original_job + + mock_sfapi_client.compute.return_value.submit_job.side_effect = capture_script + + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) + result = controller.segmentation_sam3(recon_folder_path="folder/recfile") + + assert captured, "submit_job was never called" + assert mlflow_checkpoint in captured[0], ( + "The MLflow checkpoint path must appear in the SLURM job script" + ) + assert result["success"] is True + + def test_config_default_checkpoint_used_when_mlflow_unavailable( + self, mocker, mock_sfapi_client, mock_config832 + ): + """ + When _load_job_options returns the unmodified config default (MLflow absent), + the default checkpoint path should appear in the job script. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch( + "orchestration.flows.bl832.nersc.Variable.get", + return_value={"defaults": True}, + ) + # MLflow is unreachable; _load_job_options falls back to config + mocker.patch( + "orchestration.flows.bl832.nersc.get_checkpoint_info", + return_value=None, + ) + + captured = [] + original_job = mock_sfapi_client.compute.return_value.submit_job.return_value + + def capture_script(script): + captured.append(script) + return original_job + + mock_sfapi_client.compute.return_value.submit_job.side_effect = capture_script + + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) + controller.segmentation_sam3(recon_folder_path="folder/recfile") + + config_default = mock_config832.nersc_segment_sam3_settings["finetuned_checkpoint_path"] + assert captured, "submit_job was never called" + assert config_default in captured[0], ( + "Config default checkpoint path must be used when MLflow is unavailable" + ) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index bfe0bad0..8d7056a8 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -540,7 +540,7 @@ def test_nersc_segmentation_dinov3_task_success(mocker, mock_config832): config=mock_config832 ) - mock_controller.segmentation_dinov3.assert_called_once_with(recon_folder_path="folder/recfile") + mock_controller.segmentation_dinov3.assert_called_once_with(recon_folder_path="folder/recfile", project="petiole") assert result is True @@ -614,9 +614,9 @@ def test_petiole_segment_flow_both_succeed(mocker, mock_config832, mock_recon_su mock_controller.reconstruct.return_value = mock_recon_success mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) - mock_transfer = mocker.MagicMock() - mock_transfer.copy.return_value = True - mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) + mock_globus_transfer = mocker.patch("orchestration.flows.bl832.nersc.globus_transfer_task") + mock_globus_transfer.submit.return_value = _make_future(mocker, True) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") @@ -646,9 +646,9 @@ def test_petiole_segment_flow_only_sam3_succeeds(mocker, mock_config832, mock_re mock_controller.reconstruct.return_value = mock_recon_success mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) - mock_transfer = mocker.MagicMock() - mock_transfer.copy.return_value = True - mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) + mock_globus_transfer = mocker.patch("orchestration.flows.bl832.nersc.globus_transfer_task") + mock_globus_transfer.submit.return_value = _make_future(mocker, True) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") @@ -674,9 +674,9 @@ def test_petiole_segment_flow_both_seg_fail(mocker, mock_config832, mock_recon_s mock_controller.reconstruct.return_value = mock_recon_success mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) - mock_transfer = mocker.MagicMock() - mock_transfer.copy.return_value = False - mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) + mock_globus_transfer = mocker.patch("orchestration.flows.bl832.nersc.globus_transfer_task") + mock_globus_transfer.submit.return_value = _make_future(mocker, False) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") @@ -701,8 +701,96 @@ def test_petiole_segment_flow_recon_failure(mocker, mock_config832): mock_controller = mocker.MagicMock() mock_controller.reconstruct.return_value = {"success": False, "job_id": None, "timing": None} mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) - mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mocker.MagicMock()) + mocker.patch("orchestration.flows.bl832.nersc.globus_transfer_task") mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) with pytest.raises(ValueError, match="Reconstruction at NERSC Failed"): nersc_petiole_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) + +# ────────────────────────────────────────────────────────────────────────────── +# nersc_moon_segment_flow (recon + DINOv3-moon only, no SAM3, no combine) +# ────────────────────────────────────────────────────────────────────────────── + + +def test_moon_segment_flow_succeeds(mocker, mock_config832, mock_recon_success): + """Recon + DINOv3-moon both succeed → flow returns True.""" + from orchestration.flows.bl832.nersc import nersc_moon_segment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct.return_value = mock_recon_success + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + mock_globus_transfer = mocker.patch("orchestration.flows.bl832.nersc.globus_transfer_task") + mock_globus_transfer.submit.return_value = _make_future(mocker, True) + + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + mock_dinov3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dinov3_task") + mock_dinov3_task.submit.return_value = _make_future(mocker, True) + + result = nersc_moon_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) + + assert result is True + mock_controller.reconstruct.assert_called_once() + mock_dinov3_task.submit.assert_called_once_with( + recon_folder_path="folder/recfile", config=mock_config832, project="moon" + ) + + +def test_moon_segment_flow_seg_failure(mocker, mock_config832, mock_recon_success): + """Recon succeeds but DINOv3-moon fails → flow returns False.""" + from orchestration.flows.bl832.nersc import nersc_moon_segment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct.return_value = mock_recon_success + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + mock_globus_transfer = mocker.patch("orchestration.flows.bl832.nersc.globus_transfer_task") + mock_globus_transfer.submit.return_value = _make_future(mocker, False) + + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + mock_dinov3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dinov3_task") + mock_dinov3_task.submit.return_value = _make_future(mocker, False) + + result = nersc_moon_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) + + assert result is False + + +def test_moon_segment_flow_recon_failure(mocker, mock_config832): + """Recon failure should raise ValueError immediately.""" + from orchestration.flows.bl832.nersc import nersc_moon_segment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct.return_value = {"success": False, "job_id": None, "timing": None} + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + mocker.patch("orchestration.flows.bl832.nersc.globus_transfer_task") + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + with pytest.raises(ValueError, match="Reconstruction at NERSC failed"): + nersc_moon_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) + + +def test_moon_segment_flow_no_sam3_no_combine(mocker, mock_config832, mock_recon_success): + """SAM3 and combine tasks should never be called in the moon flow.""" + from orchestration.flows.bl832.nersc import nersc_moon_segment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct.return_value = mock_recon_success + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + mock_globus_transfer = mocker.patch("orchestration.flows.bl832.nersc.globus_transfer_task") + mock_globus_transfer.submit.return_value = _make_future(mocker, True) + + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") + mock_combine_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_combine_segmentations_task") + mock_dinov3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dinov3_task") + mock_dinov3_task.submit.return_value = _make_future(mocker, True) + + nersc_moon_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) + + mock_sam3_task.submit.assert_not_called() + mock_combine_task.submit.assert_not_called() diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 16b03629..8bbbf78c 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -29,6 +29,8 @@ def _beam_specific_config(self) -> None: self.alcf832_scratch = self.endpoints["alcf832_scratch"] # SciCat self.scicat = self.config["scicat"] + # MLflow + self.mlflow = self.config["mlflow"]["local"] # NERSC HPC submission settings self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] @@ -36,3 +38,4 @@ def _beam_specific_config(self) -> None: self.nersc_segment_sam3_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_sam3"] self.nersc_segment_dinov3_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_dinov3"] self.nersc_combine_segmentation_settings = self.config["hpc_submission_settings832"]["nersc_combine_segmentations"] + self.nersc_segment_dinov3_moon_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_dinov3_moon"] diff --git a/orchestration/flows/bl832/dispatcher.py b/orchestration/flows/bl832/dispatcher.py index 6f216f6d..d28a02eb 100644 --- a/orchestration/flows/bl832/dispatcher.py +++ b/orchestration/flows/bl832/dispatcher.py @@ -28,6 +28,10 @@ class FlowParameterMapper: "num_nodes", "config"], "nersc_petiole_segment_flow/nersc_petiole_segment_flow": [ + "file_path", + "num_nodes", + "config"], + "nersc_moon_segment_flow/nersc_moon_segment_flow": [ "file_path", "num_nodes", "config"] @@ -65,6 +69,7 @@ def setup_decision_settings( alcf_recon: bool, nersc_recon: bool, nersc_petiole_segment: bool, + nersc_moon_segment: bool, new_file_832: bool ) -> dict: """ @@ -73,6 +78,7 @@ def setup_decision_settings( :param alcf_recon: Boolean indicating whether to run the ALCF reconstruction flow. :param nersc_recon: Boolean indicating whether to run the NERSC reconstruction flow. :param nersc_petiole_segment: Boolean indicating whether to run the NERSC petiole segmentation flow. + :param nersc_moon_segment: Boolean indicating whether to run the NERSC moon segmentation flow. :param new_file_832: Boolean indicating whether to move files to NERSC. :return: A dictionary containing the settings for each flow. """ @@ -81,12 +87,14 @@ def setup_decision_settings( logger.info(f"Setting up decision settings: alcf_recon={alcf_recon}, " f"nersc_recon={nersc_recon}, " f"nersc_petiole_segment={nersc_petiole_segment}, " + f"nersc_moon_segment={nersc_moon_segment}, " f"new_file_832={new_file_832}") # Define which flows to run based on the input settings settings = { "alcf_recon_flow/alcf_recon_flow": alcf_recon, "nersc_recon_flow/nersc_recon_flow": nersc_recon, "nersc_petiole_segment_flow/nersc_petiole_segment_flow": nersc_petiole_segment, + "nersc_moon_segment_flow/nersc_moon_segment_flow": nersc_moon_segment, "new_832_file_flow/new_file_832": new_file_832 } # Save the settings in a JSON block for later retrieval by other flows @@ -172,6 +180,12 @@ async def dispatcher( run_recon_flow_async("nersc_petiole_segment_flow/nersc_petiole_segment_flow", nersc_petiole_segment_params) ) + if decision_settings.get("nersc_moon_segment_flow/nersc_moon_segment_flow"): + moon_params = FlowParameterMapper.get_flow_parameters( + "nersc_moon_segment_flow/nersc_moon_segment_flow", available_params + ) + tasks.append(run_recon_flow_async("nersc_moon_segment_flow/nersc_moon_segment_flow", moon_params)) + # Run ALCF and NERSC flows in parallel, if any if tasks: try: diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 42378d70..e11a159c 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import datetime from dotenv import load_dotenv import json @@ -16,8 +17,9 @@ from orchestration.flows.bl832.config import Config832 from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController +from orchestration.mlflow import get_checkpoint_info from orchestration.prune_controller import get_prune_controller, PruneMethod -from orchestration.transfer_controller import get_transfer_controller, CopyMethod +from orchestration.transfer_controller import globus_transfer_task from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block ) @@ -28,40 +30,116 @@ load_dotenv() -def _load_job_options(variable_name: str, config_settings: dict[str, Any]) -> dict[str, Any]: - """ - Load job options, using config as defaults and a Prefect Variable as overrides. +@dataclass +class SegmentationModelSpec: + """All config-resolution inputs for a single model+project combination. - Resolution order: + Consumed by ``_load_job_options`` and the job-script builders. + Adding a new model or project means adding one entry to the registry — + nothing else changes. - 1. Load the named Prefect Variable. - 2. If absent, malformed, or ``defaults: true`` → return ``config_settings`` unchanged. - 3. If ``defaults: false`` → return ``config_settings`` with variable values overlaid. + :param variable_name: Prefect Variable name for runtime overrides. + :param settings: Config settings dict (from Config832) for base defaults. + :param mlflow_model_name: Registered MLflow model name. + :param mlflow_checkpoint_key: Config key populated from the MLflow + model's ``nersc_path`` tag. + :param output_subdir: Subdirectory written under ``seg_folder/``, + e.g. ``'dino'``, ``'sam3'``, ``'dino_moon'``. + :param extra_cli_flags: Additional flags injected into the inference + command, e.g. ``{'--project': 'moon'}``. Omit flags not needed. + """ + variable_name: str + settings: dict[str, Any] + mlflow_model_name: str + mlflow_checkpoint_key: str + extra_cli_flags: dict[str, str] = field(default_factory=dict) + + +def _load_job_options( + variable_name: str, + config_settings: dict[str, Any], + config: Config832 | None = None, + mlflow_model_name: str | None = None, + mlflow_checkpoint_key: str | None = None, +) -> dict[str, Any]: + """Load job options with three-layer resolution: config → MLflow → Prefect Variable. + + Resolution order (later layers win): + + 1. ``config_settings`` — authoritative defaults from the config YAML. + 2. MLflow Model Registry — if ``mlflow_model_name`` is provided, all + ``inference_params`` tags are overlaid onto opts by their config key name. + ``nersc_path`` is additionally mapped to ``mlflow_checkpoint_key`` if given. + 3. Prefect Variable (``variable_name``) — skipped if absent or ``defaults: true``. + If ``defaults: false``, provided keys override all lower layers. + + Args: + variable_name: Name of the Prefect Variable to load. + config_settings: Settings dict from Config832 used as base defaults. + config: Config832 instance needed for MLflow lookup. If ``None``, the + MLflow layer is skipped. + mlflow_model_name: Registered MLflow model name, e.g. ``'sam3-petiole'``. + If ``None``, the MLflow layer is skipped. + mlflow_checkpoint_key: Config key to populate from the MLflow model's + ``nersc_path`` tag, e.g. ``'finetuned_checkpoint_path'``. + + Returns: + Resolved options dict ready for use by the caller. + """ + # ── Layer 1: config defaults ────────────────────────────────────────────── + opts = dict(config_settings) - The config YAML is the authoritative source for all default values. The Prefect - Variable only needs to contain the keys it wishes to override, and may introduce - keys not present in config (e.g. a bare ``checkpoint`` filename for SAM3). + # ── Layer 2: MLflow registry ────────────────────────────────────────────── + if config is not None and mlflow_model_name: + try: + checkpoint_info = get_checkpoint_info(mlflow_model_name, config) + if checkpoint_info: + # Map nersc_path to the caller-specified checkpoint key + if mlflow_checkpoint_key: + opts[mlflow_checkpoint_key] = checkpoint_info.nersc_path + logger.info( + f"MLflow '{mlflow_model_name}': " + f"{mlflow_checkpoint_key}={checkpoint_info.nersc_path}" + ) + # Overlay all inference params that match existing config keys + overlaid = [] + for k, v in checkpoint_info.inference_params.items(): + if k in opts: + opts[k] = v + overlaid.append(k) + else: + # Also inject new keys (e.g. alcf_path for future use) + opts[k] = v + logger.info( + f"MLflow '{mlflow_model_name}': overlaid params: {overlaid}" + ) + else: + logger.info( + f"MLflow: no production checkpoint for '{mlflow_model_name}', " + "using config defaults." + ) + except Exception as e: + logger.warning( + f"MLflow lookup failed for '{mlflow_model_name}': {e}. " + "Using config defaults." + ) - :param variable_name: Name of the Prefect Variable to load. - :param config_settings: Settings dict read directly from the Config832 object - (e.g. ``config.nersc_recon_settings``). Used as-is when defaults=True. - :return: Resolved options dict ready for use by the caller. - """ + # ── Layer 3: Prefect Variable overrides ─────────────────────────────────── try: options = Variable.get(variable_name, default={"defaults": True}, _sync=True) if isinstance(options, str): options = json.loads(options) except Exception as e: - logger.warning(f"Could not load '{variable_name}': {e}. Using config defaults.") - return dict(config_settings) + logger.warning(f"Could not load '{variable_name}': {e}. Skipping variable overrides.") + return opts if options.get("defaults", True): - logger.info(f"Using config defaults for '{variable_name}'") - return dict(config_settings) + logger.info(f"Prefect Variable '{variable_name}': no overrides.") + return opts - logger.info(f"Overriding config defaults with variable options for '{variable_name}'") overrides = {k: v for k, v in options.items() if k != "defaults"} - return {**config_settings, **overrides} + logger.info(f"Prefect Variable '{variable_name}': applying overrides: {list(overrides)}") + return {**opts, **overrides} class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): @@ -111,6 +189,44 @@ def create_sfapi_client() -> Client: logger.error(f"Failed to create NERSC client: {e}") raise e + def _get_segmentation_spec(self, model: str, project: str) -> SegmentationModelSpec: + """Return the SegmentationModelSpec for a model+project combination. + + :param model: Model family, e.g. ``'dinov3'`` or ``'sam3'``. + :param project: Experiment project, e.g. ``'petiole'`` or ``'moon'``. + :return: The corresponding SegmentationModelSpec. + :raises ValueError: If the combination is not registered. + """ + registry: dict[tuple[str, str], SegmentationModelSpec] = { + ("dinov3", "petiole"): SegmentationModelSpec( + variable_name="nersc-dinov3-seg-options", + settings=self.config.nersc_segment_dinov3_settings, + mlflow_model_name="dinov3-petiole", + mlflow_checkpoint_key="dino_checkpoint_path", + ), + ("dinov3", "moon"): SegmentationModelSpec( + variable_name="nersc-dinov3-moon-seg-options", + settings=self.config.nersc_segment_dinov3_moon_settings, + mlflow_model_name="dinov3-moon", + mlflow_checkpoint_key="dino_checkpoint_path", + extra_cli_flags={"--project": "moon"}, + ), + ("sam3", "petiole"): SegmentationModelSpec( + variable_name="nersc-segmentation-options", + settings=self.config.nersc_segment_sam3_settings, + mlflow_model_name="sam3-petiole", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ), + # future: ("sam3", "moon"): SegmentationModelSpec(...), + } + key = (model, project) + if key not in registry: + raise ValueError( + f"No segmentation spec registered for model={model!r}, project={project!r}. " + f"Registered combinations: {list(registry)}" + ) + return registry[key] + def reconstruct( self, file_path: str = "", @@ -527,7 +643,13 @@ def segmentation_sam3( user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - opts = _load_job_options("nersc-segmentation-options", self.config.nersc_segment_sam3_settings) + opts = _load_job_options( + variable_name="nersc-segmentation-options", + config_settings=self.config.nersc_segment_sam3_settings, + config=self.config, + mlflow_model_name="sam3-petiole", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) cfs_path = opts["cfs_path"] conda_env_path = opts["conda_env_path"] @@ -792,12 +914,14 @@ def segmentation_sam3( def segmentation_dinov3( self, recon_folder_path: str = "", + project: str = "petiole", ) -> bool: """ Run DINOv3 segmentation at NERSC Perlmutter via SFAPI Slurm job. :param recon_folder_path: Relative path to the reconstructed data folder, e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' + :param project: Project name for segmentation settings. :return: True if the job completed successfully, False otherwise. """ logger.info("Starting NERSC DINOv3 segmentation process.") @@ -806,8 +930,18 @@ def segmentation_dinov3( pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" # Load from config + spec = self._get_segmentation_spec("dinov3", project) + opts = _load_job_options( + variable_name=spec.variable_name, + config_settings=spec.settings, + config=self.config, + mlflow_model_name=spec.mlflow_model_name, + mlflow_checkpoint_key=spec.mlflow_checkpoint_key, + ) - opts = _load_job_options("nersc-dinov3-seg-options", self.config.nersc_segment_dinov3_settings) + extra_flags = "\n".join( + f" {flag} {value} \\" for flag, value in spec.extra_cli_flags.items() + ) cfs_path = opts["cfs_path"] conda_env_path = opts["conda_env_path"] @@ -903,6 +1037,7 @@ def segmentation_dinov3( --output-dir "{output_dir}" \\ --batch-size {batch_size} \\ --finetuned-checkpoint "{dino_checkpoint}" \\ + {extra_flags} --save-overlay SEG_STATUS=$? @@ -1476,6 +1611,16 @@ def nersc_recon_flow( ) logger.info("NERSC reconstruction controller initialized") + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + tiff_file_path = f"{folder_name}/rec{file_name}" + zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + + logger.info(f"{tiff_file_path=}") + logger.info(f"{zarr_file_path=}") + if num_nodes is None: num_nodes = config.nersc_recon_settings.get("num_nodes", 4) logger.info(f"Configured to use {num_nodes} nodes for reconstruction") @@ -1516,54 +1661,47 @@ def nersc_recon_flow( logger.info(f"NERSC reconstruction success: {success}") - nersc_multi_res_success = controller.build_multi_resolution( - file_path=file_path, - ) - logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") - - path = Path(file_path) - folder_name = path.parent.name - file_name = path.stem - - tiff_file_path = f"{folder_name}/rec{file_name}" - zarr_file_path = f"{folder_name}/rec{file_name}.zarr" - - logger.info(f"{tiff_file_path=}") - logger.info(f"{zarr_file_path=}") - - # Transfer reconstructed data - logger.info("Preparing transfer.") - transfer_controller = get_transfer_controller( - transfer_type=CopyMethod.GLOBUS, - config=config + logger.info("Scheduling reconstruction transfers from pscratch to CFS and data832.") + pscratch_to_cfs_tiff_future = globus_transfer_task.submit( + file_path=tiff_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.nersc832_alsdev_scratch, + config=config, ) - - logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") - transfer_controller.copy( + pscratch_to_data832_tiff_future = globus_transfer_task.submit( file_path=tiff_file_path, source=config.nersc832_alsdev_pscratch_scratch, - destination=config.nersc832_alsdev_scratch + destination=config.data832_scratch, + config=config, ) - transfer_controller.copy( - file_path=zarr_file_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.nersc832_alsdev_scratch + logger.info("Building multi-resolution Zarrs.") + nersc_multi_res_success = controller.build_multi_resolution( + file_path=file_path, ) + logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") - logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") - transfer_controller.copy( - file_path=tiff_file_path, + logger.info("Scheduling Zarr transfers from pscratch to CFS and data832.") + pscratch_to_cfs_zarr_future = globus_transfer_task.submit( + file_path=zarr_file_path, source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch + destination=config.nersc832_alsdev_scratch, + config=config, ) - - transfer_controller.copy( + pscratch_to_data832_zarr_future = globus_transfer_task.submit( file_path=zarr_file_path, source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch + destination=config.data832_scratch, + config=config, ) + # Resolve before pruning (which needs to know what landed where) + pscratch_to_cfs_tiff_future.result() + pscratch_to_cfs_zarr_future.result() + pscratch_to_data832_tiff_future.result() + pscratch_to_data832_zarr_future.result() + logger.info("All transfers complete.") + logger.info("Scheduling pruning tasks.") schedule_pruning( config=config, @@ -1611,10 +1749,6 @@ def nersc_petiole_segment_flow( logger.info(f"Reconstructed TIFFs will be at: {scratch_path_tiff}") logger.info(f"Segmented output will be at: {scratch_path_segment}") - transfer_controller = get_transfer_controller( - transfer_type=CopyMethod.GLOBUS, - config=config - ) controller = get_controller(hpc_type=HPC.NERSC, config=config) logger.info("NERSC controller initialized") @@ -1625,6 +1759,10 @@ def nersc_petiole_segment_flow( nersc_reconstruction_success = False sam3_success = False dinov3_success = False + data832_tiff_future = None + data832_sam3_future = None + data832_dinov3_future = None + data832_combined_future = None data832_tiff_transfer_success = False data832_sam3_transfer_success = False data832_dinov3_transfer_success = False @@ -1673,12 +1811,13 @@ def nersc_petiole_segment_flow( # ── STEP 2: Transfer TIFFs to data832 ──────────────────────────────────── logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") try: - data832_tiff_transfer_success = transfer_controller.copy( + data832_tiff_future = globus_transfer_task.submit( file_path=scratch_path_tiff, source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch + destination=config.data832_scratch, + config=config, ) - logger.info(f"Transfer reconstructed TIFF data to data832 success: {data832_tiff_transfer_success}") + logger.info("TIFF transfer to data832 submitted.") except Exception as e: logger.error(f"Failed to transfer TIFFs to data832: {e}") data832_tiff_transfer_success = False @@ -1690,7 +1829,7 @@ def nersc_petiole_segment_flow( recon_folder_path=scratch_path_tiff, config=config ) dinov3_future = nersc_segmentation_dinov3_task.submit( - recon_folder_path=scratch_path_tiff, config=config + recon_folder_path=scratch_path_tiff, config=config, project="petiole" ) # ── STEP 4: Transfer each model's output as it completes ───────────────── @@ -1701,12 +1840,13 @@ def nersc_petiole_segment_flow( logger.info("Transferring SAM3 segmentation outputs to data832") sam3_segment_path = f"{folder_name}/seg{file_name}/sam3" try: - data832_sam3_transfer_success = transfer_controller.copy( + data832_sam3_future = globus_transfer_task.submit( file_path=sam3_segment_path, source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch + destination=config.data832_scratch, + config=config, ) - logger.info(f"SAM3 transfer to data832 success: {data832_sam3_transfer_success}") + logger.info("SAM3 transfer to data832 submitted") except Exception as e: logger.error(f"Failed to transfer SAM3 outputs to data832: {e}") @@ -1716,12 +1856,13 @@ def nersc_petiole_segment_flow( logger.info("Transferring DINOv3 segmentation outputs to data832") dinov3_segment_path = f"{folder_name}/seg{file_name}/dino" try: - data832_dinov3_transfer_success = transfer_controller.copy( + data832_dinov3_future = globus_transfer_task.submit( file_path=dinov3_segment_path, source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch + destination=config.data832_scratch, + config=config, ) - logger.info(f"DINOv3 transfer to data832 success: {data832_dinov3_transfer_success}") + logger.info("DINOv3 transfer to data832 submitted") except Exception as e: logger.error(f"Failed to transfer DINOv3 outputs to data832: {e}") @@ -1743,12 +1884,13 @@ def nersc_petiole_segment_flow( logger.info("Transferring combined segmentation outputs to data832") combined_segment_path = f"{folder_name}/seg{file_name}/combined/sam_dino" try: - data832_combined_transfer_success = transfer_controller.copy( + data832_combined_future = globus_transfer_task.submit( file_path=combined_segment_path, source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch + destination=config.data832_scratch, + config=config, ) - logger.info(f"Combined transfer to data832 success: {data832_combined_transfer_success}") + logger.info("Combined transfer to data832 submitted") except Exception as e: logger.error(f"Failed to transfer combined outputs to data832: {e}") @@ -1758,15 +1900,28 @@ def nersc_petiole_segment_flow( logger.info("Copying rec and seg folders from pscratch to NERSC CFS.") for cfs_path in [scratch_path_tiff, scratch_path_segment]: try: - transfer_controller.copy( + globus_transfer_task.submit( file_path=cfs_path, source=config.nersc832_alsdev_pscratch_scratch, - destination=config.nersc832_alsdev_scratch + destination=config.nersc832_alsdev_scratch, + config=config, ) - logger.info(f"CFS transfer success: {cfs_path}") + logger.info(f"CFS transfer submitted: {cfs_path}") except Exception as e: logger.error(f"Failed to copy {cfs_path} to NERSC CFS: {e}") + # ── Resolve all data832 futures before pruning ──────────────────────────── + data832_tiff_transfer_success = data832_tiff_future.result() if data832_tiff_future else False + data832_sam3_transfer_success = data832_sam3_future.result() if data832_sam3_future else False + data832_dinov3_transfer_success = data832_dinov3_future.result() if data832_dinov3_future else False + data832_combined_transfer_success = data832_combined_future.result() if data832_combined_future else False + + logger.info( + f"Transfer results — tiff: {data832_tiff_transfer_success}, " + f"sam3: {data832_sam3_transfer_success}, dino: {data832_dinov3_transfer_success}, " + f"combined: {data832_combined_transfer_success}" + ) + # ── STEP 6: Pruning ─────────────────────────────────────────────────────── logger.info("Scheduling file pruning tasks.") prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) @@ -1841,6 +1996,203 @@ def nersc_petiole_segment_flow( return False +@flow(name="nersc_moon_segment_flow", flow_run_name="nersc_moon_seg-{file_path}") +def nersc_moon_segment_flow( + file_path: str, + config: Config832 | None = None, + num_nodes: int | None = None, +) -> bool: + """Reconstruct a lunar regolith scan and run DINOv3-moon segmentation. + + Runs reconstruction then DINOv3-moon (ice, particles, pores). No SAM3 or + combine step — those are petiole-specific. Transfer and pruning follow the + same pattern as nersc_petiole_segment_flow. + + :param file_path: Path to the raw .h5 file to be processed. + :param config: Configuration object for the flow. + :param num_nodes: Number of nodes for reconstruction. + :return: True if reconstruction and segmentation both succeeded. + """ + logger = get_run_logger() + + if config is None: + logger.info("Initializing Config") + config = Config832() + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + scratch_path_tiff = f"{folder_name}/rec{file_name}" + scratch_path_segment = f"{folder_name}/seg{file_name}" + + logger.info(f"Starting NERSC reconstruction + DINOv3-moon flow for {file_path=}") + + controller = get_controller(hpc_type=HPC.NERSC, config=config) + + if num_nodes is None: + num_nodes = config.nersc_recon_settings.get("num_nodes", 4) + logger.info(f"Configured to use {num_nodes} nodes for reconstruction") + + # ── STEP 1: Reconstruction ──────────────────────────────────────────────── + recon_result = controller.reconstruct(file_path=file_path, num_nodes=num_nodes) + + if isinstance(recon_result, dict): + nersc_reconstruction_success = recon_result.get("success", False) + timing = recon_result.get("timing") + if timing: + logger.info("=" * 50) + logger.info("TIMING BREAKDOWN") + logger.info("=" * 50) + logger.info(f" Total job time: {timing.get('total', 'N/A')}s") + logger.info(f" Container pull: {timing.get('container_pull', 'N/A')}s") + logger.info( + f" File copy: {timing.get('file_copy', 'N/A')}s " + f"(skipped: {timing.get('copy_skipped', 'N/A')})" + ) + logger.info(f" Metadata detection: {timing.get('metadata', 'N/A')}s") + logger.info(f" RECONSTRUCTION: {timing.get('reconstruction', 'N/A')}s <-- actual recon time") + logger.info(f" Num slices: {timing.get('num_slices', 'N/A')}") + logger.info("=" * 50) + if all(k in timing for k in ["total", "reconstruction"]): + overhead = timing["total"] - timing["reconstruction"] + logger.info(f" Overhead: {overhead}s") + logger.info(f" Reconstruction %: {100 * timing['reconstruction'] / timing['total']:.1f}%") + logger.info("=" * 50) + else: + nersc_reconstruction_success = recon_result + + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + + if not nersc_reconstruction_success: + logger.error("Reconstruction failed — aborting moon segmentation flow.") + raise ValueError("Reconstruction at NERSC failed") + + # ── STEP 2: Transfer TIFFs to data832 ──────────────────────────────────── + data832_tiff_future = None + try: + data832_tiff_future = globus_transfer_task.submit( + file_path=scratch_path_tiff, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch, + config=config, + ) + logger.info("TIFF transfer to data832 submitted.") + except Exception as e: + logger.error(f"Failed to submit TIFF transfer to data832: {e}") + + # ── STEP 3: DINOv3-moon segmentation ───────────────────────────────────── + logger.info("Submitting DINOv3-moon segmentation task.") + moon_future = nersc_segmentation_dinov3_task.submit( + recon_folder_path=scratch_path_tiff, config=config, project="moon" + ) + + moon_success = moon_future.result() + logger.info(f"DINOv3-moon segmentation result: {moon_success}") + + # ── STEP 4: Transfer segmentation outputs to data832 ───────────────────── + data832_moon_future = None + if moon_success: + moon_segment_path = f"{folder_name}/seg{file_name}/dino" + try: + data832_moon_future = globus_transfer_task.submit( + file_path=moon_segment_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch, + config=config, + ) + logger.info("DINOv3-moon transfer to data832 submitted.") + except Exception as e: + logger.error(f"Failed to submit DINOv3-moon transfer to data832: {e}") + + # ── STEP 5: Copy to NERSC CFS ───────────────────────────────────────────── + for cfs_path in [scratch_path_tiff, scratch_path_segment]: + try: + globus_transfer_task.submit( + file_path=cfs_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.nersc832_alsdev_scratch, + config=config, + ) + logger.info(f"CFS transfer submitted: {cfs_path}") + except Exception as e: + logger.error(f"Failed to copy {cfs_path} to NERSC CFS: {e}") + + # ── Resolve futures before pruning ──────────────────────────────────────── + data832_tiff_transfer_success = data832_tiff_future.result() if data832_tiff_future else False + data832_moon_transfer_success = data832_moon_future.result() if data832_moon_future else False + + logger.info( + f"Transfer results — tiff: {data832_tiff_transfer_success}, " + f"moon: {data832_moon_transfer_success}" + ) + + # ── STEP 6: Pruning ─────────────────────────────────────────────────────── + logger.info("Scheduling file pruning tasks.") + prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) + + try: + prune_controller.prune( + file_path=f"{folder_name}/{path.name}", + source_endpoint=config.nersc832_alsdev_pscratch_raw, + check_endpoint=None, + days_from_now=1.0, + ) + except Exception as e: + logger.warning(f"Failed to schedule raw data pruning: {e}") + + try: + prune_controller.prune( + file_path=scratch_path_tiff, + source_endpoint=config.nersc832_alsdev_pscratch_scratch, + check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, + days_from_now=1.0, + ) + except Exception as e: + logger.warning(f"Failed to schedule reconstruction data pruning: {e}") + + if moon_success: + try: + prune_controller.prune( + file_path=scratch_path_segment, + source_endpoint=config.nersc832_alsdev_pscratch_scratch, + check_endpoint=config.data832_scratch if data832_moon_transfer_success else None, + days_from_now=1.0, + ) + except Exception as e: + logger.warning(f"Failed to schedule segmentation data pruning: {e}") + + if data832_tiff_transfer_success: + try: + prune_controller.prune( + file_path=scratch_path_tiff, + source_endpoint=config.data832_scratch, + check_endpoint=None, + days_from_now=30.0, + ) + except Exception as e: + logger.warning(f"Failed to schedule data832 tiff pruning: {e}") + + if data832_moon_transfer_success: + try: + prune_controller.prune( + file_path=scratch_path_segment, + source_endpoint=config.data832_scratch, + check_endpoint=None, + days_from_now=30.0, + ) + except Exception as e: + logger.warning(f"Failed to schedule data832 moon segment pruning: {e}") + + if nersc_reconstruction_success and moon_success: + logger.info("NERSC reconstruction + DINOv3-moon flow completed successfully.") + return True + else: + logger.warning( + f"Flow completed with issues: recon={nersc_reconstruction_success}, moon={moon_success}" + ) + return False + + @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( walltime: datetime.timedelta = datetime.timedelta(minutes=5), @@ -1995,14 +2347,15 @@ def nersc_segmentation_sam3_task( def nersc_segmentation_dinov3_task( recon_folder_path: str, config: Optional[Config832] = None, + project: Optional[str] = "petiole", ) -> bool: logger = get_run_logger() if config is None: logger.info("No config provided, using default Config832.") config = Config832() tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) - logger.info(f"Starting NERSC DINOv3 segmentation task for {recon_folder_path=}") - success = tomography_controller.segmentation_dinov3(recon_folder_path=recon_folder_path) + logger.info(f"Starting NERSC DINOv3 segmentation task for {recon_folder_path=}, {project=}") + success = tomography_controller.segmentation_dinov3(recon_folder_path=recon_folder_path, project=project) if not success: logger.error("DINOv3 segmentation failed.") else: diff --git a/orchestration/flows/bl832/register_mlflow.py b/orchestration/flows/bl832/register_mlflow.py new file mode 100644 index 00000000..814d622b --- /dev/null +++ b/orchestration/flows/bl832/register_mlflow.py @@ -0,0 +1,251 @@ +import logging + +from orchestration.flows.bl832.config import Config832 +from orchestration.flows.bl832.nersc import _load_job_options +from orchestration.mlflow import register_checkpoint + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def register_mlflow_checkpoints(): + config = Config832() + + scripts_dir = "/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/" + + register_checkpoint( + model_name="sam3-petiole", + nersc_path=f"{scripts_dir}sam3_finetune/sam3/checkpoint_v6.pt", + alcf_path="/eagle/IRIBeta/als/seg_models/sam3/checkpoint_v6.pt", + config=config, + alias="production", + description="SAM3 v6 fine-tuned on petiole micro-CT data.", + inference_params={ + # ── paths ────────────────────────────────────────────────────────── + "original_checkpoint_path": + f"{scripts_dir}sam3_finetune/sam3/sam3.pt", + "bpe_path": f"{scripts_dir}sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz", + "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311", + "seg_scripts_dir": f"{scripts_dir}inference_latest/forge_feb_seg_model_demo/", + "checkpoints_dir": f"{scripts_dir}sam3_finetune/sam3/", + # ── inference hyperparameters ─────────────────────────────────────── + "script_name": "src/inference_v6.py", + "batch_size": 1, + "patch_size": 400, + "confidence": [0.5], # list → JSON-encoded automatically + "overlap": 0.25, + "prompts": [ # list → JSON-encoded automatically + "Phloem Fibers", + "Hydrated Xylem vessels", + "Air-based Pith cells", + "Dehydrated Xylem vessels", + ], + }, + ) + + register_checkpoint( + model_name="dinov3-petiole", + nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt", + alcf_path="/eagle/IRIBeta/als/seg_models/dino/best.ckpt", + config=config, + alias="production", + description="DINOv3 fine-tuned on petiole micro-CT data.", + inference_params={ + # ── paths ────────────────────────────────────────────────────────── + "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", + "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", + # ── inference hyperparameters ─────────────────────────────────────── + "script_name": "src.inference_dino_v1", + "batch_size": 4, + "nproc_per_node": 4, + }, + ) + + register_checkpoint( + model_name="dinov3-moon", + nersc_path="/global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino_moon/best.ckpt", + alcf_path="/eagle/IRIBeta/als/seg_models/dino_moon/best.ckpt", + config=config, + alias="production", + description="DINOv3 fine-tuned on lunar regolith micro-CT data (ice, particles, pores).", + inference_params={ + "conda_env_path": "/global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo", + "seg_scripts_dir": f"{scripts_dir}inference_v5_multiseg/forge_feb_seg_model_demo/", + "script_name": "src.inference_dino_v2", + "batch_size": 4, + "nproc_per_node": 4, + }, + ) + + +def retrieve_mlflow_params_test() -> bool: + """Test that _load_job_options correctly pulls inference params from the MLflow registry. + + Verifies the three-layer resolution for both SAM3 and DINOv3: + - MLflow-registered values override config defaults for model-coupled params. + - SLURM allocation params (qos, account, etc.) are unchanged from config. + - List values (confidence, prompts) are correctly deserialized from JSON tags. + + Returns: + True if all assertions pass, False if any check fails. + """ + config = Config832() + all_passed = True + + # ── SAM3 ───────────────────────────────────────────────────────────────── + logger.info("=" * 60) + logger.info("Testing SAM3 option resolution") + logger.info("=" * 60) + + sam3_opts = _load_job_options( + "nersc-segmentation-options", + config.nersc_segment_sam3_settings, + config=config, + mlflow_model_name="sam3-petiole", + mlflow_checkpoint_key="finetuned_checkpoint_path", + ) + + sam3_checks = { + # MLflow should have overridden these + "finetuned_checkpoint_path": ( + lambda v: "checkpoint" in v, + "finetuned_checkpoint_path should contain 'checkpoint'" + ), + "conda_env_path": ( + lambda v: "sam3" in v, + "conda_env_path should reference sam3 env" + ), + "prompts": ( + lambda v: isinstance(v, list) and len(v) > 0, + "prompts should be a non-empty list (JSON-deserialized)" + ), + "confidence": ( + lambda v: isinstance(v, list), + "confidence should be a list (JSON-deserialized)" + ), + "batch_size": ( + lambda v: isinstance(v, int), + "batch_size should be an int" + ), + # SLURM params should still come from config + "qos": ( + lambda v: v == config.nersc_segment_sam3_settings["qos"], + "qos should be unchanged from config" + ), + "account": ( + lambda v: v == config.nersc_segment_sam3_settings["account"], + "account should be unchanged from config" + ), + } + + for key, (check_fn, message) in sam3_checks.items(): + value = sam3_opts.get(key) + passed = check_fn(value) if value is not None else False + status = "✓" if passed else "✗" + logger.info(f" [{status}] {key}={value!r} — {message}") + if not passed: + all_passed = False + + # ── DINOv3 ─────────────────────────────────────────────────────────────── + logger.info("=" * 60) + logger.info("Testing DINOv3 option resolution") + logger.info("=" * 60) + + dino_opts = _load_job_options( + "nersc-dinov3-seg-options", + config.nersc_segment_dinov3_settings, + config=config, + mlflow_model_name="dinov3-petiole", + mlflow_checkpoint_key="dino_checkpoint_path", + ) + + dino_checks = { + "dino_checkpoint_path": ( + lambda v: v.endswith(".ckpt"), + "dino_checkpoint_path should end with .ckpt" + ), + "conda_env_path": ( + lambda v: len(v) > 0, + "conda_env_path should be non-empty" + ), + "batch_size": ( + lambda v: isinstance(v, int) and v > 0, + "batch_size should be a positive int" + ), + "script_name": ( + lambda v: "dino" in v.lower(), + "script_name should reference dino" + ), + # SLURM params unchanged + "qos": ( + lambda v: v == config.nersc_segment_dinov3_settings["qos"], + "qos should be unchanged from config" + ), + "num_nodes": ( + lambda v: isinstance(v, int) and v > 0, + "num_nodes should be a positive int" + ), + } + + for key, (check_fn, message) in dino_checks.items(): + value = dino_opts.get(key) + passed = check_fn(value) if value is not None else False + status = "✓" if passed else "✗" + logger.info(f" [{status}] {key}={value!r} — {message}") + if not passed: + all_passed = False + + # ── DINOv3-moon ─────────────────────────────────────────────────────────── + logger.info("=" * 60) + logger.info("Testing DINOv3-moon option resolution") + logger.info("=" * 60) + + moon_opts = _load_job_options( + "nersc-dinov3-moon-seg-options", + config.nersc_segment_dinov3_moon_settings, + config=config, + mlflow_model_name="dinov3-moon", + mlflow_checkpoint_key="dino_checkpoint_path", + ) + + moon_checks = { + "dino_checkpoint_path": ( + lambda v: v.endswith(".ckpt"), + "dino_checkpoint_path should end with .ckpt" + ), + "script_name": ( + lambda v: "v2" in v.lower(), + "script_name should reference inference_dino_v2" + ), + "batch_size": ( + lambda v: isinstance(v, int) and v > 0, + "batch_size should be a positive int" + ), + "qos": ( + lambda v: v == config.nersc_segment_dinov3_moon_settings["qos"], + "qos should be unchanged from config" + ), + } + + for key, (check_fn, message) in moon_checks.items(): + value = moon_opts.get(key) + passed = check_fn(value) if value is not None else False + status = "✓" if passed else "✗" + logger.info(f" [{status}] {key}={value!r} — {message}") + if not passed: + all_passed = False + + # ── Summary ─────────────────────────────────────────────────────────────── + logger.info("=" * 60) + if all_passed: + logger.info("✓ All MLflow integration checks passed.") + else: + logger.error("✗ One or more MLflow integration checks failed.") + logger.info("=" * 60) + + return all_passed + + +if __name__ == "__main__": + register_mlflow_checkpoints() + retrieve_mlflow_params_test() diff --git a/orchestration/mlflow.py b/orchestration/mlflow.py new file mode 100644 index 00000000..cff8487c --- /dev/null +++ b/orchestration/mlflow.py @@ -0,0 +1,281 @@ +import logging +from dataclasses import dataclass, field +import json +import requests +from typing import Any + +import mlflow +from mlflow.tracking import MlflowClient + +from orchestration.config import BeamlineConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelCheckpointInfo: + """Checkpoint and inference metadata for a registered model version. + + Attributes: + model_name: Registered model name in MLflow (e.g. 'sam3-petiole'). + version: MLflow model version string. + alias: The alias resolved to find this version (e.g. 'production'). + nersc_path: Primary checkpoint path on NERSC CFS. + alcf_path: Primary checkpoint path on ALCF Eagle. + inference_params: Inference hyperparameters and paths stored as tags, + keyed by their config YAML name for direct overlay onto opts. + """ + + model_name: str + version: str + alias: str + nersc_path: str + alcf_path: str + inference_params: dict[str, Any] = field(default_factory=dict) + + +def _is_mlflow_reachable(tracking_uri: str, timeout: float = 2.0) -> bool: + """Check whether the MLflow tracking server is reachable. + + Makes a single lightweight GET to /health with a short timeout. + Used to short-circuit registry lookups before MLflow's HTTP client + fires its retry loop (which defaults to 6 retries with backoff). + + Args: + tracking_uri: Base URL of the MLflow tracking server. + timeout: Connection timeout in seconds. + + Returns: + True if the server responds with HTTP 200, False otherwise. + """ + try: + response = requests.get(f"{tracking_uri}/health", timeout=timeout) + return response.status_code == 200 + except Exception: + return False + + +def get_mlflow_client(config: BeamlineConfig) -> MlflowClient: + """Construct an MlflowClient pointed at the configured tracking server. + + Args: + config: Beamline configuration object, used to read the tracking URI. + + Returns: + An authenticated MlflowClient instance. + """ + tracking_uri = config.mlflow["tracking_uri"] + mlflow.set_tracking_uri(tracking_uri) + return MlflowClient(tracking_uri=tracking_uri) + + +def get_checkpoint_info( + model_name: str, + config: BeamlineConfig, + alias: str = "production", +) -> ModelCheckpointInfo | None: + """Retrieve checkpoint path metadata for a registered model version. + + Looks up the model version registered under ``alias`` in the MLflow + Model Registry and returns its ``nersc_path`` and ``alcf_path`` tags. + Returns ``None`` (and logs a warning) if the model or alias is not found, + allowing callers to fall back to config defaults. + + Args: + model_name: Registered model name, e.g. ``'sam3-petiole'``. + config: Beamline configuration object. + alias: Model version alias to resolve, e.g. ``'production'``. + + Returns: + A ``ModelCheckpointInfo`` with checkpoint paths, or ``None`` if not found. + """ + if not _is_mlflow_reachable(config.mlflow["tracking_uri"]): + logger.warning( + f"MLflow server unreachable at {config.mlflow['tracking_uri']}. " + "Falling back to config defaults." + ) + return None + + client = get_mlflow_client(config) + + try: + mv = client.get_model_version_by_alias(model_name, alias) + except mlflow.exceptions.MlflowException as e: + logger.warning( + f"Could not resolve alias '{alias}' for model '{model_name}': {e}. " + "Falling back to config defaults." + ) + return None + + tags = mv.tags or {} + nersc_path = tags.get("nersc_path", "") + alcf_path = tags.get("alcf_path", "") + + if not nersc_path: + logger.warning( + f"Model '{model_name}' v{mv.version} has no 'nersc_path' tag. " + "Falling back to config defaults." + ) + return None + + # Deserialize all remaining tags; JSON-decode lists/dicts automatically + reserved = {"nersc_path", "alcf_path"} + inference_params: dict[str, Any] = {} + for k, v in tags.items(): + if k in reserved: + continue + try: + inference_params[k] = json.loads(v) + except (json.JSONDecodeError, TypeError): + inference_params[k] = v + + logger.info( + f"Resolved '{model_name}' alias='{alias}' -> v{mv.version} " + f"with {len(inference_params)} inference params." + ) + return ModelCheckpointInfo( + model_name=model_name, + version=mv.version, + alias=alias, + nersc_path=nersc_path, + alcf_path=alcf_path, + inference_params=inference_params, + ) + + +def register_checkpoint( + model_name: str, + nersc_path: str, + config: BeamlineConfig, + alcf_path: str = "", + alias: str = "production", + description: str = "", + inference_params: dict[str, Any] | None = None, +) -> str: + """Register a model checkpoint in the MLflow Model Registry. + + Creates or updates a registered model and logs a new version with + ``nersc_path`` and ``alcf_path`` stored as version-level tags. + The new version is immediately assigned the given alias. + + Args: + model_name: Registered model name, e.g. ``'sam3-petiole'``. + nersc_path: Absolute path to the checkpoint on NERSC CFS. + config: Beamline configuration object. + alcf_path: Absolute path to the checkpoint on ALCF Eagle (optional). + alias: Alias to assign to the new version after registration. + description: Human-readable description for the model version. + inference_params: Model-coupled inference hyperparameters and paths to + store as tags, e.g. batch_size, prompts, conda_env_path. + + Returns: + The new model version string. + + Raises: + mlflow.exceptions.MlflowException: If registration or tagging fails. + """ + client = get_mlflow_client(config) + + try: + client.get_registered_model(model_name) + except mlflow.exceptions.MlflowException: + logger.info(f"Creating registered model '{model_name}'.") + client.create_registered_model(model_name) + + mlflow.set_tracking_uri(config.mlflow["tracking_uri"]) + with mlflow.start_run( + run_name=f"register_{model_name}", + tags={"mlflow.note.content": description}, + ) as run: + mlflow.log_param("nersc_path", nersc_path) + mlflow.log_param("alcf_path", alcf_path) + if inference_params: + mlflow.log_params({ + k: (json.dumps(v) if isinstance(v, (list, dict)) else v) + for k, v in inference_params.items() + }) + run_id = run.info.run_id + + mv = mlflow.register_model(model_uri=f"runs:/{run_id}/model", name=model_name) + version = mv.version + + client.set_model_version_tag(model_name, version, "nersc_path", nersc_path) + if alcf_path: + client.set_model_version_tag(model_name, version, "alcf_path", alcf_path) + + if inference_params: + for k, v in inference_params.items(): + encoded = json.dumps(v) if isinstance(v, (list, dict)) else str(v) + client.set_model_version_tag(model_name, version, k, encoded) + + client.set_registered_model_alias(model_name, alias, version) + logger.info( + f"Registered '{model_name}' v{version} alias='{alias}' " + f"with {len(inference_params or {})} inference params." + ) + return version + + +def log_segmentation_metrics( + run_name: str, + model_name: str, + job_id: str, + config: BeamlineConfig, + timing: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + parent_run_id: str | None = None, +) -> str: + """Log segmentation job metrics as a child MLflow run. + + Creates a child run under ``parent_run_id`` (if provided) and records + timing metrics, SLURM job ID, and any additional params. + + Args: + run_name: Display name for this run. + model_name: Model identifier, used to tag the run for filtering. + job_id: SLURM job ID for this segmentation job. + config: Beamline configuration object. + timing: Timing dict returned by ``_fetch_seg_timing_from_output``. + params: Arbitrary key-value pairs to log as MLflow params. + parent_run_id: If set, this run is nested under the parent. + + Returns: + The MLflow run ID for the logged child run. + """ + tracking_uri = config.mlflow["tracking_uri"] + mlflow.set_tracking_uri(tracking_uri) + + run_tags: dict[str, str] = {"model": model_name, "slurm_job_id": job_id} + + with mlflow.start_run( + run_name=run_name, + nested=parent_run_id is not None, + parent_run_id=parent_run_id, + tags=run_tags, + ) as run: + mlflow.log_param("slurm_job_id", job_id) + mlflow.log_param("model", model_name) + + if params: + mlflow.log_params(params) + + if timing: + metrics: dict[str, float] = {} + if "total_seconds" in timing: + metrics["total_seconds"] = float(timing["total_seconds"]) + if "num_images" in timing: + metrics["num_images"] = float(timing["num_images"]) + if "throughput" in timing: + metrics["throughput_images_per_min"] = float(timing["throughput"]) + if "time_per_image" in timing: + # Stored as "3.23s" — strip the unit + raw = str(timing["time_per_image"]).rstrip("s") + try: + metrics["time_per_image_seconds"] = float(raw) + except ValueError: + pass + if metrics: + mlflow.log_metrics(metrics) + + logger.info(f"Logged segmentation run '{run_name}' as MLflow run {run.info.run_id}") + return run.info.run_id diff --git a/pyproject.toml b/pyproject.toml index 7fee39f6..1b1f7010 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "mkdocs", "mkdocs-material", "mkdocs-mermaid2-plugin", + "mlflow==2.22.0", "numpy>=1.26.4", "pillow", "prefect==3.4.2", diff --git a/requirements.txt b/requirements.txt index 527e9c0d..f61194d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ httpx>=0.22.0 mkdocs mkdocs-material mkdocs-mermaid2-plugin +mlflow==2.22.0 numpy>=1.26.4 pillow prefect==3.4.2