From e09c193bfebe0bd39113d8f9965f561a6bf4dc35 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 6 Jan 2026 10:36:40 -0800 Subject: [PATCH 01/72] initial commit with multinode support --- orchestration/flows/bl832/nersc.py | 375 +++- scripts/perlmutter/sfapi_reconstruction.py | 1755 ++++++++++++++++ .../sfapi_reconstruction_multinode.py | 1761 +++++++++++++++++ scripts/perlmutter/tiff_to_zarr.py | 104 + 4 files changed, 3941 insertions(+), 54 deletions(-) create mode 100644 scripts/perlmutter/sfapi_reconstruction.py create mode 100644 scripts/perlmutter/sfapi_reconstruction_multinode.py create mode 100644 scripts/perlmutter/tiff_to_zarr.py diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 727cbbaf..e1d00522 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -123,7 +123,7 @@ def reconstruct( #SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err #SBATCH -N 1 #SBATCH --ntasks-per-node 1 -#SBATCH --cpus-per-task 64 +#SBATCH --cpus-per-task 128 #SBATCH --time=0:15:00 #SBATCH --exclusive @@ -150,6 +150,10 @@ def reconstruct( echo "Running reconstruction container..." srun podman-hpc run \ +--env NUMEXPR_MAX_THREADS=128 \\ +--env NUMEXPR_NUM_THREADS=128 \\ +--env OMP_NUM_THREADS=128 \\ +--env MKL_NUM_THREADS=128 \\ --volume {recon_scripts_dir}/sfapi_reconstruction.py:/alsuser/sfapi_reconstruction.py \ --volume {pscratch_path}/8.3.2:/alsdata \ --volume {pscratch_path}/8.3.2:/alsuser/ \ @@ -196,6 +200,200 @@ def reconstruct( # Unknown error: cannot recover return False + def reconstruct_multinode( + self, + file_path: str = "", + num_nodes: int = 2, + ) -> bool: + + """ + Use NERSC for tomography reconstruction + + :param file_path: Path to the file to reconstruct + :param num_nodes: Number of nodes to use for parallel reconstruction + """ + logger.info("Starting NERSC reconstruction process.") + + user = self.client.user() + + raw_path = self.config.nersc832_alsdev_raw.root_path + logger.info(f"{raw_path=}") + + recon_image = self.config.ghcr_images832["recon_image"] + logger.info(f"{recon_image=}") + + recon_scripts_dir = self.config.nersc832_alsdev_recon_scripts.root_path + logger.info(f"{recon_scripts_dir=}") + + scratch_path = self.config.nersc832_alsdev_scratch.root_path + logger.info(f"{scratch_path=}") + + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + logger.info(f"{pscratch_path=}") + + path = Path(file_path) + folder_name = path.parent.name + if not folder_name: + folder_name = "" + + file_name = f"{path.stem}.h5" + + logger.info(f"File name: {file_name}") + logger.info(f"Folder name: {folder_name}") + logger.info(f"Number of nodes: {num_nodes}") + + # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately + job_script = f"""#!/bin/bash +#SBATCH -q realtime +#SBATCH -A als +#SBATCH -C cpu +#SBATCH --job-name=tomo_recon_{folder_name}_{file_name} +#SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err +#SBATCH -N {num_nodes} +#SBATCH --ntasks={num_nodes} +#SBATCH --cpus-per-task=128 +#SBATCH --time=0:15:00 +#SBATCH --exclusive + +date +echo "Running reconstruction with {num_nodes} nodes" + +echo "Pre-pulling container image..." +podman-hpc pull {recon_image} + +echo "Creating directory {pscratch_path}/8.3.2/raw/{folder_name}" +mkdir -p {pscratch_path}/8.3.2/raw/{folder_name} +mkdir -p {pscratch_path}/8.3.2/scratch/{folder_name} + +echo "Copying file {raw_path}/{folder_name}/{file_name} to {pscratch_path}/8.3.2/raw/{folder_name}/" +cp {raw_path}/{folder_name}/{file_name} {pscratch_path}/8.3.2/raw/{folder_name} +if [ $? -ne 0 ]; then + echo "Failed to copy data to pscratch." + exit 1 +fi + +chmod 2775 {pscratch_path}/8.3.2/raw/{folder_name} +chmod 2775 {pscratch_path}/8.3.2/scratch/{folder_name} +chmod 664 {pscratch_path}/8.3.2/raw/{folder_name}/{file_name} + +echo "Verifying copied files..." +ls -l {pscratch_path}/8.3.2/raw/{folder_name}/ + +NNODES={num_nodes} +RAW_FILE="{pscratch_path}/8.3.2/raw/{folder_name}/{file_name}" + +# Get the number of slices from the HDF5 file using the container +echo "Reading slice count from HDF5 file..." + +NUM_SLICES=$(podman-hpc run --rm \\ + --volume {pscratch_path}/8.3.2:/alsdata \\ + {recon_image} \\ + python -c " +import h5py +with h5py.File('/alsdata/raw/{folder_name}/{file_name}', 'r') as f: + if '/exchange/data' in f: + print(f['/exchange/data'].shape[1]) + else: + for key in f.keys(): + grp = f[key] + if 'nslices' in grp.attrs: + print(int(grp.attrs['nslices'])) + break +" 2>&1 | grep -E '^[0-9]+$' | head -1) + +echo "Detected NUM_SLICES: $NUM_SLICES" + +if [ -z "$NUM_SLICES" ]; then + echo "Failed to read number of slices from HDF5 file" + exit 1 +fi + +if ! [[ "$NUM_SLICES" =~ ^[0-9]+$ ]]; then + echo "Failed to read number of slices. Got: $NUM_SLICES" + exit 1 +fi + +echo "Total slices: $NUM_SLICES" +echo "Distributing across $NNODES nodes" + +SLICES_PER_NODE=$((NUM_SLICES / NNODES)) +echo "Slices per node: ~$SLICES_PER_NODE" + +# Launch reconstruction on each node +for i in $(seq 0 $((NNODES - 1))); do + SINO_START=$((i * SLICES_PER_NODE)) + + # Last node takes any remainder slices + if [ $i -eq $((NNODES - 1)) ]; then + SINO_END=$NUM_SLICES + else + SINO_END=$(((i + 1) * SLICES_PER_NODE)) + fi + + echo "Launching node $i: slices $SINO_START to $SINO_END" + + srun --nodes=1 --ntasks=1 --exclusive podman-hpc run \ + --env NUMEXPR_MAX_THREADS=128 \ + --env NUMEXPR_NUM_THREADS=128 \ + --env OMP_NUM_THREADS=128 \ + --env MKL_NUM_THREADS=128 \ + --volume {recon_scripts_dir}/sfapi_reconstruction_multinode.py:/alsuser/sfapi_reconstruction_multinode.py \ + --volume {pscratch_path}/8.3.2/raw/{folder_name}:/alsuser/{folder_name} \ + --volume {pscratch_path}/8.3.2/scratch:/scratch \ + {recon_image} \ + bash -c "cd /alsuser && python sfapi_reconstruction_multinode.py {file_name} {folder_name} $SINO_START $SINO_END" & +done + +echo "Waiting for all $NNODES nodes to complete..." +wait +WAIT_STATUS=$? + +if [ $WAIT_STATUS -ne 0 ]; then + echo "One or more reconstruction tasks failed" + exit 1 +fi + +echo "All nodes completed successfully" +date +""" + try: + logger.info("Submitting reconstruction job script to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() # Wait until the job completes + logger.info("Reconstruction job completed successfully.") + return True + + except Exception as e: + logger.info(f"Error during job submission or completion: {e}") + match = re.search(r"Job not found:\s*(\d+)", str(e)) + + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.perlmutter.job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("Reconstruction job completed successfully after recovery.") + return True + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + return False + else: + return False + def build_multi_resolution( self, file_path: str = "", @@ -433,6 +631,7 @@ def schedule_pruning( @flow(name="nersc_recon_flow", flow_run_name="nersc_recon-{file_path}") def nersc_recon_flow( file_path: str, + num_nodes: int = 1, config: Optional[Config832] = None, ) -> bool: """ @@ -453,68 +652,80 @@ def nersc_recon_flow( ) logger.info("NERSC reconstruction controller initialized") - nersc_reconstruction_success = controller.reconstruct( - file_path=file_path, - ) + if num_nodes > 1: + nersc_reconstruction_success = controller.reconstruct_multinode( + file_path=file_path, + num_nodes=num_nodes + ) + elif num_nodes == 1: + nersc_reconstruction_success = controller.reconstruct( + file_path=file_path, + ) + else: + raise ValueError("num_nodes must be at least 1") + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") - nersc_multi_res_success = controller.build_multi_resolution( - file_path=file_path, - ) - logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") - path = Path(file_path) - folder_name = path.parent.name - file_name = path.stem + # Commented out for testing purposes -- should be re-enabled for production - tiff_file_path = f"{folder_name}/rec{file_name}" - zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + # nersc_multi_res_success = controller.build_multi_resolution( + # file_path=file_path, + # ) + # logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") - logger.info(f"{tiff_file_path=}") - logger.info(f"{zarr_file_path=}") + # path = Path(file_path) + # folder_name = path.parent.name + # file_name = path.stem - # Transfer reconstructed data - logger.info("Preparing transfer.") - transfer_controller = get_transfer_controller( - transfer_type=CopyMethod.GLOBUS, - config=config - ) + # tiff_file_path = f"{folder_name}/rec{file_name}" + # zarr_file_path = f"{folder_name}/rec{file_name}.zarr" - logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") - transfer_controller.copy( - file_path=tiff_file_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.nersc832_alsdev_scratch - ) + # logger.info(f"{tiff_file_path=}") + # logger.info(f"{zarr_file_path=}") - transfer_controller.copy( - file_path=zarr_file_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.nersc832_alsdev_scratch - ) - - logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") - transfer_controller.copy( - file_path=tiff_file_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch - ) - - transfer_controller.copy( - file_path=zarr_file_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch - ) - - logger.info("Scheduling pruning tasks.") - schedule_pruning( - config=config, - raw_file_path=file_path, - tiff_file_path=tiff_file_path, - zarr_file_path=zarr_file_path - ) + # Transfer reconstructed data + # logger.info("Preparing transfer.") + # transfer_controller = get_transfer_controller( + # transfer_type=CopyMethod.GLOBUS, + # config=config + # ) + + # logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") + # transfer_controller.copy( + # file_path=tiff_file_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.nersc832_alsdev_scratch + # ) + + # transfer_controller.copy( + # file_path=zarr_file_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.nersc832_alsdev_scratch + # ) + + # logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") + # transfer_controller.copy( + # file_path=tiff_file_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch + # ) + + # transfer_controller.copy( + # file_path=zarr_file_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch + # ) + + # logger.info("Scheduling pruning tasks.") + # schedule_pruning( + # config=config, + # raw_file_path=file_path, + # tiff_file_path=tiff_file_path, + # zarr_file_path=zarr_file_path + # ) # TODO: Ingest into SciCat - if nersc_reconstruction_success and nersc_multi_res_success: + if nersc_reconstruction_success: return True else: return False @@ -549,10 +760,66 @@ def nersc_streaming_flow( if __name__ == "__main__": config = Config832() + + start = time.time() + nersc_recon_flow( + file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + num_nodes=8, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") + print(f"Total reconstruction time with 8 nodes: {end - start} seconds") + + start = time.time() nersc_recon_flow( file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + num_nodes=8, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") + print(f"Total reconstruction time with 8 nodes: {end - start} seconds") + + start = time.time() + nersc_recon_flow( + file_path="dabramov/20251218_111600_silkraw.h5", + num_nodes=8, config=config ) + end = time.time() + logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") + print(f"Total reconstruction time with 8 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + # num_nodes=4, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + # num_nodes=2, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 2 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + # num_nodes=1, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") + # print(f"Total reconstruction time with 1 node: {end - start} seconds") # nersc_streaming_flow( # config=config, # walltime=datetime.timedelta(minutes=5) diff --git a/scripts/perlmutter/sfapi_reconstruction.py b/scripts/perlmutter/sfapi_reconstruction.py new file mode 100644 index 00000000..d6a72081 --- /dev/null +++ b/scripts/perlmutter/sfapi_reconstruction.py @@ -0,0 +1,1755 @@ +from __future__ import print_function +import time +import h5py +import numpy as np +import numexpr as ne +import skimage.transform as st +import os +import sys +import scipy.ndimage.filters as snf +import concurrent.futures as cf +import warnings +import stat +from pathlib import Path + +#import xlrd # for importing excel spreadsheets +#from ast import literal_eval # For converting string to tuple + +try: + import tomopy + from tomopy.util import mproc +except: + print("warning: tomopy is not available") + +try: + import dxchange +except: + print("warning: dxchange is not available") + +# run this from the command line: +# python tomopy832.py +# it requires a separate file, which contains at minimum a list of filenames +# on separate lines. Default name of this file is input832.txt, but you can use any +# filename and run from the commandline as +# python tomopy832.py yourinputfile.txt +# If desired, on each line (separated by spaces) you can +# include parameters to override the defaults. +# to do this you need pairs, first the name of the variable, then the desired value +# For True/False, use 1/0. +# You can generate these input files in excel, in which case use tab-separated +# (or space separated). Some input overrides require multiple values, +# these should be comma-separated (with no spaces). Example is sinoused +# which would be e.g. 500,510,1 to get slices 500 through 509. For sinoused, +# you can use first value -1 and second value number of slices to get that number +# of slices from the middle of the stack. +# an example of the contents of the input file look like this: + +# filename.h5 cor 1196 sinoused "-1,10,1" doPhaseRetrieval 0 outputFilename c1196.0 +# filename.h5 cor 1196.5 sinoused "-1,10,1" doPhaseRetrieval 0 outputFilename c1196.5 + +# this was generated in excel and saved as txt tab separated, so the quotes were +# added automatically by excel. Note also that for parameters expecting strings as +# input (outputFilename for example), the program will choke if you put in a number. + +# if cor is not defined in the parameters file, automated cor detection will happen + +# chunk_proj and chunk_sino handle memory management. +# If you are running out of memory, make one or both of those smaller. + +slice_dir = { + 'write_raw': 'proj', + 'remove_outlier1d': 'sino', + 'remove_outlier2d': 'proj', + 'normalize_nf': 'sino', + 'normalize': 'both', + 'minus_log': 'both', + 'beam_hardening': 'both', + 'remove_stripe_fw': 'sino', + 'remove_stripe_ti': 'sino', + 'remove_stripe_sf': 'sino', + 'do_360_to_180': 'sino', + 'correcttilt': 'proj', + 'lensdistortion': 'proj', + 'phase_retrieval': 'proj', + 'recon_mask': 'sino', + 'polar_ring': 'sino', + 'polar_ring2': 'sino', + 'castTo8bit': 'both', + 'write_reconstruction': 'both', + 'write_normalized': 'proj', +} + + +def recon_setup( + filename, + filetype = 'dxfile', #other options are als, als1131, sls + timepoint = 0, + bffilename = None, + inputPath = './', # input path, location of the data set to reconstruct + outputPath=None, + # define an output path (default is inputPath), a sub-folder will be created based on file name + outputFilename=None, + # file name for output tif files (a number and .tiff will be added). default is based on input filename + fulloutputPath=None, # definte the full output path, no automatic sub-folder will be created + doOutliers1D=False, # outlier removal in 1d (along sinogram columns) + outlier_diff1D=750, # difference between good data and outlier data (outlier removal) + outlier_size1D=3, # radius around each pixel to look for outliers (outlier removal) + doOutliers2D=False, # outlier removal, standard 2d on each projection + outlier_diff2D=750, # difference between good data and outlier data (outlier removal) + outlier_size2D=3, # radius around each pixel to look for outliers (outlier removal) + doFWringremoval=True, # Fourier-wavelet ring removal + doTIringremoval=False, # Titarenko ring removal + doSFringremoval=False, # Smoothing filter ring removal + ringSigma=3, # damping parameter in Fourier space (Fourier-wavelet ring removal) + ringLevel=8, # number of wavelet transform levels (Fourier-wavelet ring removal) + ringWavelet='db5', # type of wavelet filter (Fourier-wavelet ring removal) + ringNBlock=0, # used in Titarenko ring removal (doTIringremoval) + ringAlpha=1.5, # used in Titarenko ring removal (doTIringremoval) + ringSize=5, # used in smoothing filter ring removal (doSFringremoval) + doPhaseRetrieval=False, # phase retrieval + alphaReg=0.00001, # smaller = smoother (used for phase retrieval) + propagation_dist=75.0, # sample-to-scintillator distance (phase retrieval) + kev=24.0, # energy level (phase retrieval) + butterworth_cutoff=0.25, # 0.1 would be very smooth, 0.4 would be very grainy (reconstruction) + butterworth_order=2, # for reconstruction + doTranslationCorrection=False, # correct for linear drift during scan + xshift=0, # undesired dx transation correction (from 0 degree to 180 degree proj) + yshift=0, # undesired dy transation correction (from 0 degree to 180 degree proj) + doPolarRing=False, # ring removal + Rarc=30, # min angle needed to be considered ring artifact (ring removal) + Rmaxwidth=100, # max width of rings to be filtered (ring removal) + Rtmax=3000.0, # max portion of image to filter (ring removal) + Rthr=3000.0, # max value of offset due to ring artifact (ring removal) + Rtmin=-3000.0, # min value of image to filter (ring removal) + doPolarRing2=False, # ring removal + Rarc2=30, # min angle needed to be considered ring artifact (ring removal) + Rmaxwidth2=100, # max width of rings to be filtered (ring removal) + Rtmax2=3000.0, # max portion of image to filter (ring removal) + Rthr2=3000.0, # max value of offset due to ring artifact (ring removal) + Rtmin2=-3000.0, # min value of image to filter (ring removal) + cor=None, # center of rotation (float). If not used then cor will be detected automatically + corFunction='pc', # center of rotation function to use - can be 'pc', 'vo', or 'nm', or use 'skip' to return tomo variable without having to do a calc. + corLoadMinimalBakDrk=True, #during cor detection, only load the first dark field and first flat field rather than all of them, to minimize file loading time for cor detection. + voInd=None, # index of slice to use for cor search (vo) + voSMin=-150, # min radius for searching in sinogram (vo) + voSMax=150, # max radius for searching in sinogram (vo) + voSRad=6, # search radius (vo) + voStep=0.25, # search step (vo) + voRatio=0.5, # ratio of field-of-view and object size (vo) + voDrop=20, # drop lines around vertical center of mask (vo) + nmInd=None, # index of slice to use for cor search (nm) + nmInit=None, # initial guess for center (nm) + nmTol=0.5, # desired sub-pixel accuracy (nm) + nmMask=True, # if True, limits analysis to circular region (nm) + nmRatio=1.0, # ratio of radius of circular mask to edge of reconstructed image (nm) + nmSinoOrder=False, # if True, analyzes in sinogram space. If False, analyzes in radiograph space + use360to180=False, # use 360 to 180 conversion + castTo8bit=False, # convert data to 8bit before writing + cast8bit_min=-10, # min value if converting to 8bit + cast8bit_max=30, # max value if converting to 8bit + useNormalize_nf=False, # normalize based on background intensity (nf) + chunk_proj=100, # chunk size in projection direction + chunk_sino=100, # chunk size in sinogram direction + npad=None, # amount to pad data before reconstruction + projused=None, + # should be slicing in projection dimension (start,end,step) Be sure to add one to the end as stop in python means the last value is omitted + sinoused=None, + # should be sliceing in sinogram dimension (start,end,step). If first value is negative, it takes the number of slices from the second value in the middle of the stack. + correcttilt=0, # tilt dataset + tiltcenter_slice=None, # tilt center (x direction) + tiltcenter_det=None, # tilt center (y direction) + angle_offset=0, + # this is the angle offset from our default (270) so that tomopy yields output in the same orientation as previous software (Octopus) + anglelist=None, + # if not set, will assume evenly spaced angles which will be calculated by the angular range and number of angles found in the file. if set to -1, will read individual angles from each image. alternatively, a list of angles can be passed. + doBeamHardening=False, + # turn on beam hardening correction, based on "Correction for beam hardening in computed tomography", Gabor Herman, 1979 Phys. Med. Biol. 24 81 + BeamHardeningCoefficients=None, # 6 values, tomo = a0 + a1*tomo + a2*tomo^2 + a3*tomo^3 + a4*tomo^4 + a5*tomo^5 + projIgnoreList=None, + # projections to be ignored in the reconstruction (for simplicity in the code, they will not be removed and will be processed as all other projections but will be set to zero absorption right before reconstruction. + bfexposureratio=1, # ratio of exposure time of bf to exposure time of sample + dorecon=True, #do the tomographic reconstruction + writeraw = False, + writenormalized=False, + writereconstruction=True, + doNormalize=True, + dominuslog=True, + slsnumangles=1000, + slspxsize=0.00081, + verbose_printing=False, + recon_algorithm='gridrec', # choose from gridrec, fbp, and others in tomopy + dolensdistortion=False, + lensdistortioncenter=(1280,1080), + lensdistortionfactors = (1.00015076, 1.9289e-06, -2.4325e-08, 1.00439e-11, -3.99352e-15), + minimum_transmission = 0.01, + *args, **kwargs + ): + + + outputFilename = os.path.splitext(filename)[0] if outputFilename is None else outputFilename + # outputPath = inputPath + 'rec' + os.path.splitext(filename)[0] + '/' if outputPath is None else outputPath + 'rec' + os.path.splitext(filename)[0] + '/' + outputPath = os.path.join(inputPath, 'rec' + outputFilename) if outputPath is None else os.path.join(outputPath,'rec' + outputFilename) + fulloutputPath = outputPath if fulloutputPath is None else fulloutputPath + tempfilenames = [os.path.join(fulloutputPath,'tmp0.h5'), os.path.join(fulloutputPath, 'tmp1.h5')] + + if verbose_printing: + print("cleaning up previous temp files", end="") + for tmpfile in tempfilenames: + try: + os.remove(tmpfile) + except OSError: + pass + if verbose_printing: + print(", reading metadata") + + if (filetype == 'als') or (filetype == 'als1131'): + datafile = h5py.File(os.path.join(inputPath,filename), 'r') + gdata = dict(dxchange.reader._find_dataset_group(datafile).attrs) + pxsize = float(gdata['pxsize']) / 10 # /10 to convert units from mm to cm + numslices = int(gdata['nslices']) + numangles = int(gdata['nangles']) + angularrange = float(gdata['arange']) + numrays = int(gdata['nrays']) + inter_bright = int(gdata['i0cycle']) + + + + dgroup = dxchange.reader._find_dataset_group(datafile) + keys = list(gdata.keys()) + if 'num_dark_fields' in keys: + ndark = int(gdata['num_dark_fields']) + else: + ndark = dxchange.reader._count_proj(dgroup, dgroup.name.split('/')[-1] + 'drk_0000.tif', numangles, inter_bright=-1) #for darks, don't want to divide out inter_bright for counting projections + ind_dark = list(range(0, ndark)) + group_dark = [numangles - 1] + + if 'num_bright_field' in keys: + nflat = int(gdata['num_bright_field']) + else: + nflat = dxchange.reader._count_proj(dgroup, dgroup.name.split('/')[-1] + 'bak_0000.tif', numangles, inter_bright=inter_bright) + ind_flat = list(range(0, nflat)) + + # figure out the angle list (a list of angles, one per projection image) + dtemp = datafile[list(datafile.keys())[0]] + fltemp = list(dtemp.keys()) + firstangle = float(dtemp[fltemp[0]].attrs.get('rot_angle', 0)) + if anglelist is None: + # the offset angle should offset from the angle of the first image, which is usually 0, but in the case of timbir data may not be. + # we add the 270 to be inte same orientation as previous software used at bl832 + angle_offset = 270 + angle_offset - firstangle + anglelist = tomopy.angles(numangles, angle_offset, angle_offset - angularrange) + elif anglelist == -1: + anglelist = np.zeros(shape=numangles) + for icount in range(0, numangles): + anglelist[icount] = np.pi / 180 * (270 + angle_offset - float(dtemp[fltemp[icount]].attrs['rot_angle'])) + if inter_bright > 0: + group_flat = list(range(0, numangles, inter_bright)) + if group_flat[-1] != numangles - 1: + group_flat.append(numangles - 1) + elif inter_bright == 0: + group_flat = [0, numangles - 1] + else: + group_flat = None + elif filetype == 'dxfile': + numangles = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/rotation/num_angles")[0]) + angularrange = dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/rotation/range")[0] + anglelist = dxchange.read_hdf5(os.path.join(inputPath, filename), '/exchange/theta', slc=None) + if anglelist is None: + try: + # See if the rotation start, step, num_angles are in the file + rotation_start = dxchange.read_hdf5(os.path.join(inputPath, filename), + '/process/acquisition/rotation/rotation_start')[0] + rotation_step = dxchange.read_hdf5(os.path.join(inputPath, filename), + '/process/acquisition/rotation/rotation_step')[0] + anglelist = rotation_start + rotation_step * range(numangles) + except: + anglelist = np.linspace(0. - angle_offset, angularrange, numangles) + anglelist = anglelist - angle_offset + anglelist = np.deg2rad(anglelist) + anglelist = -anglelist + numslices = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/detector/dimension_y")[0]) + numrays = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/detector/dimension_x")[0]) + pxsize = dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/detector/pixel_size")[0] / 10.0 # /10 to convert units from mm to cm + inter_bright = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/flat_fields/i0cycle")[0]) + group_flat = [0, numangles - 1] + nflat = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/flat_fields/num_flat_fields")[0]) + ind_flat = list(range(0, nflat)) + ndark = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/dark_fields/num_dark_fields")[0]) + ind_dark = list(range(0, ndark)) + propagation_dist = dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/camera_motor_stack/setup/camera_distance")[0] + if (propagation_dist == 0): + propagation_dist = dxchange.read_hdf5(os.path.join(inputPath, filename), + "/measurement/instrument/camera_motor_stack/setup/camera_distance")[1] + kev = dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/monochromator/energy")[0] / 1000 + if (kev == 0): + kev = dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/monochromator/energy")[ + 1] / 1000 + + if (isinstance(kev, int) or isinstance(kev, float)): + if kev > 1000: + kev = 30.0 + else: + kev = 30.0 + elif filetype == 'sls': + datafile = h5py.File(os.path.join(inputPath, filename), 'r') + slsdata = datafile["exchange/data"] + numslices = slsdata.shape[1] + numrays = slsdata.shape[2] + pxsize = slspxsize + numangles = slsnumangles + _, _, _, anglelist = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles,(timepoint+1)*numangles,1), sino=(0,1,1)) #dtype=None, , ) + angularrange = np.abs(anglelist[-1]-anglelist[0]) + inter_bright = 0 + group_flat = [0, numangles - 1] + nflat = 1 #this variable is not used for sls data + ind_flat = list(range(0, nflat)) + else: + print("Not sure what file type, gotta break.") + return + + npad = int(np.ceil(numrays * np.sqrt(2)) - numrays) // 2 if npad is None else npad + if projused is not None and (projused[1] > numangles - 1 or projused[0] < 0): # allows program to deal with out of range projection values + if projused[1] > numangles: + print("End Projection value greater than number of angles. Value has been lowered to the number of angles " + str(numangles)) + projused = (projused[0], numangles, projused[2]) + if projused[0] < 0: + print("Start Projection value less than zero. Value raised to 0") + projused = (0, projused[1], projused[2]) + if projused is None: + projused = (0, numangles, 1) + else: + # if projused is different than default, need to change numangles and angularrange; dula attempting to do this with these two lines, we'll see if it works! 11/16/17 + angularrange = (angularrange / (numangles - 1)) * (projused[1] - projused[0]) + #dula updated to use anglelist to find angular rage, 11 june 2020, not sure if it will work?? + angularrange = np.abs(anglelist[projused[1]] - anglelist[projused[0]]) + # want angular range to stay constant if we keep the end values consistent + numangles = len(range(projused[0], projused[1], projused[2])) + + ind_tomo = list(range(0, numangles)) + floc_independent = dxchange.reader._map_loc(ind_tomo, group_flat) + + # figure out how user can pass to do central x number of slices, or set of slices dispersed throughout (without knowing a priori the value of numslices) + if sinoused is None: + sinoused = (0, numslices, 1) + elif sinoused[0] < 0: + sinoused = (int(np.floor(numslices / 2.0) - np.ceil(sinoused[1] / 2.0)), int(np.floor(numslices / 2.0) + np.floor(sinoused[1] / 2.0)), 1) + + if verbose_printing: + print('There are ' + str(numslices) + ' sinograms, ' + str(numrays) + ' rays, and ' + str(numangles) + ' projections, with an angular range of ' +str(angularrange) + '.') + print('Looking at sinograms ' + str(sinoused[0]) + ' through ' + str(sinoused[1]-1) + ' (inclusive) in steps of ' + str(sinoused[2])) + + BeamHardeningCoefficients = (0, 1, 0, 0, 0, .004) if BeamHardeningCoefficients is None else BeamHardeningCoefficients + + if cor is None: + if verbose_printing: + print("Detecting center of rotation", end="") + + if angularrange > 300: + lastcor = int(np.floor(numangles / 2) - 1) + else: + lastcor = numangles - 1 + # I don't want to see the warnings about the reader using a deprecated variable in dxchange + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als') or (filetype == 'als1131'): + if corLoadMinimalBakDrk: + ind_dark = 0 + ind_flat = 0 + if filetype == 'als': + tomo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath, filename), ind_tomo=(0, lastcor),ind_dark=ind_dark,ind_flat=ind_flat) + else: + tomo, flat, dark, floc = read_als_1131h5(os.path.join(inputPath, filename), + ind_tomo=(0, lastcor), ind_dark=ind_dark, + ind_flat=ind_flat) + elif filetype == 'dxfile': + # if corLoadMinimalBakDrk: + # ind_dark = 0 + # ind_flat = 0 + # tomo, flat, dark, coranglelist, _ = dxchange.exchange.read_dx(os.path.join(inputPath, filename), proj=(0,numangles-1),ind_dark=ind_dark,ind_flat=ind_flat) + # tomo, flat, dark, coranglelist, _ = dxchange.exchange.read_dx(os.path.join(inputPath, filename), proj=(0,lastcor,lastcor-1)) + tomo, flat, dark, coranglelist = dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), proj=(0,lastcor,lastcor-1)) + elif (filetype == 'sls'): + tomo, flat, dark, coranglelist = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=( + timepoint * numangles, (timepoint + 1) * numangles, numangles - 1)) # dtype=None, , ) + else: + return + if bffilename is not None and (filetype == 'als'): + tomobf, flatbf, darkbf, flocbf = dxchange.read_als_832h5(os.path.join(inputPath, bffilename)) + flat = tomobf + tomo = tomo.astype(np.float32) + if useNormalize_nf and ((filetype == 'als') or (filetype == 'als1131')): + tomopy.normalize_nf(tomo, flat, dark, floc, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + else: + tomopy.normalize(tomo, flat, dark, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + + if corFunction == 'vo': + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als'): + tomovo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath, filename), sino=(sinoused[0],sinoused[0]+1,1)) + elif (filetype == 'sls'): + tomovo, flat, dark, coranglelist = read_sls(os.path.join(inputPath, filename), exchange_rank=0, sino=(sinoused[0],sinoused[0]+1,1), proj=(timepoint*numangles+projused[0],timepoint*numangles+projused[1],projused[2])) # dtype=None, , ) + else: + return + if bffilename is not None: + if (filetype == 'als'): + tomobf, flatbf, darkbf, flocbf = dxchange.read_als_832h5(os.path.join(inputPath, bffilename), sino=(sinoused[0],sinoused[0]+1,1)) + flat = tomobf + elif (filetype == 'als1131'): + tomobf, flatbf, darkbf, flocbf = read_als_1131h5(os.path.join(inputPath, bffilename), + sino=(sinoused[0], sinoused[0] + 1, 1)) + flat = tomobf + tomovo = tomovo.astype(np.float32) + + if useNormalize_nf and ((filetype == 'als') or (filetype == 'als1131')): + tomopy.normalize_nf(tomovo, flat, dark, floc, out=tomovo) + if bfexposureratio != 1: + tomovo = tomovo * bfexposureratio + else: + tomopy.normalize(tomovo, flat, dark, out=tomovo) + if bfexposureratio != 1: + tomovo = tomovo * bfexposureratio + + cor = tomopy.find_center_vo(tomovo, ind=voInd, smin=voSMin, smax=voSMax, srad=voSRad, step=voStep, + ratio=voRatio, drop=voDrop) + + + elif corFunction == 'nm': + cor = tomopy.find_center(tomo, tomopy.angles(numangles, angle_offset, angle_offset - angularrange), + ind=nmInd, init=nmInit, tol=nmTol, mask=nmMask, ratio=nmRatio, + sinogram_order=nmSinoOrder) + elif corFunction == 'pc': + if angularrange > 300: + lastcor = int(np.floor(numangles / 2) - 1) + else: + lastcor = numangles - 1 + # I don't want to see the warnings about the reader using a deprecated variable in dxchange + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als'): + tomo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath, filename), ind_tomo=(0, lastcor)) + elif (filetype == 'als1131'): + tomo, flat, dark, floc = read_als_1131h5(os.path.join(inputPath, filename), ind_tomo=(0, lastcor)) + elif (filetype == 'dxfile'): + tomo, flat, dark, coranglelist = dxchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, proj=( + 0, lastcor, lastcor-1)) # dtype=None, , ) + elif (filetype == 'sls'): + tomo, flat, dark, coranglelist = read_sls(os.path.join(inputPath, filename), exchange_rank=0, proj=( + timepoint * numangles, (timepoint + 1) * numangles, numangles - 1)) # dtype=None, , ) + else: + return + if bffilename is not None: + if (filetype == 'als'): + tomobf, flatbf, darkbf, flocbf = dxchange.read_als_832h5(os.path.join(inputPath, bffilename)) + flat = tomobf + elif (filetype == 'als1131'): + tomobf, flatbf, darkbf, flocbf = read_als_1131h5(os.path.join(inputPath, bffilename)) + flat = tomobf + tomo = tomo.astype(np.float32) + if useNormalize_nf and ((filetype == 'als') or (filetype == 'als1131')): + tomopy.normalize_nf(tomo, flat, dark, floc, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + else: + tomopy.normalize(tomo, flat, dark, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + cor = tomopy.find_center_pc(tomo[0], tomo[-1], tol=0.25) + elif corFunction == 'skip': #use this to get back the tomo variable without running processing + cor = numrays/2 + else: + raise ValueError("\'corFunction\' must be one of: [ pc, vo, nm ].") + if verbose_printing: + print(", {}".format(cor)) + else: + tomo = 0 + if verbose_printing: + print("using user input center of {}".format(cor)) + + function_list = [] + + if writeraw: + function_list.append('write_raw') + if doOutliers1D: + function_list.append('remove_outlier1d') + if doOutliers2D: + function_list.append('remove_outlier2d') + if doNormalize: + if useNormalize_nf: + function_list.append('normalize_nf') + else: + function_list.append('normalize') + if dominuslog: + function_list.append('minus_log') + if doBeamHardening: + function_list.append('beam_hardening') + if doFWringremoval: + function_list.append('remove_stripe_fw') + if doTIringremoval: + function_list.append('remove_stripe_ti') + if doSFringremoval: + function_list.append('remove_stripe_sf') + if correcttilt: + function_list.append('correcttilt') + if dolensdistortion: + function_list.append('lensdistortion') + if use360to180: + function_list.append('do_360_to_180') + if doPhaseRetrieval: + function_list.append('phase_retrieval') + if dorecon: + function_list.append('recon_mask') + if doPolarRing: + if dorecon: + function_list.append('polar_ring') + if doPolarRing2: + if dorecon: + function_list.append('polar_ring2') + if castTo8bit: + if dorecon: + function_list.append('castTo8bit') + if writereconstruction: + if dorecon: + function_list.append('write_reconstruction') + if writenormalized: + function_list.append('write_normalized') + + recon_dict = { + "inputPath": inputPath, #input file path + "filename": filename, #input file name + "filetype": filetype, #other options are als, als1131, sls + "timepoint": timepoint, + "fulloutputPath": fulloutputPath, + "outputFilename": outputFilename, + "bffilename": bffilename, #if there is a separate file with the bright fields + "doOutliers1D": doOutliers1D, # outlier removal in 1d (along sinogram columns) + "outlier_diff1D": outlier_diff1D, # difference between good data and outlier data (outlier removal) + "outlier_size1D": outlier_size1D, # radius around each pixel to look for outliers (outlier removal) + "doOutliers2D": doOutliers2D, # outlier removal, standard 2d on each projection + "outlier_diff2D": outlier_diff2D, # difference between good data and outlier data (outlier removal) + "outlier_size2D": outlier_size2D, # radius around each pixel to look for outliers (outlier removal) + "doFWringremoval": doFWringremoval, # Fourier-wavelet ring removal + "doTIringremoval": doTIringremoval, # Titarenko ring removal + "doSFringremoval": doSFringremoval, # Smoothing filter ring removal + "ringSigma": ringSigma, # damping parameter in Fourier space (Fourier-wavelet ring removal) + "ringLevel": ringLevel, # number of wavelet transform levels (Fourier-wavelet ring removal) + "ringWavelet": ringWavelet, # type of wavelet filter (Fourier-wavelet ring removal) + "ringNBlock": ringNBlock, # used in Titarenko ring removal (doTIringremoval) + "ringAlpha": ringAlpha, # used in Titarenko ring removal (doTIringremoval) + "ringSize": ringSize, # used in smoothing filter ring removal (doSFringremoval) + "doPhaseRetrieval": doPhaseRetrieval, # phase retrieval + "alphaReg": alphaReg, # smaller = smoother (used for phase retrieval) + "propagation_dist": propagation_dist, # sample-to-scintillator distance (phase retrieval) + "kev": kev, # energy level (phase retrieval) + "butterworth_cutoff": butterworth_cutoff, # 0.1 would be very smooth, 0.4 would be very grainy (reconstruction) + "butterworth_order": butterworth_order, # for reconstruction + "doTranslationCorrection": doTranslationCorrection, # correct for linear drift during scan + "xshift": xshift, # undesired dx transation correction (from 0 degree to 180 degree proj) + "yshift": yshift, # undesired dy transation correction (from 0 degree to 180 degree proj) + "doPolarRing": doPolarRing, # ring removal + "Rarc": Rarc, # min angle needed to be considered ring artifact (ring removal) + "Rmaxwidth": Rmaxwidth, # max width of rings to be filtered (ring removal) + "Rtmax": Rtmax, # max portion of image to filter (ring removal) + "Rthr": Rthr, # max value of offset due to ring artifact (ring removal) + "Rtmin": Rtmin, # min value of image to filter (ring removal) + "doPolarRing2": doPolarRing2, # ring removal + "Rarc2": Rarc2, # min angle needed to be considered ring artifact (ring removal) + "Rmaxwidth2": Rmaxwidth2, # max width of rings to be filtered (ring removal) + "Rtmax2": Rtmax2, # max portion of image to filter (ring removal) + "Rthr2": Rthr2, # max value of offset due to ring artifact (ring removal) + "Rtmin2": Rtmin2, # min value of image to filter (ring removal) + "cor": cor, # center of rotation (float). If not used then cor will be detected automatically + "corFunction": corFunction, # center of rotation function to use - can be 'pc', 'vo', or 'nm' + "voInd": voInd, # index of slice to use for cor search (vo) + "voSMin": voSMin, # min radius for searching in sinogram (vo) + "voSMax": voSMax, # max radius for searching in sinogram (vo) + "voSRad": voSRad, # search radius (vo) + "voStep": voStep, # search step (vo) + "voRatio": voRatio, # ratio of field-of-view and object size (vo) + "voDrop": voDrop, # drop lines around vertical center of mask (vo) + "nmInd": nmInd, # index of slice to use for cor search (nm) + "nmInit": nmInit, # initial guess for center (nm) + "nmTol": nmTol, # desired sub-pixel accuracy (nm) + "nmMask": nmMask, # if True, limits analysis to circular region (nm) + "nmRatio": nmRatio, # ratio of radius of circular mask to edge of reconstructed image (nm) + "nmSinoOrder": nmSinoOrder, # if True, analyzes in sinogram space. If False, analyzes in radiograph space + "use360to180": use360to180, # use 360 to 180 conversion + "castTo8bit": castTo8bit, # convert data to 8bit before writing + "cast8bit_min": cast8bit_min, # min value if converting to 8bit + "cast8bit_max": cast8bit_max, # max value if converting to 8bit + "useNormalize_nf": useNormalize_nf, # normalize based on background intensity (nf) + "chunk_proj": chunk_proj, # chunk size in projection direction + "chunk_sino": chunk_sino, # chunk size in sinogram direction + "npad": npad, # amount to pad data before reconstruction + "projused": projused, # should be slicing in projection dimension (start,end,step) Be sure to add one to the end as stop in python means the last value is omitted + "sinoused": sinoused, # should be sliceing in sinogram dimension (start,end,step). If first value is negative, it takes the number of slices from the second value in the middle of the stack. + "correcttilt": correcttilt, # tilt dataset + "tiltcenter_slice": tiltcenter_slice, # tilt center (x direction) + "tiltcenter_det": tiltcenter_det, # tilt center (y direction) + "angle_offset": angle_offset, # this is the angle offset from our default (270) so that tomopy yields output in the same orientation as previous software (Octopus) + "anglelist": anglelist, # if not set, will assume evenly spaced angles which will be calculated by the angular range and number of angles found in the file. if set to -1, will read individual angles from each image. alternatively, a list of angles can be passed. + "doBeamHardening": doBeamHardening, # turn on beam hardening correction, based on "Correction for beam hardening in computed tomography", Gabor Herman, 1979 Phys. Med. Biol. 24 81 + "BeamHardeningCoefficients": BeamHardeningCoefficients, # 6 values, tomo = a0 + a1*tomo + a2*tomo^2 + a3*tomo^3 + a4*tomo^4 + a5*tomo^5 + "projIgnoreList": projIgnoreList, # projections to be ignored in the reconstruction (for simplicity in the code, they will not be removed and will be processed as all other projections but will be set to zero absorption right before reconstruction. + "bfexposureratio": bfexposureratio, # ratio of exposure time of bf to exposure time of sample + "pxsize": pxsize, + "numslices": numslices, + "numangles": numangles, + "angularrange": angularrange, + "numrays": numrays, + "npad": npad, + "projused": projused, + "inter_bright": inter_bright, + "nflat": nflat, + "ind_flat": ind_flat, + "ndark": nflat, + "ind_dark": ind_flat, + "group_flat": group_flat, + "ind_tomo": ind_tomo, + "floc_independent": floc_independent, + "sinoused": sinoused, + "BeamHardeningCoefficients": BeamHardeningCoefficients, + "function_list": function_list, + "dorecon": dorecon, + "doNormalize": doNormalize, + "writeraw": writeraw, + "writenormalized": writenormalized, + "writereconstruction": writereconstruction, + "dominuslog": dominuslog, + "verbose_printing": verbose_printing, + "recon_algorithm": recon_algorithm, + "dolensdistortion": dolensdistortion, + "lensdistortioncenter": lensdistortioncenter, + "lensdistortionfactors": lensdistortionfactors, + "minimum_transmission": minimum_transmission, + } + + #return second variable tomo, (first and last normalized image), to use it for manual COR checking + return recon_dict, tomo + + + +# to profile memory, uncomment the following line +# and then run program from command line as +# python -m memory_profiler tomopy832.py +# (you have to have memory_profiler installed) +# @profile +def recon( + filename, + filetype = 'als', #other options are als, als1131, sls + timepoint = 0, + bffilename = None, + inputPath = './', #input path, location of the data set to reconstruct + outputFilename = None, #file name for output tif files (a number and .tiff will be added). default is based on input filename + fulloutputPath = None, # definte the full output path, no automatic sub-folder will be created + doOutliers1D = False, # outlier removal in 1d (along sinogram columns) + outlier_diff1D = 750, # difference between good data and outlier data (outlier removal) + outlier_size1D = 3, # radius around each pixel to look for outliers (outlier removal) + doOutliers2D = False, # outlier removal, standard 2d on each projection + outlier_diff2D = 750, # difference between good data and outlier data (outlier removal) + outlier_size2D = 3, # radius around each pixel to look for outliers (outlier removal) + doFWringremoval = True, # Fourier-wavelet ring removal + doTIringremoval = False, # Titarenko ring removal + doSFringremoval = False, # Smoothing filter ring removal + ringSigma = 3, # damping parameter in Fourier space (Fourier-wavelet ring removal) + ringLevel = 8, # number of wavelet transform levels (Fourier-wavelet ring removal) + ringWavelet = 'db5', # type of wavelet filter (Fourier-wavelet ring removal) + ringNBlock = 0, # used in Titarenko ring removal (doTIringremoval) + ringAlpha = 1.5, # used in Titarenko ring removal (doTIringremoval) + ringSize = 5, # used in smoothing filter ring removal (doSFringremoval) + doPhaseRetrieval = False, # phase retrieval + alphaReg = 0.00001, # smaller = smoother (used for phase retrieval) + propagation_dist = 75.0, # sample-to-scintillator distance (phase retrieval) + kev = 24.0, # energy level (phase retrieval) + butterworth_cutoff = 0.25, #0.1 would be very smooth, 0.4 would be very grainy (reconstruction) + butterworth_order = 2, # for reconstruction + doTranslationCorrection = False, # correct for linear drift during scan + xshift = 0, # undesired dx transation correction (from 0 degree to 180 degree proj) + yshift = 0, # undesired dy transation correction (from 0 degree to 180 degree proj) + doPolarRing = False, # ring removal + Rarc=30, # min angle needed to be considered ring artifact (ring removal) + Rmaxwidth=100, # max width of rings to be filtered (ring removal) + Rtmax=3000.0, # max portion of image to filter (ring removal) + Rthr=3000.0, # max value of offset due to ring artifact (ring removal) + Rtmin=-3000.0, # min value of image to filter (ring removal) + doPolarRing2 = False, # ring removal + Rarc2=30, # min angle needed to be considered ring artifact (ring removal) + Rmaxwidth2=100, # max width of rings to be filtered (ring removal) + Rtmax2=3000.0, # max portion of image to filter (ring removal) + Rthr2=3000.0, # max value of offset due to ring artifact (ring removal) + Rtmin2=-3000.0, # min value of image to filter (ring removal) + cor=None, # center of rotation (float). If not used then cor will be detected automatically + corFunction = 'pc', # center of rotation function to use - can be 'pc', 'vo', or 'nm' + voInd = None, # index of slice to use for cor search (vo) + voSMin = -40, # min radius for searching in sinogram (vo) + voSMax = 40, # max radius for searching in sinogram (vo) + voSRad = 10, # search radius (vo) + voStep = 0.5, # search step (vo) + voRatio = 2.0, # ratio of field-of-view and object size (vo) + voDrop = 20, # drop lines around vertical center of mask (vo) + nmInd = None, # index of slice to use for cor search (nm) + nmInit = None, # initial guess for center (nm) + nmTol = 0.5, # desired sub-pixel accuracy (nm) + nmMask = True, # if True, limits analysis to circular region (nm) + nmRatio = 1.0, # ratio of radius of circular mask to edge of reconstructed image (nm) + nmSinoOrder = False, # if True, analyzes in sinogram space. If False, analyzes in radiograph space + use360to180 = False, # use 360 to 180 conversion + castTo8bit = False, # convert data to 8bit before writing + cast8bit_min=-10, # min value if converting to 8bit + cast8bit_max=30, # max value if converting to 8bit + useNormalize_nf = False, # normalize based on background intensity (nf) + chunk_proj = 100, # chunk size in projection direction + chunk_sino = 100, # chunk size in sinogram direction + npad = None, # amount to pad data before reconstruction + projused = None, # should be slicing in projection dimension (start,end,step) Be sure to add one to the end as stop in python means the last value is omitted + sinoused = None, # should be sliceing in sinogram dimension (start,end,step). If first value is negative, it takes the number of slices from the second value in the middle of the stack. + correcttilt = 0, # tilt dataset + tiltcenter_slice = None, # tilt center (x direction) + tiltcenter_det = None, # tilt center (y direction) + angle_offset = 0, # this is the angle offset from our default (270) so that tomopy yields output in the same orientation as previous software (Octopus) + anglelist = None, # if not set, will assume evenly spaced angles which will be calculated by the angular range and number of angles found in the file. if set to -1, will read individual angles from each image. alternatively, a list of angles can be passed. + doBeamHardening = False, # turn on beam hardening correction, based on "Correction for beam hardening in computed tomography", Gabor Herman, 1979 Phys. Med. Biol. 24 81 + BeamHardeningCoefficients = (0, 1, 0, 0, 0, .1), # 6 values, tomo = a0 + a1*tomo + a2*tomo^2 + a3*tomo^3 + a4*tomo^4 + a5*tomo^5 + projIgnoreList = None, # projections to be ignored in the reconstruction (for simplicity in the code, they will not be removed and will be processed as all other projections but will be set to zero absorption right before reconstruction. + bfexposureratio = 1, #ratio of exposure time of bf to exposure time of sample + pxsize = .001, + numslices= 100, + numangles= 3, + angularrange= 180, + numrays= 2560, + inter_bright= 0, + nflat= 15, + ind_flat=1, + group_flat= None, + ndrk=10, + ind_dark=1, + ind_tomo= [0,1,2], + floc_independent= 1, + function_list= ['normalize','minus_log','recon_mask','write_output'], + dorecon=True, + doNormalize=True, + writeraw=False, + writenormalized=False, + writereconstruction=True, + dominuslog=True, + verbose_printing=False, + recon_algorithm='gridrec', #choose from gridrec, fbp, and others in tomopy + dolensdistortion=False, + lensdistortioncenter = (1280,1080), + lensdistortionfactors = (1.00015076, 1.9289e-06, -2.4325e-08, 1.00439e-11, -3.99352e-15), + minimum_transmission = 0.01, + *args, **kwargs + ): + + start_time = time.time() + if verbose_printing: + print("Start {} at:".format(filename)+time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.localtime())) + + filenametowrite = os.path.join(fulloutputPath,outputFilename) + if verbose_printing: + print("Time point: {}".format(timepoint)) + + tempfilenames = [os.path.join(fulloutputPath,'tmp0.h5'),os.path.join(fulloutputPath,'tmp1.h5')] + if verbose_printing: + print("cleaning up previous temp files") #, end="") + for tmpfile in tempfilenames: + try: + os.remove(tmpfile) + except OSError: + pass + + numprojused = len(range(projused[0], projused[1],projused[2])) # number of total projections. We add 1 to include the last projection + numsinoused = len(range(sinoused[0], sinoused[1],sinoused[2])) # number of total sinograms. We add 1 to include the last projection + num_proj_per_chunk = np.minimum(chunk_proj,numprojused) # sets the chunk size to either all of the projections used or the chunk size + numprojchunks = (numprojused - 1) // num_proj_per_chunk + 1 # adding 1 fixes the case of the number of projections not being a factor of the chunk size. Subtracting 1 fixes the edge case where the number of projections is a multiple of the chunk size + num_sino_per_chunk = np.minimum(chunk_sino, numsinoused) # same as num_proj_per_chunk + numsinochunks = (numsinoused - 1) // num_sino_per_chunk + 1 # adding 1 fixes the case of the number of sinograms not being a factor of the chunk size. Subtracting 1 fixes the edge case where the number of sinograms is a multiple of the chunk size + + # Ensure the output directory exists + Path(fulloutputPath).mkdir(parents=True, exist_ok=True) + set_directory_permissions(fulloutputPath) # Set permissions for the output directory + + # Figure out first direction to slice + for func in function_list: + if slice_dir[func] != 'both': + axis = slice_dir[func] + break + else: + axis = 'sino' + + done = False + curfunc = 0 + curtemp = 0 + + if not dorecon: + rec = 0 + + while True: # Loop over reading data in certain chunking direction + if axis=='proj': + niter = numprojchunks + else: + niter = numsinochunks + for y in range(niter): # Loop over chunks + if verbose_printing: + print("{} chunk {} of {}".format(axis, y+1, niter)) + # The standard case. Unless the combinations below are in our function list, we read darks and flats normally, and on next chunck proceed to "else." + if curfunc == 0 and not (('normalize_nf' in function_list and 'remove_outlier2d' in function_list) or ('remove_outlier1d' in function_list and 'remove_outlier2d' in function_list)): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if axis=='proj': + if (filetype=='als'): + tomo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath,filename),ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]),sino=(sinoused[0],sinoused[1],sinoused[2])) + if bffilename is not None: + tomobf, _, _, _ = dxchange.read_als_832h5(os.path.join(inputPath,bffilename),sino=(sinoused[0],sinoused[1],sinoused[2])) #I don't think we need this for separate bf: ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), + flat = tomobf + elif (filetype=='als1131'): + tomo, flat, dark, floc = read_als_1131h5(os.path.join(inputPath,filename),ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]),sino=(sinoused[0],sinoused[1],sinoused[2])) + if bffilename is not None: + tomobf, _, _, _ = read_als_1131h5(os.path.join(inputPath,bffilename),sino=(sinoused[0],sinoused[1],sinoused[2])) #I don't think we need this for separate bf: ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), + flat = tomobf + elif (filetype == 'dxfile'): + tomo, flat, dark, _= dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, + proj=( y * projused[2] * num_proj_per_chunk + projused[0], + + np.minimum((y + 1) * projused[2] * num_proj_per_chunk + projused[0], projused[1]), projused[2]), + sino=sinoused) # dtype=None, , ) + elif (filetype=='sls'): + tomo, flat, dark, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+y*projused[2]*num_proj_per_chunk+projused[0],timepoint*numangles+np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), sino=sinoused) #dtype=None, , ) + else: + break + else: + if (filetype == 'als'): + tomo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath,filename),ind_tomo=range(projused[0],projused[1],projused[2]),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) + if bffilename is not None: + tomobf, _, _, _ = dxchange.read_als_832h5(os.path.join(inputPath, bffilename),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) # I don't think we need this for separate bf: ind_tomo=range(projused[0],projused[1],projused[2]), + flat = tomobf + elif (filetype == 'als1131'): + tomo, flat, dark, floc = read_als_1131h5(os.path.join(inputPath,filename),ind_tomo=range(projused[0],projused[1],projused[2]),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) + if bffilename is not None: + tomobf, _, _, _ = read_als_1131h5(os.path.join(inputPath, bffilename),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) # I don't think we need this for separate bf: ind_tomo=range(projused[0],projused[1],projused[2]), + flat = tomobf + elif (filetype == 'dxfile'): + tomo, flat, dark, _ = dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, + proj=( projused[0], + projused[1], projused[2]), + sino=(y * sinoused[2] * num_sino_per_chunk + sinoused[0], + np.minimum( + (y + 1) * sinoused[2] * num_sino_per_chunk + + sinoused[0], sinoused[1]), + sinoused[2])) # dtype=None, , ) + elif (filetype=='sls'): + tomo, flat, dark, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+projused[0],timepoint*numangles+projused[1],projused[2]), sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) #dtype=None, , ) + else: + break + # Handles the initial reading of scans. Flats and darks are not read in, because the chunking direction will swap before we normalize. We read in darks when we normalize. + elif curfunc == 0: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if axis=='proj': + if (filetype == 'als') or (filetype == 'als1131'): + tomo = read_als_h5_tomo_only(os.path.join(inputPath,filename),ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]),sino=(sinoused[0],sinoused[1], sinoused[2]), bl=filetype) + elif (filetype == 'dxfile'): + tomo, _, _, _, _ = read_sls(os.path.join(inputPath, filename), exchange_rank=0, proj=( + y * projused[2] * num_proj_per_chunk + projused[0], + np.minimum((y + 1) * projused[2] * num_proj_per_chunk + projused[0], + projused[1]), projused[2]), + sino=sinoused) # dtype=None, , ) + elif (filetype=='sls'): + tomo, _, _, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+y*projused[2]*num_proj_per_chunk+projused[0],timepoint*numangles+np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), sino=sinoused) #dtype=None, , ) + else: + break + else: + if (filetype == 'als') or (filetype == 'als1131'): + tomo = read_als_h5_tomo_only(os.path.join(inputPath,filename),ind_tomo=range(projused[0],projused[1],projused[2]),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2]),bl=filetype) + elif (filetype == 'dxfile'): + tomo, _, _, _ = dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, proj=( + projused[0], projused[1], projused[2]), + sino=(y * sinoused[2] * num_sino_per_chunk + sinoused[0], + np.minimum( + (y + 1) * sinoused[2] * num_sino_per_chunk + sinoused[0], + sinoused[1]), sinoused[2])) # dtype=None, , ) + elif (filetype=='sls'): + tomo, _, _, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+projused[0],timepoint*numangles+projused[1],projused[2]), sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) #dtype=None, , ) + else: + break + # Handles the reading of darks and flats, once we know the chunking direction will not change before normalizing. + elif ('remove_outlier2d' == function_list[curfunc] and 'normalize' in function_list) or 'normalize_nf' == function_list[curfunc]: + if axis == 'proj': + start, end = y * num_proj_per_chunk, np.minimum((y + 1) * num_proj_per_chunk,numprojused) + tomo = dxchange.reader.read_hdf5(tempfilenames[curtemp],'/tmp/tmp',slc=((start,end,1),(0,numslices,1),(0,numrays,1))) #read in intermediate file + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als') or (filetype == 'als1131'): + flat, dark, floc = read_als_h5_non_tomo(os.path.join(inputPath,filename),ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]),sino=(sinoused[0],sinoused[1], sinoused[2]),bl=filetype) + if bffilename is not None: + if filetype == 'als': + tomobf, _, _, _ = dxchange.read_als_832h5(os.path.join(inputPath,bffilename),sino=(sinoused[0],sinoused[1], sinoused[2])) #I don't think we need this since it is full tomo in separate file: ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]) + flat = tomobf + elif filetype == 'als1131': + tomobf, _, _, _ = read_als_1131h5(os.path.join(inputPath,bffilename),sino=(sinoused[0],sinoused[1], sinoused[2])) #I don't think we need this since it is full tomo in separate file: ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]) + flat = tomobf + elif (filetype == 'dxfile'): + _, flat, dark, _ = dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, proj=( + y * projused[2] * num_proj_per_chunk + projused[0], + np.minimum( + (y + 1) * projused[2] * num_proj_per_chunk + projused[0], projused[1]), + projused[2]), sino=sinoused) # dtype=None, , ) + elif (filetype=='sls'): + _, flat, dark, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+y*projused[2]*num_proj_per_chunk+projused[0],timepoint*numangles+np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), sino=sinoused) #dtype=None, , ) + else: + break + else: + start, end = y * num_sino_per_chunk, np.minimum((y + 1) * num_sino_per_chunk,numsinoused) + tomo = dxchange.reader.read_hdf5(tempfilenames[curtemp],'/tmp/tmp',slc=((0,numangles,1),(start,end,1),(0,numrays,1))) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als') or (filetype == 'als1131'): + flat, dark, floc = read_als_h5_non_tomo(os.path.join(inputPath,filename),ind_tomo=range(projused[0],projused[1],projused[2]),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2]),bl=filetype) + elif (filetype == 'dxfile'): + _, flat, dark, _, _ = read_sls(os.path.join(inputPath, filename), exchange_rank=0, proj=( + projused[0], projused[1], projused[2]), + sino=(y * sinoused[2] * num_sino_per_chunk + sinoused[0], + np.minimum( + (y + 1) * sinoused[2] * num_sino_per_chunk + sinoused[ + 0], sinoused[1]), sinoused[2])) # dtype=None, , ) + elif (filetype=='sls'): + _, flat, dark, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+projused[0],timepoint*numangles+projused[1],projused[2]), sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) #dtype=None, , ) + else: + break + # Anything after darks and flats have been read or the case in which remove_outlier2d is the current/2nd function and the previous case fails. + else: + if axis=='proj': + start, end = y * num_proj_per_chunk, np.minimum((y + 1) * num_proj_per_chunk,numprojused) + tomo = dxchange.reader.read_hdf5(tempfilenames[curtemp],'/tmp/tmp',slc=((start,end,1),(0,numslices,1),(0,numrays,1))) #read in intermediate file + else: + start, end = y * num_sino_per_chunk, np.minimum((y + 1) * num_sino_per_chunk,numsinoused) + tomo = dxchange.reader.read_hdf5(tempfilenames[curtemp],'/tmp/tmp',slc=((0,numangles,1),(start,end,1),(0,numrays,1))) + dofunc = curfunc + keepvalues = None + while True: # Loop over operations to do in current chunking direction + func_name = function_list[dofunc] + newaxis = slice_dir[func_name] + if newaxis != 'both' and newaxis != axis: + # We have to switch axis, so flush to disk + if y==0: + try: + os.remove(tempfilenames[1-curtemp]) + except OSError: + pass + appendaxis = 1 if axis=='sino' else 0 + dxchange.writer.write_hdf5(tomo,fname=tempfilenames[1-curtemp],gname='tmp',dname='tmp',overwrite=False,appendaxis=appendaxis) #writing intermediate file... + break + if verbose_printing: + print(func_name, end=" ") + curtime = time.time() + if func_name == 'write_raw': + dxchange.write_tiff_stack(tomo, fname=filenametowrite,start=y * num_proj_per_chunk + projused[0]) + if y == 0: + dxchange.write_tiff_stack(flat, fname=filenametowrite+'bak',start=0) + dxchange.write_tiff_stack(dark, fname=filenametowrite + 'drk', start=0) + elif func_name == 'remove_outlier1d': + tomo = tomo.astype(np.float32,copy=False) + remove_outlier1d(tomo, outlier_diff1D, size=outlier_size1D, out=tomo) + elif func_name == 'remove_outlier2d': + tomo = tomo.astype(np.float32,copy=False) + tomopy.remove_outlier(tomo, outlier_diff2D, size=outlier_size2D, axis=0, out=tomo) + elif func_name == 'normalize_nf': + tomo = tomo.astype(np.float32,copy=False) + tomopy.normalize_nf(tomo, flat, dark, floc_independent, out=tomo) #use floc_independent b/c when you read file in proj chunks, you don't get the correct floc returned right now to use here. + if bfexposureratio != 1: + if verbose_printing: + print("correcting bfexposureratio") + tomo = tomo * bfexposureratio + elif func_name == 'normalize': + tomo = tomo.astype(np.float32,copy=False) + tomopy.normalize(tomo, flat, dark, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + if verbose_printing: + print("correcting bfexposureratio") + elif func_name == 'minus_log': + mx = np.float32(minimum_transmission) #setting min %transmission to 1% helps avoid streaking from very high absorbing areas + ne.evaluate('where(tomo>mx, tomo, mx)', out=tomo) + tomopy.minus_log(tomo, out=tomo) + elif func_name == 'beam_hardening': + loc_dict = {'a{}'.format(i):np.float32(val) for i,val in enumerate(BeamHardeningCoefficients)} + loc_dict['tomo'] = tomo + tomo = ne.evaluate('a0 + a1*tomo + a2*tomo**2 + a3*tomo**3 + a4*tomo**4 + a5*tomo**5', local_dict=loc_dict, out=tomo) + elif func_name == 'remove_stripe_fw': + tomo = tomopy.remove_stripe_fw(tomo, sigma=ringSigma, level=ringLevel, pad=True, wname=ringWavelet) + elif func_name == 'remove_stripe_ti': + tomo = tomopy.remove_stripe_ti(tomo, nblock=ringNBlock, alpha=ringAlpha) + elif func_name == 'remove_stripe_sf': + tomo = tomopy.remove_stripe_sf(tomo, size=ringSize) + elif func_name == 'correcttilt': + if tiltcenter_slice is None: + tiltcenter_slice = numslices/2. + if tiltcenter_det is None: + tiltcenter_det = tomo.shape[2]/2 + new_center = tiltcenter_slice - 0.5 - sinoused[0] + center_det = tiltcenter_det - 0.5 + + # add padding of 10 pixels, to be unpadded right after tilt correction. + # This makes the tilted image not have zeros at certain edges, + # which matters in cases where sample is bigger than the field of view. + # For the small amounts we are generally tilting the images, 10 pixels is sufficient. + # tomo = tomopy.pad(tomo, 2, npad=10, mode='edge') + # center_det = center_det + 10 + + cntr = (center_det, new_center) + for b in range(tomo.shape[0]): + tomo[b] = st.rotate(tomo[b], correcttilt, center=cntr, preserve_range=True, order=1, mode='edge', clip=True) # center=None means image is rotated around its center; order=1 is default, order of spline interpolation +# tomo = tomo[:, :, 10:-10] + elif func_name == 'lensdistortion': + if verbose_printing: + print(lensdistortioncenter[0]) + print(lensdistortioncenter[1]) + print(lensdistortionfactors[0]) + print(lensdistortionfactors[1]) + print(lensdistortionfactors[2]) + print(lensdistortionfactors[3]) + print(lensdistortionfactors[4]) + print(type(lensdistortionfactors[0])) + print(type(lensdistortionfactors)) + tomo = tomopy.prep.alignment.distortion_correction_proj(tomo, lensdistortioncenter[0], lensdistortioncenter[1], lensdistortionfactors, ncore=None,nchunk=None) + elif func_name == 'do_360_to_180': + + # Keep values around for processing the next chunk in the list + keepvalues = [angularrange, numangles, projused, num_proj_per_chunk, numprojchunks, numprojused, numrays, anglelist] + + # why -.5 on one and not on the other? + if tomo.shape[0]%2>0: + tomo = sino_360_to_180(tomo[0:-1,:,:], overlap=int(np.round((tomo.shape[2]-cor-.5))*2), rotation='right') + angularrange = angularrange/2 - angularrange/(tomo.shape[0]-1) + else: + tomo = sino_360_to_180(tomo[:,:,:], overlap=int(np.round((tomo.shape[2]-cor))*2), rotation='right') + angularrange = angularrange/2 + numangles = int(numangles/2) + projused = (0,numangles-1,1) + numprojused = len(range(projused[0],projused[1],projused[2])) + num_proj_per_chunk = np.minimum(chunk_proj,numprojused) + numprojchunks = (numprojused-1)//num_proj_per_chunk+1 + numrays = tomo.shape[2] + + anglelist = anglelist[:numangles] + + elif func_name == 'phase_retrieval': + tomo = tomopy.retrieve_phase(tomo, pixel_size=pxsize, dist=propagation_dist, energy=kev, alpha=alphaReg, pad=True) + + elif func_name == 'translation_correction': + tomo = linear_translation_correction(tomo,dx=xshift,dy=yshift,interpolation=False) + + elif func_name == 'recon_mask': + tomo = tomopy.pad(tomo, 2, npad=npad, mode='edge') + + if projIgnoreList is not None: + for badproj in projIgnoreList: + tomo[badproj] = 0 + rec = tomopy.recon(tomo, anglelist, center=cor+npad, algorithm=recon_algorithm, filter_name='butterworth', filter_par=[butterworth_cutoff, butterworth_order], ncore=64) + rec = rec[:, npad:-npad, npad:-npad] + rec /= pxsize # convert reconstructed voxel values from 1/pixel to 1/cm + rec = tomopy.circ_mask(rec, 0) + tomo = tomo[:, :, npad:-npad] + elif func_name == 'polar_ring': + rec = np.ascontiguousarray(rec, dtype=np.float32) + rec = tomopy.remove_ring(rec, theta_min=Rarc, rwidth=Rmaxwidth, thresh_max=Rtmax, thresh=Rthr, thresh_min=Rtmin,out=rec) + elif func_name == 'polar_ring2': + rec = np.ascontiguousarray(rec, dtype=np.float32) + rec = tomopy.remove_ring(rec, theta_min=Rarc2, rwidth=Rmaxwidth2, thresh_max=Rtmax2, thresh=Rthr2, thresh_min=Rtmin2,out=rec) + elif func_name == 'castTo8bit': + rec = convert8bit(rec, cast8bit_min, cast8bit_max) + elif func_name == 'write_reconstruction': + if dorecon: + if sinoused[2] == 1: + dxchange.write_tiff_stack(rec, fname=filenametowrite, start=y*num_sino_per_chunk + sinoused[0]) + else: + num = y*sinoused[2]*num_sino_per_chunk+sinoused[0] + for sinotowrite in rec: #fixes issue where dxchange only writes for step sizes of 1 + dxchange.writer.write_tiff(sinotowrite, fname=filenametowrite + '_' + '{0:0={1}d}'.format(num, 5)) + set_file_permissions(filenametowrite + '_' + '{0:0={1}d}_norm'.format(num, 5)) # Set permissions for individual files + num += sinoused[2] + else: + if verbose_printing: + print('Reconstruction was not done because dorecon was set to False.') + elif func_name == 'write_normalized': + if projused[2] == 1: + dxchange.write_tiff_stack(tomo, fname=filenametowrite+'_norm', start=y * num_proj_per_chunk + projused[0]) + else: + num = y * projused[2] * num_proj_per_chunk + projused[0] + for projtowrite in tomo: # fixes issue where dxchange only writes for step sizes of 1 + dxchange.writer.write_tiff(projtowrite,fname=filenametowrite + '_' + '{0:0={1}d}_norm'.format(num, 5)) + set_file_permissions(filenametowrite + '_' + '{0:0={1}d}_norm'.format(num, 5)) # Set permissions for individual files + num += projused[2] + if verbose_printing: + print('(took {:.2f} seconds)'.format(time.time()-curtime)) + dofunc+=1 + if dofunc==len(function_list): + break + if y255,255,scl)',out=scl) + return scl.astype(np.uint8) + + +def sino_360_to_180(data, overlap=0, rotation='left'): + """ + Converts 0-360 degrees sinogram to a 0-180 sinogram. + + Parameters + ---------- + data : ndarray + Input 3D data. + + overlap : scalar, optional + Overlapping number of pixels. + + rotation : string, optional + Left if rotation center is close to the left of the + field-of-view, right otherwise. + + Returns + ------- + ndarray + Output 3D data. + """ + dx, dy, dz = data.shape + lo = overlap//2 + ro = overlap - lo + n = dx//2 + out = np.zeros((n, dy, 2*dz-overlap), dtype=data.dtype) + if rotation == 'left': + weights = (np.arange(overlap)+0.5)/overlap + out[:, :, -dz+overlap:] = data[:n, :, overlap:] + out[:, :, :dz-overlap] = data[n:2*n, :, overlap:][:, :, ::-1] + out[:, :, dz-overlap:dz] = weights*data[:n, :, :overlap] + (weights*data[n:2*n, :, :overlap])[:, :, ::-1] + elif rotation == 'right': + weights = (np.arange(overlap)[::-1]+0.5)/overlap + out[:, :, :dz-overlap] = data[:n, :, :-overlap] + out[:, :, -dz+overlap:] = data[n:2*n, :, :-overlap][:, :, ::-1] + out[:, :, dz-overlap:dz] = weights*data[:n, :, -overlap:] + (weights*data[n:2*n, :, -overlap:])[:, :, ::-1] + return out + + + + +def remove_outlier1d(arr, dif, size=3, axis=0, ncore=None, out=None): + """ + Remove high intensity bright spots from an array, using a one-dimensional + median filter along the specified axis. + + Dula: also removes dark spots + + Parameters + ---------- + arr : ndarray + Input array. + dif : float + Expected difference value between outlier value and + the median value of the array. + size : int + Size of the median filter. + axis : int, optional + Axis along which median filtering is performed. + ncore : int, optional + Number of cores that will be assigned to jobs. + out : ndarray, optional + Output array for result. If same as arr, process will be done in-place. + Returns + ------- + ndarray + Corrected array. + """ + arr = arr.astype(np.float32,copy=False) + dif = np.float32(dif) + + tmp = np.empty_like(arr) + + other_axes = [i for i in range(arr.ndim) if i != axis] + largest = np.argmax([arr.shape[i] for i in other_axes]) + lar_axis = other_axes[largest] + ncore, chnk_slices = mproc.get_ncore_slices(arr.shape[lar_axis],ncore=ncore) + filt_size = [1]*arr.ndim + filt_size[axis] = size + + with cf.ThreadPoolExecutor(ncore) as e: + slc = [slice(None)]*arr.ndim + for i in range(ncore): + slc[lar_axis] = chnk_slices[i] + e.submit(snf.median_filter, arr[slc], size=filt_size,output=tmp[slc], mode='mirror') + + with mproc.set_numexpr_threads(ncore): + out = ne.evaluate('where(abs(arr-tmp)>=dif,tmp,arr)', out=out) + + return out + + +def translate(data,dx=0,dy=0,interpolation=True): + """ + Shifts all projections in an image stack by dx (horizontal) and dy (vertical) pixels. Translation with subpixel resolution is possible with interpolation==True + + Parameters + ---------- + data: ndarray + Input array, stack of 2D (x,y) images, angle in z + + dx: int or float + desored horizontal pixel shift + + dy: int or float + desired vertical pixel shift + + interpolation: boolean + True calls funtion from sckimage to interpolate image when subpixel shifts are applied + + Returns + ------- + ndarray + Corrected array. + """ + + Nproj, Nrow, Ncol = data.shape + dataOut = np.zeros(data.shape) + + if interpolation == True: + #translateFunction = st.SimilarityTransform(translation=(-dx,dy)) + M=np.matrix([[1,0,-dx],[0,1,dy],[0,0,1]]) + translateFunction = st.SimilarityTransform(matrix=M) + for n in range(Nproj): + dataOut[n,:,:] = st.warp(data[n,:,:], translateFunction) + + if interpolation == False: + Npad = max(dx,dy) + drow = int(-dy) # negative matrix row increments = dy + dcol = int(dx) # matrix column increments = dx + for n in range(Nproj): + PaddedImage = np.pad(data[n,:,:],Npad,'constant') + dataOut[n,:,:] = PaddedImage[Npad-drow:Npad+Nrow-drow,Npad-dcol:Npad+Ncol-dcol] # shift image by dx and dy, replace original without padding + + return dataOut + + +def linear_translation_correction(data,dx=100.5,dy=700.1,interpolation=True): + + """ + Corrects for a linear drift in field of view (horizontal dx, vertical dy) over time. The first index indicaties time data[time,:,:] in the time series of projections. dx and dy are the final shifts in FOV position. + + Parameters + ---------- + data: ndarray + Input array, stack of 2D (x,y) images, angle in z + + dx: int or float + total horizontal pixel offset from first (0 deg) to last (180 deg) projection + + dy: int or float + total horizontal pixel offset from first (0 deg) to last (180 deg) projection + + interpolation: boolean + True calls funtion from sckimage to interpolate image when subpixel shifts are applied + + Returns + ------- + ndarray + Corrected array. + """ + + Nproj, Nrow, Ncol = data.shape + Nproj=10 + + dataOut = np.zeros(data0.shape) + + dx_n = np.linspace(0,dx,Nproj) # generate array dx[n] of pixel shift for projection n = 0, 1, ... Nproj + + dy_n = np.linspace(0,dy,Nproj) # generate array dy[n] of pixel shift for projection n = 0, 1, ... Nproj + + if interpolation==True: + for n in range(Nproj): + #translateFunction = st.SimilarityTransform(translation=(-dx_n[n],dy_n[n])) # Generate Translation Function based on dy[n] and dx[n] + M=np.matrix([[1,0,-dx_n[n]],[0,1,dy_n[n]],[0,0,1]]) + translateFunction = st.SimilarityTransform(matrix=M) + image_n = data[n,:,:] + dataOut[n,:,:] = st.warp(image_n, translateFunction) # Apply translation with interpolation to projection[n] + #print(n) + + if interpolation==False: + Npad = max(dx,dy) + for n in range(Nproj): + PaddedImage = np.pad(data[n,:,:],Npad,'constant') # copy single projection and pad with maximum of dx,dy + drow = int(round(-dy_n[n])) # round shift to nearest pixel step, negative matrix row increments = dy + dcol = int(round(dx_n[n])) # round shift to nearest pixel step, matrix column increments = dx + dataOut[n,:,:] = PaddedImage[Npad-drow:Npad+Nrow-drow,Npad-dcol:Npad+Ncol-dcol] # shift image by dx and dy, replace original without padding + #print(n) + + return dataOut + + + """ + Parameters + ---------- + data: ndarray + Input array, stack of 2D (x,y) images, angle in z + pixelshift: float + total pixel offset from first (0 deg) to last (180 deg) projection + + Returns + ------- + ndarray + Corrected array. + """ + + +"""Hi Dula, +This is roughly what I am doing in the script to 'unspiral' the superweave data: +spd = float(int(sys.argv[2])/2048) +x = np.zeros((2049,200,2560), dtype=np.float32) +blks = np.round(np.linspace(0,2049,21)).astype(np.int) +for i in range(0,20): + dat = dxchange.read_als_832h5(fn, ind_tomo=range(blks[i],blks[i+1])) + prj = tomopy.normalize_nf(dat[0],dat[1],dat[2],dat[3]) + for ik,j in enumerate(range(blks[i],blks[i+1])): + l = prj.shape[1]//2-j*spd + li = int(l) + ri = li+200 + fc = l-li + x[j] = (1-fc)*prj[ik,li:ri] + x[j] += fc*prj[ik,li+1:ri+1] +dxchange.writer.write_hdf5(x, fname=fn[:-3]+'_unspiral.h5', overwrite=True, gname='tmp', dname='tmp', appendaxis=0) + +This processes the (roughly) central 200 slices, and saves it to a new file. The vertical speed is one of the input arguments, and I simply estimate it manually by looking at the first and last projection, shifting them by 'np.roll'. The input argument is the total number of pixels that are shifted over the whole scan (which is then converted to pixels per projection by dividing by the number of projections-1). +I don't really remember why I wrote my own code, but maybe I was running into problems using scikit-image as well. The current code uses linear interpolation, and gives pretty good results for the data I tested. + +Best, + +Daniel""" + + +def convertthetype(val): + constructors = [int, float, str] + for c in constructors: + try: + return c(val) + except ValueError: + pass + +############################################################################################### +# New Readers, so we don't have to read in darks and flats until they're needed +############################################################################################### +# Tomo +############################################################################################### + +def read_als_h5_tomo_only(fname, ind_tomo=None, ind_flat=None, ind_dark=None, + proj=None, sino=None): + """ + Read ALS 8.3.2 hdf5 file with stacked datasets. + + Parameters + ---------- + See docs for read_als_832h5 + """ + + with dxchange.reader.find_dataset_group(fname) as dgroup: + dname = dgroup.name.split('/')[-1] + + tomo_name = dname + '_0000_0000.tif' + + # Read metadata from dataset group attributes + keys = list(dgroup.attrs.keys()) + if 'nangles' in keys: + nproj = int(dgroup.attrs['nangles']) + + # Create arrays of indices to read projections + if ind_tomo is None: + ind_tomo = list(range(0, nproj)) + if proj is not None: + ind_tomo = ind_tomo[slice(*proj)] + + tomo = dxchange.reader.read_hdf5_stack( + dgroup, tomo_name, ind_tomo, slc=(None, sino)) + + return tomo + + +##################################################################################### +# Non tomo +##################################################################################### + +def read_als_h5_non_tomo(fname, ind_tomo=None, ind_flat=None, ind_dark=None, + proj=None, sino=None, whichbeamline='als'): + """ + Read ALS 8.3.2 hdf5 file with stacked datasets. + + Parameters + ---------- + See docs for read_als_832h5 + """ + + with dxchange.reader.find_dataset_group(fname) as dgroup: + dname = dgroup.name.split('/')[-1] + + flat_name = dname + 'bak_0000.tif' + dark_name = dname + 'drk_0000.tif' + + # Read metadata from dataset group attributes + keys = list(dgroup.attrs.keys()) + if 'nangles' in keys: + nproj = int(dgroup.attrs['nangles']) + if 'i0cycle' in keys: + inter_bright = int(dgroup.attrs['i0cycle']) + if 'num_bright_field' in keys: + nflat = int(dgroup.attrs['num_bright_field']) + else: + nflat = dxchange.reader._count_proj(dgroup, flat_name, nproj, + inter_bright=inter_bright) + if 'num_dark_fields' in keys: + ndark = int(dgroup.attrs['num_dark_fields']) + else: + ndark = dxchange.reader._count_proj(dgroup, dark_name, nproj) + + # Create arrays of indices to read projections, flats and darks + if ind_tomo is None: + ind_tomo = list(range(0, nproj)) + if proj is not None: + ind_tomo = ind_tomo[slice(*proj)] + ind_dark = list(range(0, ndark)) + group_dark = [nproj - 1] + ind_flat = list(range(0, nflat)) + + if inter_bright > 0: + group_flat = list(range(0, nproj, inter_bright)) + if group_flat[-1] != nproj - 1: + group_flat.append(nproj - 1) + elif inter_bright == 0: + if whichbeamline == 'als1131': + group_flat = [nproj - 1] + else: + group_flat = [0, nproj - 1] + else: + group_flat = None + + flat = dxchange.reader.read_hdf5_stack( + dgroup, flat_name, ind_flat, slc=(None, sino), out_ind=group_flat) + + dark = dxchange.reader.read_hdf5_stack( + dgroup, dark_name, ind_dark, slc=(None, sino), out_ind=group_dark) + + return flat, dark, dxchange.reader._map_loc(ind_tomo, group_flat) + + + +###################################################################################################### + +def read_als_1131h5(fname, ind_tomo=None, ind_flat=None, ind_dark=None, + proj=None, sino=None): + """ + Read ALS 11.3.1 hdf5 file with stacked datasets. + + Parameters + ---------- + + fname : str + Path to hdf5 file. + + ind_tomo : list of int, optional + Indices of the projection files to read. + + ind_flat : list of int, optional + Indices of the flat field files to read. + + ind_dark : list of int, optional + Indices of the dark field files to read. + + proj : {sequence, int}, optional + Specify projections to read. (start, end, step) + + sino : {sequence, int}, optional + Specify sinograms to read. (start, end, step) + + Returns + ------- + ndarray + 3D tomographic data. + + ndarray + 3D flat field data. + + ndarray + 3D dark field data. + + list of int + Indices of flat field data within tomography projection list + """ + + with dxchange.reader.find_dataset_group(fname) as dgroup: + dname = dgroup.name.split('/')[-1] + + tomo_name = dname + '_0000_0000.tif' + flat_name = dname + 'bak_0000.tif' + dark_name = dname + 'drk_0000.tif' + + # Read metadata from dataset group attributes + keys = list(dgroup.attrs.keys()) + if 'nangles' in keys: + nproj = int(dgroup.attrs['nangles']) + if 'i0cycle' in keys: + inter_bright = int(dgroup.attrs['i0cycle']) + if 'num_bright_field' in keys: + nflat = int(dgroup.attrs['num_bright_field']) + else: + nflat = dxchange.reader._count_proj(dgroup, flat_name, nproj, + inter_bright=inter_bright) + if 'num_dark_fields' in keys: + ndark = int(dgroup.attrs['num_dark_fields']) + else: + ndark = dxchange.reader._count_proj(dgroup, dark_name, nproj) + + # Create arrays of indices to read projections, flats and darks + if ind_tomo is None: + ind_tomo = list(range(0, nproj)) + if proj is not None: + ind_tomo = ind_tomo[slice(*proj)] + ind_dark = list(range(0, ndark)) + group_dark = [nproj - 1] + ind_flat = list(range(0, nflat)) + + if inter_bright > 0: + group_flat = list(range(0, nproj, inter_bright)) + if group_flat[-1] != nproj - 1: + group_flat.append(nproj - 1) + elif inter_bright == 0: + #group_flat = [0, nproj - 1] + group_flat = [nproj - 1] + else: + group_flat = None + + tomo = dxchange.reader.read_hdf5_stack( + dgroup, tomo_name, ind_tomo, slc=(None, sino)) + + flat = dxchange.reader.read_hdf5_stack( + dgroup, flat_name, ind_flat, slc=(None, sino), out_ind=group_flat) + + dark = dxchange.reader.read_hdf5_stack( + dgroup, dark_name, ind_dark, slc=(None, sino), out_ind=group_dark) + +# return tomo, flat, dark, dxchange.reader._map_loc(ind_tomo, group_flat) + return tomo, flat, dark, 0 + +###################################################################################################### + + +def read_sls(fname, exchange_rank=0, proj=None, sino=None, dtype=None): + """ + Read sls time resolved data format. + + Parameters + ---------- + fname : str + Path to hdf5 file. + + exchange_rank : int, optional + exchange_rank is added to "exchange" to point tomopy to the data + to reconstruct. if rank is not set then the data are raw from the + detector and are located under exchange = "exchange/...", to process + data that are the result of some intemedite processing step then + exchange_rank = 1, 2, ... will direct tomopy to process + "exchange1/...", + + proj : {sequence, int}, optional + Specify projections to read. (start, end, step) + + sino : {sequence, int}, optional + Specify sinograms to read. (start, end, step) + + dtype : numpy datatype, optional + Convert data to this datatype on read if specified. + + ind_tomo : list of int, optional + Indices of the projection files to read. + + Returns + ------- + ndarray + 3D tomographic data. + + ndarray + 3D flat field data. + + ndarray + 3D dark field data. + + ndarray + 1D theta in radian. + """ + if exchange_rank > 0: + exchange_base = 'exchange{:d}'.format(int(exchange_rank)) + else: + exchange_base = "exchange" + + tomo_grp = '/'.join([exchange_base, 'data']) + flat_grp = '/'.join([exchange_base, 'data_white']) + dark_grp = '/'.join([exchange_base, 'data_dark']) + theta_grp = '/'.join([exchange_base, 'theta']) + + tomo = dxchange.read_hdf5(fname, tomo_grp, slc=(proj, sino), dtype=dtype) + flat = dxchange.read_hdf5(fname, flat_grp, slc=(None, sino), dtype=dtype) + dark = dxchange.read_hdf5(fname, dark_grp, slc=(None, sino), dtype=dtype) + theta = dxchange.read_hdf5(fname, theta_grp) + + if (theta is None): + theta_grp = '/'.join([exchange_base, 'theta_aborted']) + theta = dxchange.read_hdf5(fname, theta_grp) + if (theta is None): + if verbose_printing: + print('could not find thetas, generating them based on 180 degree rotation') + theta_size = dxchange.read_dx_dims(fname, 'data')[0] + logger.warn('Generating "%s" [0-180] deg angles for missing "exchange/theta" dataset' % (str(theta_size))) + theta = np.linspace(0., 180., theta_size) + + theta = theta * np.pi / 180. + + if proj is not None: + theta = theta[proj[0]:proj[1]:proj[2]] + + return tomo, flat, dark, theta + +#Converts spreadsheet.xlsx file with headers into dictionaries +# def read_spreadsheet(filepath): +# workbook=xlrd.open_workbook(filepath) +# worksheet = workbook.sheet_by_index(0) +# +# # imports first row and converts to a list of header strings +# headerList = [] +# for col_index in range(worksheet.ncols): +# headerList.append(str(worksheet.cell_value(0,col_index))) +# +# dataList = [] +# # For each row, create a dictionary and like header name to data +# # converts each row to following format rowDictionary1 ={'header1':colvalue1,'header2':colvalue2,... } +# # compiles rowDictinaries into a list: dataList = [rowDictionary1, rowDictionary2,...] +# for row_index in range(1,worksheet.nrows): +# rowDictionary = {} +# for col_index in range(worksheet.ncols): +# cellValue = worksheet.cell_value(row_index,col_index) +# +# if type(cellValue)==unicode: +# cellValue = str(cellValue) +# +# # if cell contains string that looks like a tuple, convert to tuple +# if '(' in str(cellValue): +# cellValue = literal_eval(cellValue) +# +# # if cell contains string or int that looks like 'True', convert to boolean True +# if str(cellValue).lower() =='true' or (type(cellValue)==int and cellValue==1): +# cellValue = True +# +# # if cell contains string or int that looks like 'False', convert to boolean False +# if str(cellValue).lower() =='false' or (type(cellValue)==int and cellValue==0): +# cellValue = False +# +# if cellValue != '': # create dictionary element if cell value is not empty +# rowDictionary[headerList[col_index]] = cellValue +# dataList.append(rowDictionary) +# +# return(dataList) + + +# D.Y.Parkinson's interpreter for text input files +def main(parametersfile): + + if parametersfile.split('.')[-1] == 'txt': + with open(parametersfile,'r') as theinputfile: + theinput = theinputfile.read() + inputlist = theinput.splitlines() + for reconcounter in range(0,len(inputlist)): + inputlisttabsplit = inputlist[reconcounter].split() + if inputlisttabsplit: + functioninput = {'filename': inputlisttabsplit[0]} + for inputcounter in range(0,(len(inputlisttabsplit)-1)//2): + inputlisttabsplit[inputcounter*2+2] = inputlisttabsplit[inputcounter*2+2].replace('\"','') + inputcommasplit = inputlisttabsplit[inputcounter*2+2].split(',') + if len(inputcommasplit)>1: + inputcommasplitconverted = [] + for jk in range(0,len(inputcommasplit)): + inputcommasplitconverted.append(convertthetype(inputcommasplit[jk])) + else: + inputcommasplitconverted = convertthetype(inputlisttabsplit[inputcounter*2+2]) + functioninput[inputlisttabsplit[inputcounter*2+1]] = inputcommasplitconverted + else: + print("Ending at blank line in input.") + break + print("Read user input:") + print(functioninput) + recon_dictionary, _ = recon_setup(**functioninput) +# recon(**functioninput) + recon(**recon_dictionary) + +# H.S.Barnard Spreadsheet interpreter +# if parametersfile.split('.')[-1]=='xlsx': +# functioninput = read_spreadsheet(parametersfile) +# for i in range(len(functioninput)): +# recon(**functioninput[i]) + +# if __name__ == '__main__': +# parametersfile = 'input832.txt' if (len(sys.argv)<2) else sys.argv[1] +# main(parametersfile) + +def set_directory_permissions(path): + os.chmod(path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IROTH | stat.S_IXOTH | stat.S_ISGID) + +def set_file_permissions(path): + os.chmod(path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH) + +if __name__ == '__main__': + if len(sys.argv) < 3: + print("Usage: python sfapi_reconstruction.py ") + sys.exit(1) + + file_name = sys.argv[1] + folder_path = sys.argv[2] + '/' + + try: + recon_dictionary, _ = recon_setup(filename = file_name, inputPath = f"raw/{folder_path}", outputPath = "scratch/"+folder_path) + #, sinoused = (-1,1,1)) + + recon(**recon_dictionary) + except Exception as e: + print(f"An error occurred: {e}") + print("reconstruction completed") diff --git a/scripts/perlmutter/sfapi_reconstruction_multinode.py b/scripts/perlmutter/sfapi_reconstruction_multinode.py new file mode 100644 index 00000000..0a7a9995 --- /dev/null +++ b/scripts/perlmutter/sfapi_reconstruction_multinode.py @@ -0,0 +1,1761 @@ +from __future__ import print_function +import time +import h5py +import numpy as np +import numexpr as ne +import skimage.transform as st +import os +import sys +import scipy.ndimage.filters as snf +import concurrent.futures as cf +import warnings +import stat +from pathlib import Path + +#import xlrd # for importing excel spreadsheets +#from ast import literal_eval # For converting string to tuple + +try: + import tomopy + from tomopy.util import mproc +except: + print("warning: tomopy is not available") + +try: + import dxchange +except: + print("warning: dxchange is not available") + +# run this from the command line: +# python tomopy832.py +# it requires a separate file, which contains at minimum a list of filenames +# on separate lines. Default name of this file is input832.txt, but you can use any +# filename and run from the commandline as +# python tomopy832.py yourinputfile.txt +# If desired, on each line (separated by spaces) you can +# include parameters to override the defaults. +# to do this you need pairs, first the name of the variable, then the desired value +# For True/False, use 1/0. +# You can generate these input files in excel, in which case use tab-separated +# (or space separated). Some input overrides require multiple values, +# these should be comma-separated (with no spaces). Example is sinoused +# which would be e.g. 500,510,1 to get slices 500 through 509. For sinoused, +# you can use first value -1 and second value number of slices to get that number +# of slices from the middle of the stack. +# an example of the contents of the input file look like this: + +# filename.h5 cor 1196 sinoused "-1,10,1" doPhaseRetrieval 0 outputFilename c1196.0 +# filename.h5 cor 1196.5 sinoused "-1,10,1" doPhaseRetrieval 0 outputFilename c1196.5 + +# this was generated in excel and saved as txt tab separated, so the quotes were +# added automatically by excel. Note also that for parameters expecting strings as +# input (outputFilename for example), the program will choke if you put in a number. + +# if cor is not defined in the parameters file, automated cor detection will happen + +# chunk_proj and chunk_sino handle memory management. +# If you are running out of memory, make one or both of those smaller. + +slice_dir = { + 'write_raw': 'proj', + 'remove_outlier1d': 'sino', + 'remove_outlier2d': 'proj', + 'normalize_nf': 'sino', + 'normalize': 'both', + 'minus_log': 'both', + 'beam_hardening': 'both', + 'remove_stripe_fw': 'sino', + 'remove_stripe_ti': 'sino', + 'remove_stripe_sf': 'sino', + 'do_360_to_180': 'sino', + 'correcttilt': 'proj', + 'lensdistortion': 'proj', + 'phase_retrieval': 'proj', + 'recon_mask': 'sino', + 'polar_ring': 'sino', + 'polar_ring2': 'sino', + 'castTo8bit': 'both', + 'write_reconstruction': 'both', + 'write_normalized': 'proj', +} + + +def recon_setup( + filename, + filetype = 'dxfile', #other options are als, als1131, sls + timepoint = 0, + bffilename = None, + inputPath = './', # input path, location of the data set to reconstruct + outputPath=None, + # define an output path (default is inputPath), a sub-folder will be created based on file name + outputFilename=None, + # file name for output tif files (a number and .tiff will be added). default is based on input filename + fulloutputPath=None, # definte the full output path, no automatic sub-folder will be created + doOutliers1D=False, # outlier removal in 1d (along sinogram columns) + outlier_diff1D=750, # difference between good data and outlier data (outlier removal) + outlier_size1D=3, # radius around each pixel to look for outliers (outlier removal) + doOutliers2D=False, # outlier removal, standard 2d on each projection + outlier_diff2D=750, # difference between good data and outlier data (outlier removal) + outlier_size2D=3, # radius around each pixel to look for outliers (outlier removal) + doFWringremoval=True, # Fourier-wavelet ring removal + doTIringremoval=False, # Titarenko ring removal + doSFringremoval=False, # Smoothing filter ring removal + ringSigma=3, # damping parameter in Fourier space (Fourier-wavelet ring removal) + ringLevel=8, # number of wavelet transform levels (Fourier-wavelet ring removal) + ringWavelet='db5', # type of wavelet filter (Fourier-wavelet ring removal) + ringNBlock=0, # used in Titarenko ring removal (doTIringremoval) + ringAlpha=1.5, # used in Titarenko ring removal (doTIringremoval) + ringSize=5, # used in smoothing filter ring removal (doSFringremoval) + doPhaseRetrieval=False, # phase retrieval + alphaReg=0.00001, # smaller = smoother (used for phase retrieval) + propagation_dist=75.0, # sample-to-scintillator distance (phase retrieval) + kev=24.0, # energy level (phase retrieval) + butterworth_cutoff=0.25, # 0.1 would be very smooth, 0.4 would be very grainy (reconstruction) + butterworth_order=2, # for reconstruction + doTranslationCorrection=False, # correct for linear drift during scan + xshift=0, # undesired dx transation correction (from 0 degree to 180 degree proj) + yshift=0, # undesired dy transation correction (from 0 degree to 180 degree proj) + doPolarRing=False, # ring removal + Rarc=30, # min angle needed to be considered ring artifact (ring removal) + Rmaxwidth=100, # max width of rings to be filtered (ring removal) + Rtmax=3000.0, # max portion of image to filter (ring removal) + Rthr=3000.0, # max value of offset due to ring artifact (ring removal) + Rtmin=-3000.0, # min value of image to filter (ring removal) + doPolarRing2=False, # ring removal + Rarc2=30, # min angle needed to be considered ring artifact (ring removal) + Rmaxwidth2=100, # max width of rings to be filtered (ring removal) + Rtmax2=3000.0, # max portion of image to filter (ring removal) + Rthr2=3000.0, # max value of offset due to ring artifact (ring removal) + Rtmin2=-3000.0, # min value of image to filter (ring removal) + cor=None, # center of rotation (float). If not used then cor will be detected automatically + corFunction='pc', # center of rotation function to use - can be 'pc', 'vo', or 'nm', or use 'skip' to return tomo variable without having to do a calc. + corLoadMinimalBakDrk=True, #during cor detection, only load the first dark field and first flat field rather than all of them, to minimize file loading time for cor detection. + voInd=None, # index of slice to use for cor search (vo) + voSMin=-150, # min radius for searching in sinogram (vo) + voSMax=150, # max radius for searching in sinogram (vo) + voSRad=6, # search radius (vo) + voStep=0.25, # search step (vo) + voRatio=0.5, # ratio of field-of-view and object size (vo) + voDrop=20, # drop lines around vertical center of mask (vo) + nmInd=None, # index of slice to use for cor search (nm) + nmInit=None, # initial guess for center (nm) + nmTol=0.5, # desired sub-pixel accuracy (nm) + nmMask=True, # if True, limits analysis to circular region (nm) + nmRatio=1.0, # ratio of radius of circular mask to edge of reconstructed image (nm) + nmSinoOrder=False, # if True, analyzes in sinogram space. If False, analyzes in radiograph space + use360to180=False, # use 360 to 180 conversion + castTo8bit=False, # convert data to 8bit before writing + cast8bit_min=-10, # min value if converting to 8bit + cast8bit_max=30, # max value if converting to 8bit + useNormalize_nf=False, # normalize based on background intensity (nf) + chunk_proj=100, # chunk size in projection direction + chunk_sino=100, # chunk size in sinogram direction + npad=None, # amount to pad data before reconstruction + projused=None, + # should be slicing in projection dimension (start,end,step) Be sure to add one to the end as stop in python means the last value is omitted + sinoused=None, + # should be sliceing in sinogram dimension (start,end,step). If first value is negative, it takes the number of slices from the second value in the middle of the stack. + correcttilt=0, # tilt dataset + tiltcenter_slice=None, # tilt center (x direction) + tiltcenter_det=None, # tilt center (y direction) + angle_offset=0, + # this is the angle offset from our default (270) so that tomopy yields output in the same orientation as previous software (Octopus) + anglelist=None, + # if not set, will assume evenly spaced angles which will be calculated by the angular range and number of angles found in the file. if set to -1, will read individual angles from each image. alternatively, a list of angles can be passed. + doBeamHardening=False, + # turn on beam hardening correction, based on "Correction for beam hardening in computed tomography", Gabor Herman, 1979 Phys. Med. Biol. 24 81 + BeamHardeningCoefficients=None, # 6 values, tomo = a0 + a1*tomo + a2*tomo^2 + a3*tomo^3 + a4*tomo^4 + a5*tomo^5 + projIgnoreList=None, + # projections to be ignored in the reconstruction (for simplicity in the code, they will not be removed and will be processed as all other projections but will be set to zero absorption right before reconstruction. + bfexposureratio=1, # ratio of exposure time of bf to exposure time of sample + dorecon=True, #do the tomographic reconstruction + writeraw = False, + writenormalized=False, + writereconstruction=True, + doNormalize=True, + dominuslog=True, + slsnumangles=1000, + slspxsize=0.00081, + verbose_printing=False, + recon_algorithm='gridrec', # choose from gridrec, fbp, and others in tomopy + dolensdistortion=False, + lensdistortioncenter=(1280,1080), + lensdistortionfactors = (1.00015076, 1.9289e-06, -2.4325e-08, 1.00439e-11, -3.99352e-15), + minimum_transmission = 0.01, + *args, **kwargs + ): + + + outputFilename = os.path.splitext(filename)[0] if outputFilename is None else outputFilename + # outputPath = inputPath + 'rec' + os.path.splitext(filename)[0] + '/' if outputPath is None else outputPath + 'rec' + os.path.splitext(filename)[0] + '/' + outputPath = os.path.join(inputPath, 'rec' + outputFilename) if outputPath is None else os.path.join(outputPath,'rec' + outputFilename) + fulloutputPath = outputPath if fulloutputPath is None else fulloutputPath + tempfilenames = [os.path.join(fulloutputPath,'tmp0.h5'), os.path.join(fulloutputPath, 'tmp1.h5')] + + if verbose_printing: + print("cleaning up previous temp files", end="") + for tmpfile in tempfilenames: + try: + os.remove(tmpfile) + except OSError: + pass + if verbose_printing: + print(", reading metadata") + + if (filetype == 'als') or (filetype == 'als1131'): + datafile = h5py.File(os.path.join(inputPath,filename), 'r') + gdata = dict(dxchange.reader._find_dataset_group(datafile).attrs) + pxsize = float(gdata['pxsize']) / 10 # /10 to convert units from mm to cm + numslices = int(gdata['nslices']) + numangles = int(gdata['nangles']) + angularrange = float(gdata['arange']) + numrays = int(gdata['nrays']) + inter_bright = int(gdata['i0cycle']) + + + + dgroup = dxchange.reader._find_dataset_group(datafile) + keys = list(gdata.keys()) + if 'num_dark_fields' in keys: + ndark = int(gdata['num_dark_fields']) + else: + ndark = dxchange.reader._count_proj(dgroup, dgroup.name.split('/')[-1] + 'drk_0000.tif', numangles, inter_bright=-1) #for darks, don't want to divide out inter_bright for counting projections + ind_dark = list(range(0, ndark)) + group_dark = [numangles - 1] + + if 'num_bright_field' in keys: + nflat = int(gdata['num_bright_field']) + else: + nflat = dxchange.reader._count_proj(dgroup, dgroup.name.split('/')[-1] + 'bak_0000.tif', numangles, inter_bright=inter_bright) + ind_flat = list(range(0, nflat)) + + # figure out the angle list (a list of angles, one per projection image) + dtemp = datafile[list(datafile.keys())[0]] + fltemp = list(dtemp.keys()) + firstangle = float(dtemp[fltemp[0]].attrs.get('rot_angle', 0)) + if anglelist is None: + # the offset angle should offset from the angle of the first image, which is usually 0, but in the case of timbir data may not be. + # we add the 270 to be inte same orientation as previous software used at bl832 + angle_offset = 270 + angle_offset - firstangle + anglelist = tomopy.angles(numangles, angle_offset, angle_offset - angularrange) + elif anglelist == -1: + anglelist = np.zeros(shape=numangles) + for icount in range(0, numangles): + anglelist[icount] = np.pi / 180 * (270 + angle_offset - float(dtemp[fltemp[icount]].attrs['rot_angle'])) + if inter_bright > 0: + group_flat = list(range(0, numangles, inter_bright)) + if group_flat[-1] != numangles - 1: + group_flat.append(numangles - 1) + elif inter_bright == 0: + group_flat = [0, numangles - 1] + else: + group_flat = None + elif filetype == 'dxfile': + numangles = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/rotation/num_angles")[0]) + angularrange = dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/rotation/range")[0] + anglelist = dxchange.read_hdf5(os.path.join(inputPath, filename), '/exchange/theta', slc=None) + if anglelist is None: + try: + # See if the rotation start, step, num_angles are in the file + rotation_start = dxchange.read_hdf5(os.path.join(inputPath, filename), + '/process/acquisition/rotation/rotation_start')[0] + rotation_step = dxchange.read_hdf5(os.path.join(inputPath, filename), + '/process/acquisition/rotation/rotation_step')[0] + anglelist = rotation_start + rotation_step * range(numangles) + except: + anglelist = np.linspace(0. - angle_offset, angularrange, numangles) + anglelist = anglelist - angle_offset + anglelist = np.deg2rad(anglelist) + anglelist = -anglelist + numslices = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/detector/dimension_y")[0]) + numrays = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/detector/dimension_x")[0]) + pxsize = dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/detector/pixel_size")[0] / 10.0 # /10 to convert units from mm to cm + inter_bright = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/flat_fields/i0cycle")[0]) + group_flat = [0, numangles - 1] + nflat = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/flat_fields/num_flat_fields")[0]) + ind_flat = list(range(0, nflat)) + ndark = int(dxchange.read_hdf5(os.path.join(inputPath, filename), "/process/acquisition/dark_fields/num_dark_fields")[0]) + ind_dark = list(range(0, ndark)) + propagation_dist = dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/camera_motor_stack/setup/camera_distance")[0] + if (propagation_dist == 0): + propagation_dist = dxchange.read_hdf5(os.path.join(inputPath, filename), + "/measurement/instrument/camera_motor_stack/setup/camera_distance")[1] + kev = dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/monochromator/energy")[0] / 1000 + if (kev == 0): + kev = dxchange.read_hdf5(os.path.join(inputPath, filename), "/measurement/instrument/monochromator/energy")[ + 1] / 1000 + + if (isinstance(kev, int) or isinstance(kev, float)): + if kev > 1000: + kev = 30.0 + else: + kev = 30.0 + elif filetype == 'sls': + datafile = h5py.File(os.path.join(inputPath, filename), 'r') + slsdata = datafile["exchange/data"] + numslices = slsdata.shape[1] + numrays = slsdata.shape[2] + pxsize = slspxsize + numangles = slsnumangles + _, _, _, anglelist = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles,(timepoint+1)*numangles,1), sino=(0,1,1)) #dtype=None, , ) + angularrange = np.abs(anglelist[-1]-anglelist[0]) + inter_bright = 0 + group_flat = [0, numangles - 1] + nflat = 1 #this variable is not used for sls data + ind_flat = list(range(0, nflat)) + else: + print("Not sure what file type, gotta break.") + return + + npad = int(np.ceil(numrays * np.sqrt(2)) - numrays) // 2 if npad is None else npad + if projused is not None and (projused[1] > numangles - 1 or projused[0] < 0): # allows program to deal with out of range projection values + if projused[1] > numangles: + print("End Projection value greater than number of angles. Value has been lowered to the number of angles " + str(numangles)) + projused = (projused[0], numangles, projused[2]) + if projused[0] < 0: + print("Start Projection value less than zero. Value raised to 0") + projused = (0, projused[1], projused[2]) + if projused is None: + projused = (0, numangles, 1) + else: + # if projused is different than default, need to change numangles and angularrange; dula attempting to do this with these two lines, we'll see if it works! 11/16/17 + angularrange = (angularrange / (numangles - 1)) * (projused[1] - projused[0]) + #dula updated to use anglelist to find angular rage, 11 june 2020, not sure if it will work?? + angularrange = np.abs(anglelist[projused[1]] - anglelist[projused[0]]) + # want angular range to stay constant if we keep the end values consistent + numangles = len(range(projused[0], projused[1], projused[2])) + + ind_tomo = list(range(0, numangles)) + floc_independent = dxchange.reader._map_loc(ind_tomo, group_flat) + + # figure out how user can pass to do central x number of slices, or set of slices dispersed throughout (without knowing a priori the value of numslices) + if sinoused is None: + sinoused = (0, numslices, 1) + elif sinoused[0] < 0: + sinoused = (int(np.floor(numslices / 2.0) - np.ceil(sinoused[1] / 2.0)), int(np.floor(numslices / 2.0) + np.floor(sinoused[1] / 2.0)), 1) + + if verbose_printing: + print('There are ' + str(numslices) + ' sinograms, ' + str(numrays) + ' rays, and ' + str(numangles) + ' projections, with an angular range of ' +str(angularrange) + '.') + print('Looking at sinograms ' + str(sinoused[0]) + ' through ' + str(sinoused[1]-1) + ' (inclusive) in steps of ' + str(sinoused[2])) + + BeamHardeningCoefficients = (0, 1, 0, 0, 0, .004) if BeamHardeningCoefficients is None else BeamHardeningCoefficients + + if cor is None: + if verbose_printing: + print("Detecting center of rotation", end="") + + if angularrange > 300: + lastcor = int(np.floor(numangles / 2) - 1) + else: + lastcor = numangles - 1 + # I don't want to see the warnings about the reader using a deprecated variable in dxchange + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als') or (filetype == 'als1131'): + if corLoadMinimalBakDrk: + ind_dark = 0 + ind_flat = 0 + if filetype == 'als': + tomo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath, filename), ind_tomo=(0, lastcor),ind_dark=ind_dark,ind_flat=ind_flat) + else: + tomo, flat, dark, floc = read_als_1131h5(os.path.join(inputPath, filename), + ind_tomo=(0, lastcor), ind_dark=ind_dark, + ind_flat=ind_flat) + elif filetype == 'dxfile': + # if corLoadMinimalBakDrk: + # ind_dark = 0 + # ind_flat = 0 + # tomo, flat, dark, coranglelist, _ = dxchange.exchange.read_dx(os.path.join(inputPath, filename), proj=(0,numangles-1),ind_dark=ind_dark,ind_flat=ind_flat) + # tomo, flat, dark, coranglelist, _ = dxchange.exchange.read_dx(os.path.join(inputPath, filename), proj=(0,lastcor,lastcor-1)) + tomo, flat, dark, coranglelist = dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), proj=(0,lastcor,lastcor-1)) + elif (filetype == 'sls'): + tomo, flat, dark, coranglelist = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=( + timepoint * numangles, (timepoint + 1) * numangles, numangles - 1)) # dtype=None, , ) + else: + return + if bffilename is not None and (filetype == 'als'): + tomobf, flatbf, darkbf, flocbf = dxchange.read_als_832h5(os.path.join(inputPath, bffilename)) + flat = tomobf + tomo = tomo.astype(np.float32) + if useNormalize_nf and ((filetype == 'als') or (filetype == 'als1131')): + tomopy.normalize_nf(tomo, flat, dark, floc, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + else: + tomopy.normalize(tomo, flat, dark, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + + if corFunction == 'vo': + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als'): + tomovo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath, filename), sino=(sinoused[0],sinoused[0]+1,1)) + elif (filetype == 'sls'): + tomovo, flat, dark, coranglelist = read_sls(os.path.join(inputPath, filename), exchange_rank=0, sino=(sinoused[0],sinoused[0]+1,1), proj=(timepoint*numangles+projused[0],timepoint*numangles+projused[1],projused[2])) # dtype=None, , ) + else: + return + if bffilename is not None: + if (filetype == 'als'): + tomobf, flatbf, darkbf, flocbf = dxchange.read_als_832h5(os.path.join(inputPath, bffilename), sino=(sinoused[0],sinoused[0]+1,1)) + flat = tomobf + elif (filetype == 'als1131'): + tomobf, flatbf, darkbf, flocbf = read_als_1131h5(os.path.join(inputPath, bffilename), + sino=(sinoused[0], sinoused[0] + 1, 1)) + flat = tomobf + tomovo = tomovo.astype(np.float32) + + if useNormalize_nf and ((filetype == 'als') or (filetype == 'als1131')): + tomopy.normalize_nf(tomovo, flat, dark, floc, out=tomovo) + if bfexposureratio != 1: + tomovo = tomovo * bfexposureratio + else: + tomopy.normalize(tomovo, flat, dark, out=tomovo) + if bfexposureratio != 1: + tomovo = tomovo * bfexposureratio + + cor = tomopy.find_center_vo(tomovo, ind=voInd, smin=voSMin, smax=voSMax, srad=voSRad, step=voStep, + ratio=voRatio, drop=voDrop) + + + elif corFunction == 'nm': + cor = tomopy.find_center(tomo, tomopy.angles(numangles, angle_offset, angle_offset - angularrange), + ind=nmInd, init=nmInit, tol=nmTol, mask=nmMask, ratio=nmRatio, + sinogram_order=nmSinoOrder) + elif corFunction == 'pc': + if angularrange > 300: + lastcor = int(np.floor(numangles / 2) - 1) + else: + lastcor = numangles - 1 + # I don't want to see the warnings about the reader using a deprecated variable in dxchange + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als'): + tomo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath, filename), ind_tomo=(0, lastcor)) + elif (filetype == 'als1131'): + tomo, flat, dark, floc = read_als_1131h5(os.path.join(inputPath, filename), ind_tomo=(0, lastcor)) + elif (filetype == 'dxfile'): + tomo, flat, dark, coranglelist = dxchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, proj=( + 0, lastcor, lastcor-1)) # dtype=None, , ) + elif (filetype == 'sls'): + tomo, flat, dark, coranglelist = read_sls(os.path.join(inputPath, filename), exchange_rank=0, proj=( + timepoint * numangles, (timepoint + 1) * numangles, numangles - 1)) # dtype=None, , ) + else: + return + if bffilename is not None: + if (filetype == 'als'): + tomobf, flatbf, darkbf, flocbf = dxchange.read_als_832h5(os.path.join(inputPath, bffilename)) + flat = tomobf + elif (filetype == 'als1131'): + tomobf, flatbf, darkbf, flocbf = read_als_1131h5(os.path.join(inputPath, bffilename)) + flat = tomobf + tomo = tomo.astype(np.float32) + if useNormalize_nf and ((filetype == 'als') or (filetype == 'als1131')): + tomopy.normalize_nf(tomo, flat, dark, floc, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + else: + tomopy.normalize(tomo, flat, dark, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + cor = tomopy.find_center_pc(tomo[0], tomo[-1], tol=0.25) + elif corFunction == 'skip': #use this to get back the tomo variable without running processing + cor = numrays/2 + else: + raise ValueError("\'corFunction\' must be one of: [ pc, vo, nm ].") + if verbose_printing: + print(", {}".format(cor)) + else: + tomo = 0 + if verbose_printing: + print("using user input center of {}".format(cor)) + + function_list = [] + + if writeraw: + function_list.append('write_raw') + if doOutliers1D: + function_list.append('remove_outlier1d') + if doOutliers2D: + function_list.append('remove_outlier2d') + if doNormalize: + if useNormalize_nf: + function_list.append('normalize_nf') + else: + function_list.append('normalize') + if dominuslog: + function_list.append('minus_log') + if doBeamHardening: + function_list.append('beam_hardening') + if doFWringremoval: + function_list.append('remove_stripe_fw') + if doTIringremoval: + function_list.append('remove_stripe_ti') + if doSFringremoval: + function_list.append('remove_stripe_sf') + if correcttilt: + function_list.append('correcttilt') + if dolensdistortion: + function_list.append('lensdistortion') + if use360to180: + function_list.append('do_360_to_180') + if doPhaseRetrieval: + function_list.append('phase_retrieval') + if dorecon: + function_list.append('recon_mask') + if doPolarRing: + if dorecon: + function_list.append('polar_ring') + if doPolarRing2: + if dorecon: + function_list.append('polar_ring2') + if castTo8bit: + if dorecon: + function_list.append('castTo8bit') + if writereconstruction: + if dorecon: + function_list.append('write_reconstruction') + if writenormalized: + function_list.append('write_normalized') + + recon_dict = { + "inputPath": inputPath, #input file path + "filename": filename, #input file name + "filetype": filetype, #other options are als, als1131, sls + "timepoint": timepoint, + "fulloutputPath": fulloutputPath, + "outputFilename": outputFilename, + "bffilename": bffilename, #if there is a separate file with the bright fields + "doOutliers1D": doOutliers1D, # outlier removal in 1d (along sinogram columns) + "outlier_diff1D": outlier_diff1D, # difference between good data and outlier data (outlier removal) + "outlier_size1D": outlier_size1D, # radius around each pixel to look for outliers (outlier removal) + "doOutliers2D": doOutliers2D, # outlier removal, standard 2d on each projection + "outlier_diff2D": outlier_diff2D, # difference between good data and outlier data (outlier removal) + "outlier_size2D": outlier_size2D, # radius around each pixel to look for outliers (outlier removal) + "doFWringremoval": doFWringremoval, # Fourier-wavelet ring removal + "doTIringremoval": doTIringremoval, # Titarenko ring removal + "doSFringremoval": doSFringremoval, # Smoothing filter ring removal + "ringSigma": ringSigma, # damping parameter in Fourier space (Fourier-wavelet ring removal) + "ringLevel": ringLevel, # number of wavelet transform levels (Fourier-wavelet ring removal) + "ringWavelet": ringWavelet, # type of wavelet filter (Fourier-wavelet ring removal) + "ringNBlock": ringNBlock, # used in Titarenko ring removal (doTIringremoval) + "ringAlpha": ringAlpha, # used in Titarenko ring removal (doTIringremoval) + "ringSize": ringSize, # used in smoothing filter ring removal (doSFringremoval) + "doPhaseRetrieval": doPhaseRetrieval, # phase retrieval + "alphaReg": alphaReg, # smaller = smoother (used for phase retrieval) + "propagation_dist": propagation_dist, # sample-to-scintillator distance (phase retrieval) + "kev": kev, # energy level (phase retrieval) + "butterworth_cutoff": butterworth_cutoff, # 0.1 would be very smooth, 0.4 would be very grainy (reconstruction) + "butterworth_order": butterworth_order, # for reconstruction + "doTranslationCorrection": doTranslationCorrection, # correct for linear drift during scan + "xshift": xshift, # undesired dx transation correction (from 0 degree to 180 degree proj) + "yshift": yshift, # undesired dy transation correction (from 0 degree to 180 degree proj) + "doPolarRing": doPolarRing, # ring removal + "Rarc": Rarc, # min angle needed to be considered ring artifact (ring removal) + "Rmaxwidth": Rmaxwidth, # max width of rings to be filtered (ring removal) + "Rtmax": Rtmax, # max portion of image to filter (ring removal) + "Rthr": Rthr, # max value of offset due to ring artifact (ring removal) + "Rtmin": Rtmin, # min value of image to filter (ring removal) + "doPolarRing2": doPolarRing2, # ring removal + "Rarc2": Rarc2, # min angle needed to be considered ring artifact (ring removal) + "Rmaxwidth2": Rmaxwidth2, # max width of rings to be filtered (ring removal) + "Rtmax2": Rtmax2, # max portion of image to filter (ring removal) + "Rthr2": Rthr2, # max value of offset due to ring artifact (ring removal) + "Rtmin2": Rtmin2, # min value of image to filter (ring removal) + "cor": cor, # center of rotation (float). If not used then cor will be detected automatically + "corFunction": corFunction, # center of rotation function to use - can be 'pc', 'vo', or 'nm' + "voInd": voInd, # index of slice to use for cor search (vo) + "voSMin": voSMin, # min radius for searching in sinogram (vo) + "voSMax": voSMax, # max radius for searching in sinogram (vo) + "voSRad": voSRad, # search radius (vo) + "voStep": voStep, # search step (vo) + "voRatio": voRatio, # ratio of field-of-view and object size (vo) + "voDrop": voDrop, # drop lines around vertical center of mask (vo) + "nmInd": nmInd, # index of slice to use for cor search (nm) + "nmInit": nmInit, # initial guess for center (nm) + "nmTol": nmTol, # desired sub-pixel accuracy (nm) + "nmMask": nmMask, # if True, limits analysis to circular region (nm) + "nmRatio": nmRatio, # ratio of radius of circular mask to edge of reconstructed image (nm) + "nmSinoOrder": nmSinoOrder, # if True, analyzes in sinogram space. If False, analyzes in radiograph space + "use360to180": use360to180, # use 360 to 180 conversion + "castTo8bit": castTo8bit, # convert data to 8bit before writing + "cast8bit_min": cast8bit_min, # min value if converting to 8bit + "cast8bit_max": cast8bit_max, # max value if converting to 8bit + "useNormalize_nf": useNormalize_nf, # normalize based on background intensity (nf) + "chunk_proj": chunk_proj, # chunk size in projection direction + "chunk_sino": chunk_sino, # chunk size in sinogram direction + "npad": npad, # amount to pad data before reconstruction + "projused": projused, # should be slicing in projection dimension (start,end,step) Be sure to add one to the end as stop in python means the last value is omitted + "sinoused": sinoused, # should be sliceing in sinogram dimension (start,end,step). If first value is negative, it takes the number of slices from the second value in the middle of the stack. + "correcttilt": correcttilt, # tilt dataset + "tiltcenter_slice": tiltcenter_slice, # tilt center (x direction) + "tiltcenter_det": tiltcenter_det, # tilt center (y direction) + "angle_offset": angle_offset, # this is the angle offset from our default (270) so that tomopy yields output in the same orientation as previous software (Octopus) + "anglelist": anglelist, # if not set, will assume evenly spaced angles which will be calculated by the angular range and number of angles found in the file. if set to -1, will read individual angles from each image. alternatively, a list of angles can be passed. + "doBeamHardening": doBeamHardening, # turn on beam hardening correction, based on "Correction for beam hardening in computed tomography", Gabor Herman, 1979 Phys. Med. Biol. 24 81 + "BeamHardeningCoefficients": BeamHardeningCoefficients, # 6 values, tomo = a0 + a1*tomo + a2*tomo^2 + a3*tomo^3 + a4*tomo^4 + a5*tomo^5 + "projIgnoreList": projIgnoreList, # projections to be ignored in the reconstruction (for simplicity in the code, they will not be removed and will be processed as all other projections but will be set to zero absorption right before reconstruction. + "bfexposureratio": bfexposureratio, # ratio of exposure time of bf to exposure time of sample + "pxsize": pxsize, + "numslices": numslices, + "numangles": numangles, + "angularrange": angularrange, + "numrays": numrays, + "npad": npad, + "projused": projused, + "inter_bright": inter_bright, + "nflat": nflat, + "ind_flat": ind_flat, + "ndark": nflat, + "ind_dark": ind_flat, + "group_flat": group_flat, + "ind_tomo": ind_tomo, + "floc_independent": floc_independent, + "sinoused": sinoused, + "BeamHardeningCoefficients": BeamHardeningCoefficients, + "function_list": function_list, + "dorecon": dorecon, + "doNormalize": doNormalize, + "writeraw": writeraw, + "writenormalized": writenormalized, + "writereconstruction": writereconstruction, + "dominuslog": dominuslog, + "verbose_printing": verbose_printing, + "recon_algorithm": recon_algorithm, + "dolensdistortion": dolensdistortion, + "lensdistortioncenter": lensdistortioncenter, + "lensdistortionfactors": lensdistortionfactors, + "minimum_transmission": minimum_transmission, + } + + #return second variable tomo, (first and last normalized image), to use it for manual COR checking + return recon_dict, tomo + + + +# to profile memory, uncomment the following line +# and then run program from command line as +# python -m memory_profiler tomopy832.py +# (you have to have memory_profiler installed) +# @profile +def recon( + filename, + filetype = 'als', #other options are als, als1131, sls + timepoint = 0, + bffilename = None, + inputPath = './', #input path, location of the data set to reconstruct + outputFilename = None, #file name for output tif files (a number and .tiff will be added). default is based on input filename + fulloutputPath = None, # definte the full output path, no automatic sub-folder will be created + doOutliers1D = False, # outlier removal in 1d (along sinogram columns) + outlier_diff1D = 750, # difference between good data and outlier data (outlier removal) + outlier_size1D = 3, # radius around each pixel to look for outliers (outlier removal) + doOutliers2D = False, # outlier removal, standard 2d on each projection + outlier_diff2D = 750, # difference between good data and outlier data (outlier removal) + outlier_size2D = 3, # radius around each pixel to look for outliers (outlier removal) + doFWringremoval = True, # Fourier-wavelet ring removal + doTIringremoval = False, # Titarenko ring removal + doSFringremoval = False, # Smoothing filter ring removal + ringSigma = 3, # damping parameter in Fourier space (Fourier-wavelet ring removal) + ringLevel = 8, # number of wavelet transform levels (Fourier-wavelet ring removal) + ringWavelet = 'db5', # type of wavelet filter (Fourier-wavelet ring removal) + ringNBlock = 0, # used in Titarenko ring removal (doTIringremoval) + ringAlpha = 1.5, # used in Titarenko ring removal (doTIringremoval) + ringSize = 5, # used in smoothing filter ring removal (doSFringremoval) + doPhaseRetrieval = False, # phase retrieval + alphaReg = 0.00001, # smaller = smoother (used for phase retrieval) + propagation_dist = 75.0, # sample-to-scintillator distance (phase retrieval) + kev = 24.0, # energy level (phase retrieval) + butterworth_cutoff = 0.25, #0.1 would be very smooth, 0.4 would be very grainy (reconstruction) + butterworth_order = 2, # for reconstruction + doTranslationCorrection = False, # correct for linear drift during scan + xshift = 0, # undesired dx transation correction (from 0 degree to 180 degree proj) + yshift = 0, # undesired dy transation correction (from 0 degree to 180 degree proj) + doPolarRing = False, # ring removal + Rarc=30, # min angle needed to be considered ring artifact (ring removal) + Rmaxwidth=100, # max width of rings to be filtered (ring removal) + Rtmax=3000.0, # max portion of image to filter (ring removal) + Rthr=3000.0, # max value of offset due to ring artifact (ring removal) + Rtmin=-3000.0, # min value of image to filter (ring removal) + doPolarRing2 = False, # ring removal + Rarc2=30, # min angle needed to be considered ring artifact (ring removal) + Rmaxwidth2=100, # max width of rings to be filtered (ring removal) + Rtmax2=3000.0, # max portion of image to filter (ring removal) + Rthr2=3000.0, # max value of offset due to ring artifact (ring removal) + Rtmin2=-3000.0, # min value of image to filter (ring removal) + cor=None, # center of rotation (float). If not used then cor will be detected automatically + corFunction = 'pc', # center of rotation function to use - can be 'pc', 'vo', or 'nm' + voInd = None, # index of slice to use for cor search (vo) + voSMin = -40, # min radius for searching in sinogram (vo) + voSMax = 40, # max radius for searching in sinogram (vo) + voSRad = 10, # search radius (vo) + voStep = 0.5, # search step (vo) + voRatio = 2.0, # ratio of field-of-view and object size (vo) + voDrop = 20, # drop lines around vertical center of mask (vo) + nmInd = None, # index of slice to use for cor search (nm) + nmInit = None, # initial guess for center (nm) + nmTol = 0.5, # desired sub-pixel accuracy (nm) + nmMask = True, # if True, limits analysis to circular region (nm) + nmRatio = 1.0, # ratio of radius of circular mask to edge of reconstructed image (nm) + nmSinoOrder = False, # if True, analyzes in sinogram space. If False, analyzes in radiograph space + use360to180 = False, # use 360 to 180 conversion + castTo8bit = False, # convert data to 8bit before writing + cast8bit_min=-10, # min value if converting to 8bit + cast8bit_max=30, # max value if converting to 8bit + useNormalize_nf = False, # normalize based on background intensity (nf) + chunk_proj = 100, # chunk size in projection direction + chunk_sino = 100, # chunk size in sinogram direction + npad = None, # amount to pad data before reconstruction + projused = None, # should be slicing in projection dimension (start,end,step) Be sure to add one to the end as stop in python means the last value is omitted + sinoused = None, # should be sliceing in sinogram dimension (start,end,step). If first value is negative, it takes the number of slices from the second value in the middle of the stack. + correcttilt = 0, # tilt dataset + tiltcenter_slice = None, # tilt center (x direction) + tiltcenter_det = None, # tilt center (y direction) + angle_offset = 0, # this is the angle offset from our default (270) so that tomopy yields output in the same orientation as previous software (Octopus) + anglelist = None, # if not set, will assume evenly spaced angles which will be calculated by the angular range and number of angles found in the file. if set to -1, will read individual angles from each image. alternatively, a list of angles can be passed. + doBeamHardening = False, # turn on beam hardening correction, based on "Correction for beam hardening in computed tomography", Gabor Herman, 1979 Phys. Med. Biol. 24 81 + BeamHardeningCoefficients = (0, 1, 0, 0, 0, .1), # 6 values, tomo = a0 + a1*tomo + a2*tomo^2 + a3*tomo^3 + a4*tomo^4 + a5*tomo^5 + projIgnoreList = None, # projections to be ignored in the reconstruction (for simplicity in the code, they will not be removed and will be processed as all other projections but will be set to zero absorption right before reconstruction. + bfexposureratio = 1, #ratio of exposure time of bf to exposure time of sample + pxsize = .001, + numslices= 100, + numangles= 3, + angularrange= 180, + numrays= 2560, + inter_bright= 0, + nflat= 15, + ind_flat=1, + group_flat= None, + ndrk=10, + ind_dark=1, + ind_tomo= [0,1,2], + floc_independent= 1, + function_list= ['normalize','minus_log','recon_mask','write_output'], + dorecon=True, + doNormalize=True, + writeraw=False, + writenormalized=False, + writereconstruction=True, + dominuslog=True, + verbose_printing=False, + recon_algorithm='gridrec', #choose from gridrec, fbp, and others in tomopy + dolensdistortion=False, + lensdistortioncenter = (1280,1080), + lensdistortionfactors = (1.00015076, 1.9289e-06, -2.4325e-08, 1.00439e-11, -3.99352e-15), + minimum_transmission = 0.01, + *args, **kwargs + ): + + start_time = time.time() + if verbose_printing: + print("Start {} at:".format(filename)+time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.localtime())) + + filenametowrite = os.path.join(fulloutputPath,outputFilename) + if verbose_printing: + print("Time point: {}".format(timepoint)) + + tempfilenames = [os.path.join(fulloutputPath,'tmp0.h5'),os.path.join(fulloutputPath,'tmp1.h5')] + if verbose_printing: + print("cleaning up previous temp files") #, end="") + for tmpfile in tempfilenames: + try: + os.remove(tmpfile) + except OSError: + pass + + numprojused = len(range(projused[0], projused[1],projused[2])) # number of total projections. We add 1 to include the last projection + numsinoused = len(range(sinoused[0], sinoused[1],sinoused[2])) # number of total sinograms. We add 1 to include the last projection + num_proj_per_chunk = np.minimum(chunk_proj,numprojused) # sets the chunk size to either all of the projections used or the chunk size + numprojchunks = (numprojused - 1) // num_proj_per_chunk + 1 # adding 1 fixes the case of the number of projections not being a factor of the chunk size. Subtracting 1 fixes the edge case where the number of projections is a multiple of the chunk size + num_sino_per_chunk = np.minimum(chunk_sino, numsinoused) # same as num_proj_per_chunk + numsinochunks = (numsinoused - 1) // num_sino_per_chunk + 1 # adding 1 fixes the case of the number of sinograms not being a factor of the chunk size. Subtracting 1 fixes the edge case where the number of sinograms is a multiple of the chunk size + + # Ensure the output directory exists + Path(fulloutputPath).mkdir(parents=True, exist_ok=True) + set_directory_permissions(fulloutputPath) # Set permissions for the output directory + + # Figure out first direction to slice + for func in function_list: + if slice_dir[func] != 'both': + axis = slice_dir[func] + break + else: + axis = 'sino' + + done = False + curfunc = 0 + curtemp = 0 + + if not dorecon: + rec = 0 + + while True: # Loop over reading data in certain chunking direction + if axis=='proj': + niter = numprojchunks + else: + niter = numsinochunks + for y in range(niter): # Loop over chunks + if verbose_printing: + print("{} chunk {} of {}".format(axis, y+1, niter)) + # The standard case. Unless the combinations below are in our function list, we read darks and flats normally, and on next chunck proceed to "else." + if curfunc == 0 and not (('normalize_nf' in function_list and 'remove_outlier2d' in function_list) or ('remove_outlier1d' in function_list and 'remove_outlier2d' in function_list)): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if axis=='proj': + if (filetype=='als'): + tomo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath,filename),ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]),sino=(sinoused[0],sinoused[1],sinoused[2])) + if bffilename is not None: + tomobf, _, _, _ = dxchange.read_als_832h5(os.path.join(inputPath,bffilename),sino=(sinoused[0],sinoused[1],sinoused[2])) #I don't think we need this for separate bf: ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), + flat = tomobf + elif (filetype=='als1131'): + tomo, flat, dark, floc = read_als_1131h5(os.path.join(inputPath,filename),ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]),sino=(sinoused[0],sinoused[1],sinoused[2])) + if bffilename is not None: + tomobf, _, _, _ = read_als_1131h5(os.path.join(inputPath,bffilename),sino=(sinoused[0],sinoused[1],sinoused[2])) #I don't think we need this for separate bf: ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), + flat = tomobf + elif (filetype == 'dxfile'): + tomo, flat, dark, _= dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, + proj=( y * projused[2] * num_proj_per_chunk + projused[0], + + np.minimum((y + 1) * projused[2] * num_proj_per_chunk + projused[0], projused[1]), projused[2]), + sino=sinoused) # dtype=None, , ) + elif (filetype=='sls'): + tomo, flat, dark, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+y*projused[2]*num_proj_per_chunk+projused[0],timepoint*numangles+np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), sino=sinoused) #dtype=None, , ) + else: + break + else: + if (filetype == 'als'): + tomo, flat, dark, floc = dxchange.read_als_832h5(os.path.join(inputPath,filename),ind_tomo=range(projused[0],projused[1],projused[2]),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) + if bffilename is not None: + tomobf, _, _, _ = dxchange.read_als_832h5(os.path.join(inputPath, bffilename),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) # I don't think we need this for separate bf: ind_tomo=range(projused[0],projused[1],projused[2]), + flat = tomobf + elif (filetype == 'als1131'): + tomo, flat, dark, floc = read_als_1131h5(os.path.join(inputPath,filename),ind_tomo=range(projused[0],projused[1],projused[2]),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) + if bffilename is not None: + tomobf, _, _, _ = read_als_1131h5(os.path.join(inputPath, bffilename),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) # I don't think we need this for separate bf: ind_tomo=range(projused[0],projused[1],projused[2]), + flat = tomobf + elif (filetype == 'dxfile'): + tomo, flat, dark, _ = dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, + proj=( projused[0], + projused[1], projused[2]), + sino=(y * sinoused[2] * num_sino_per_chunk + sinoused[0], + np.minimum( + (y + 1) * sinoused[2] * num_sino_per_chunk + + sinoused[0], sinoused[1]), + sinoused[2])) # dtype=None, , ) + elif (filetype=='sls'): + tomo, flat, dark, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+projused[0],timepoint*numangles+projused[1],projused[2]), sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) #dtype=None, , ) + else: + break + # Handles the initial reading of scans. Flats and darks are not read in, because the chunking direction will swap before we normalize. We read in darks when we normalize. + elif curfunc == 0: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if axis=='proj': + if (filetype == 'als') or (filetype == 'als1131'): + tomo = read_als_h5_tomo_only(os.path.join(inputPath,filename),ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]),sino=(sinoused[0],sinoused[1], sinoused[2]), bl=filetype) + elif (filetype == 'dxfile'): + tomo, _, _, _, _ = read_sls(os.path.join(inputPath, filename), exchange_rank=0, proj=( + y * projused[2] * num_proj_per_chunk + projused[0], + np.minimum((y + 1) * projused[2] * num_proj_per_chunk + projused[0], + projused[1]), projused[2]), + sino=sinoused) # dtype=None, , ) + elif (filetype=='sls'): + tomo, _, _, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+y*projused[2]*num_proj_per_chunk+projused[0],timepoint*numangles+np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), sino=sinoused) #dtype=None, , ) + else: + break + else: + if (filetype == 'als') or (filetype == 'als1131'): + tomo = read_als_h5_tomo_only(os.path.join(inputPath,filename),ind_tomo=range(projused[0],projused[1],projused[2]),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2]),bl=filetype) + elif (filetype == 'dxfile'): + tomo, _, _, _ = dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, proj=( + projused[0], projused[1], projused[2]), + sino=(y * sinoused[2] * num_sino_per_chunk + sinoused[0], + np.minimum( + (y + 1) * sinoused[2] * num_sino_per_chunk + sinoused[0], + sinoused[1]), sinoused[2])) # dtype=None, , ) + elif (filetype=='sls'): + tomo, _, _, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+projused[0],timepoint*numangles+projused[1],projused[2]), sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) #dtype=None, , ) + else: + break + # Handles the reading of darks and flats, once we know the chunking direction will not change before normalizing. + elif ('remove_outlier2d' == function_list[curfunc] and 'normalize' in function_list) or 'normalize_nf' == function_list[curfunc]: + if axis == 'proj': + start, end = y * num_proj_per_chunk, np.minimum((y + 1) * num_proj_per_chunk,numprojused) + tomo = dxchange.reader.read_hdf5(tempfilenames[curtemp],'/tmp/tmp',slc=((start,end,1),(0,numslices,1),(0,numrays,1))) #read in intermediate file + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als') or (filetype == 'als1131'): + flat, dark, floc = read_als_h5_non_tomo(os.path.join(inputPath,filename),ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]),sino=(sinoused[0],sinoused[1], sinoused[2]),bl=filetype) + if bffilename is not None: + if filetype == 'als': + tomobf, _, _, _ = dxchange.read_als_832h5(os.path.join(inputPath,bffilename),sino=(sinoused[0],sinoused[1], sinoused[2])) #I don't think we need this since it is full tomo in separate file: ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]) + flat = tomobf + elif filetype == 'als1131': + tomobf, _, _, _ = read_als_1131h5(os.path.join(inputPath,bffilename),sino=(sinoused[0],sinoused[1], sinoused[2])) #I don't think we need this since it is full tomo in separate file: ind_tomo=range(y*projused[2]*num_proj_per_chunk+projused[0], np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]) + flat = tomobf + elif (filetype == 'dxfile'): + _, flat, dark, _ = dxchange.exchange.read_aps_tomoscan_hdf5(os.path.join(inputPath, filename), exchange_rank=0, proj=( + y * projused[2] * num_proj_per_chunk + projused[0], + np.minimum( + (y + 1) * projused[2] * num_proj_per_chunk + projused[0], projused[1]), + projused[2]), sino=sinoused) # dtype=None, , ) + elif (filetype=='sls'): + _, flat, dark, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+y*projused[2]*num_proj_per_chunk+projused[0],timepoint*numangles+np.minimum((y + 1)*projused[2]*num_proj_per_chunk+projused[0],projused[1]),projused[2]), sino=sinoused) #dtype=None, , ) + else: + break + else: + start, end = y * num_sino_per_chunk, np.minimum((y + 1) * num_sino_per_chunk,numsinoused) + tomo = dxchange.reader.read_hdf5(tempfilenames[curtemp],'/tmp/tmp',slc=((0,numangles,1),(start,end,1),(0,numrays,1))) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if (filetype == 'als') or (filetype == 'als1131'): + flat, dark, floc = read_als_h5_non_tomo(os.path.join(inputPath,filename),ind_tomo=range(projused[0],projused[1],projused[2]),sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2]),bl=filetype) + elif (filetype == 'dxfile'): + _, flat, dark, _, _ = read_sls(os.path.join(inputPath, filename), exchange_rank=0, proj=( + projused[0], projused[1], projused[2]), + sino=(y * sinoused[2] * num_sino_per_chunk + sinoused[0], + np.minimum( + (y + 1) * sinoused[2] * num_sino_per_chunk + sinoused[ + 0], sinoused[1]), sinoused[2])) # dtype=None, , ) + elif (filetype=='sls'): + _, flat, dark, _ = read_sls(os.path.join(inputPath,filename), exchange_rank=0, proj=(timepoint*numangles+projused[0],timepoint*numangles+projused[1],projused[2]), sino=(y*sinoused[2]*num_sino_per_chunk+sinoused[0],np.minimum((y + 1)*sinoused[2]*num_sino_per_chunk+sinoused[0],sinoused[1]),sinoused[2])) #dtype=None, , ) + else: + break + # Anything after darks and flats have been read or the case in which remove_outlier2d is the current/2nd function and the previous case fails. + else: + if axis=='proj': + start, end = y * num_proj_per_chunk, np.minimum((y + 1) * num_proj_per_chunk,numprojused) + tomo = dxchange.reader.read_hdf5(tempfilenames[curtemp],'/tmp/tmp',slc=((start,end,1),(0,numslices,1),(0,numrays,1))) #read in intermediate file + else: + start, end = y * num_sino_per_chunk, np.minimum((y + 1) * num_sino_per_chunk,numsinoused) + tomo = dxchange.reader.read_hdf5(tempfilenames[curtemp],'/tmp/tmp',slc=((0,numangles,1),(start,end,1),(0,numrays,1))) + dofunc = curfunc + keepvalues = None + while True: # Loop over operations to do in current chunking direction + func_name = function_list[dofunc] + newaxis = slice_dir[func_name] + if newaxis != 'both' and newaxis != axis: + # We have to switch axis, so flush to disk + if y==0: + try: + os.remove(tempfilenames[1-curtemp]) + except OSError: + pass + appendaxis = 1 if axis=='sino' else 0 + dxchange.writer.write_hdf5(tomo,fname=tempfilenames[1-curtemp],gname='tmp',dname='tmp',overwrite=False,appendaxis=appendaxis) #writing intermediate file... + break + if verbose_printing: + print(func_name, end=" ") + curtime = time.time() + if func_name == 'write_raw': + dxchange.write_tiff_stack(tomo, fname=filenametowrite,start=y * num_proj_per_chunk + projused[0]) + if y == 0: + dxchange.write_tiff_stack(flat, fname=filenametowrite+'bak',start=0) + dxchange.write_tiff_stack(dark, fname=filenametowrite + 'drk', start=0) + elif func_name == 'remove_outlier1d': + tomo = tomo.astype(np.float32,copy=False) + remove_outlier1d(tomo, outlier_diff1D, size=outlier_size1D, out=tomo) + elif func_name == 'remove_outlier2d': + tomo = tomo.astype(np.float32,copy=False) + tomopy.remove_outlier(tomo, outlier_diff2D, size=outlier_size2D, axis=0, out=tomo) + elif func_name == 'normalize_nf': + tomo = tomo.astype(np.float32,copy=False) + tomopy.normalize_nf(tomo, flat, dark, floc_independent, out=tomo) #use floc_independent b/c when you read file in proj chunks, you don't get the correct floc returned right now to use here. + if bfexposureratio != 1: + if verbose_printing: + print("correcting bfexposureratio") + tomo = tomo * bfexposureratio + elif func_name == 'normalize': + tomo = tomo.astype(np.float32,copy=False) + tomopy.normalize(tomo, flat, dark, out=tomo) + if bfexposureratio != 1: + tomo = tomo * bfexposureratio + if verbose_printing: + print("correcting bfexposureratio") + elif func_name == 'minus_log': + mx = np.float32(minimum_transmission) #setting min %transmission to 1% helps avoid streaking from very high absorbing areas + ne.evaluate('where(tomo>mx, tomo, mx)', out=tomo) + tomopy.minus_log(tomo, out=tomo) + elif func_name == 'beam_hardening': + loc_dict = {'a{}'.format(i):np.float32(val) for i,val in enumerate(BeamHardeningCoefficients)} + loc_dict['tomo'] = tomo + tomo = ne.evaluate('a0 + a1*tomo + a2*tomo**2 + a3*tomo**3 + a4*tomo**4 + a5*tomo**5', local_dict=loc_dict, out=tomo) + elif func_name == 'remove_stripe_fw': + tomo = tomopy.remove_stripe_fw(tomo, sigma=ringSigma, level=ringLevel, pad=True, wname=ringWavelet) + elif func_name == 'remove_stripe_ti': + tomo = tomopy.remove_stripe_ti(tomo, nblock=ringNBlock, alpha=ringAlpha) + elif func_name == 'remove_stripe_sf': + tomo = tomopy.remove_stripe_sf(tomo, size=ringSize) + elif func_name == 'correcttilt': + if tiltcenter_slice is None: + tiltcenter_slice = numslices/2. + if tiltcenter_det is None: + tiltcenter_det = tomo.shape[2]/2 + new_center = tiltcenter_slice - 0.5 - sinoused[0] + center_det = tiltcenter_det - 0.5 + + # add padding of 10 pixels, to be unpadded right after tilt correction. + # This makes the tilted image not have zeros at certain edges, + # which matters in cases where sample is bigger than the field of view. + # For the small amounts we are generally tilting the images, 10 pixels is sufficient. + # tomo = tomopy.pad(tomo, 2, npad=10, mode='edge') + # center_det = center_det + 10 + + cntr = (center_det, new_center) + for b in range(tomo.shape[0]): + tomo[b] = st.rotate(tomo[b], correcttilt, center=cntr, preserve_range=True, order=1, mode='edge', clip=True) # center=None means image is rotated around its center; order=1 is default, order of spline interpolation +# tomo = tomo[:, :, 10:-10] + elif func_name == 'lensdistortion': + if verbose_printing: + print(lensdistortioncenter[0]) + print(lensdistortioncenter[1]) + print(lensdistortionfactors[0]) + print(lensdistortionfactors[1]) + print(lensdistortionfactors[2]) + print(lensdistortionfactors[3]) + print(lensdistortionfactors[4]) + print(type(lensdistortionfactors[0])) + print(type(lensdistortionfactors)) + tomo = tomopy.prep.alignment.distortion_correction_proj(tomo, lensdistortioncenter[0], lensdistortioncenter[1], lensdistortionfactors, ncore=None,nchunk=None) + elif func_name == 'do_360_to_180': + + # Keep values around for processing the next chunk in the list + keepvalues = [angularrange, numangles, projused, num_proj_per_chunk, numprojchunks, numprojused, numrays, anglelist] + + # why -.5 on one and not on the other? + if tomo.shape[0]%2>0: + tomo = sino_360_to_180(tomo[0:-1,:,:], overlap=int(np.round((tomo.shape[2]-cor-.5))*2), rotation='right') + angularrange = angularrange/2 - angularrange/(tomo.shape[0]-1) + else: + tomo = sino_360_to_180(tomo[:,:,:], overlap=int(np.round((tomo.shape[2]-cor))*2), rotation='right') + angularrange = angularrange/2 + numangles = int(numangles/2) + projused = (0,numangles-1,1) + numprojused = len(range(projused[0],projused[1],projused[2])) + num_proj_per_chunk = np.minimum(chunk_proj,numprojused) + numprojchunks = (numprojused-1)//num_proj_per_chunk+1 + numrays = tomo.shape[2] + + anglelist = anglelist[:numangles] + + elif func_name == 'phase_retrieval': + tomo = tomopy.retrieve_phase(tomo, pixel_size=pxsize, dist=propagation_dist, energy=kev, alpha=alphaReg, pad=True) + + elif func_name == 'translation_correction': + tomo = linear_translation_correction(tomo,dx=xshift,dy=yshift,interpolation=False) + + elif func_name == 'recon_mask': + tomo = tomopy.pad(tomo, 2, npad=npad, mode='edge') + + if projIgnoreList is not None: + for badproj in projIgnoreList: + tomo[badproj] = 0 + rec = tomopy.recon(tomo, anglelist, center=cor+npad, algorithm=recon_algorithm, filter_name='butterworth', filter_par=[butterworth_cutoff, butterworth_order], ncore=64) + rec = rec[:, npad:-npad, npad:-npad] + rec /= pxsize # convert reconstructed voxel values from 1/pixel to 1/cm + rec = tomopy.circ_mask(rec, 0) + tomo = tomo[:, :, npad:-npad] + elif func_name == 'polar_ring': + rec = np.ascontiguousarray(rec, dtype=np.float32) + rec = tomopy.remove_ring(rec, theta_min=Rarc, rwidth=Rmaxwidth, thresh_max=Rtmax, thresh=Rthr, thresh_min=Rtmin,out=rec) + elif func_name == 'polar_ring2': + rec = np.ascontiguousarray(rec, dtype=np.float32) + rec = tomopy.remove_ring(rec, theta_min=Rarc2, rwidth=Rmaxwidth2, thresh_max=Rtmax2, thresh=Rthr2, thresh_min=Rtmin2,out=rec) + elif func_name == 'castTo8bit': + rec = convert8bit(rec, cast8bit_min, cast8bit_max) + elif func_name == 'write_reconstruction': + if dorecon: + if sinoused[2] == 1: + dxchange.write_tiff_stack(rec, fname=filenametowrite, start=y*num_sino_per_chunk + sinoused[0]) + else: + num = y*sinoused[2]*num_sino_per_chunk+sinoused[0] + for sinotowrite in rec: #fixes issue where dxchange only writes for step sizes of 1 + dxchange.writer.write_tiff(sinotowrite, fname=filenametowrite + '_' + '{0:0={1}d}'.format(num, 5)) + set_file_permissions(filenametowrite + '_' + '{0:0={1}d}_norm'.format(num, 5)) # Set permissions for individual files + num += sinoused[2] + else: + if verbose_printing: + print('Reconstruction was not done because dorecon was set to False.') + elif func_name == 'write_normalized': + if projused[2] == 1: + dxchange.write_tiff_stack(tomo, fname=filenametowrite+'_norm', start=y * num_proj_per_chunk + projused[0]) + else: + num = y * projused[2] * num_proj_per_chunk + projused[0] + for projtowrite in tomo: # fixes issue where dxchange only writes for step sizes of 1 + dxchange.writer.write_tiff(projtowrite,fname=filenametowrite + '_' + '{0:0={1}d}_norm'.format(num, 5)) + set_file_permissions(filenametowrite + '_' + '{0:0={1}d}_norm'.format(num, 5)) # Set permissions for individual files + num += projused[2] + if verbose_printing: + print('(took {:.2f} seconds)'.format(time.time()-curtime)) + dofunc+=1 + if dofunc==len(function_list): + break + if y255,255,scl)',out=scl) + return scl.astype(np.uint8) + + +def sino_360_to_180(data, overlap=0, rotation='left'): + """ + Converts 0-360 degrees sinogram to a 0-180 sinogram. + + Parameters + ---------- + data : ndarray + Input 3D data. + + overlap : scalar, optional + Overlapping number of pixels. + + rotation : string, optional + Left if rotation center is close to the left of the + field-of-view, right otherwise. + + Returns + ------- + ndarray + Output 3D data. + """ + dx, dy, dz = data.shape + lo = overlap//2 + ro = overlap - lo + n = dx//2 + out = np.zeros((n, dy, 2*dz-overlap), dtype=data.dtype) + if rotation == 'left': + weights = (np.arange(overlap)+0.5)/overlap + out[:, :, -dz+overlap:] = data[:n, :, overlap:] + out[:, :, :dz-overlap] = data[n:2*n, :, overlap:][:, :, ::-1] + out[:, :, dz-overlap:dz] = weights*data[:n, :, :overlap] + (weights*data[n:2*n, :, :overlap])[:, :, ::-1] + elif rotation == 'right': + weights = (np.arange(overlap)[::-1]+0.5)/overlap + out[:, :, :dz-overlap] = data[:n, :, :-overlap] + out[:, :, -dz+overlap:] = data[n:2*n, :, :-overlap][:, :, ::-1] + out[:, :, dz-overlap:dz] = weights*data[:n, :, -overlap:] + (weights*data[n:2*n, :, -overlap:])[:, :, ::-1] + return out + + + + +def remove_outlier1d(arr, dif, size=3, axis=0, ncore=None, out=None): + """ + Remove high intensity bright spots from an array, using a one-dimensional + median filter along the specified axis. + + Dula: also removes dark spots + + Parameters + ---------- + arr : ndarray + Input array. + dif : float + Expected difference value between outlier value and + the median value of the array. + size : int + Size of the median filter. + axis : int, optional + Axis along which median filtering is performed. + ncore : int, optional + Number of cores that will be assigned to jobs. + out : ndarray, optional + Output array for result. If same as arr, process will be done in-place. + Returns + ------- + ndarray + Corrected array. + """ + arr = arr.astype(np.float32,copy=False) + dif = np.float32(dif) + + tmp = np.empty_like(arr) + + other_axes = [i for i in range(arr.ndim) if i != axis] + largest = np.argmax([arr.shape[i] for i in other_axes]) + lar_axis = other_axes[largest] + ncore, chnk_slices = mproc.get_ncore_slices(arr.shape[lar_axis],ncore=ncore) + filt_size = [1]*arr.ndim + filt_size[axis] = size + + with cf.ThreadPoolExecutor(ncore) as e: + slc = [slice(None)]*arr.ndim + for i in range(ncore): + slc[lar_axis] = chnk_slices[i] + e.submit(snf.median_filter, arr[slc], size=filt_size,output=tmp[slc], mode='mirror') + + with mproc.set_numexpr_threads(ncore): + out = ne.evaluate('where(abs(arr-tmp)>=dif,tmp,arr)', out=out) + + return out + + +def translate(data,dx=0,dy=0,interpolation=True): + """ + Shifts all projections in an image stack by dx (horizontal) and dy (vertical) pixels. Translation with subpixel resolution is possible with interpolation==True + + Parameters + ---------- + data: ndarray + Input array, stack of 2D (x,y) images, angle in z + + dx: int or float + desored horizontal pixel shift + + dy: int or float + desired vertical pixel shift + + interpolation: boolean + True calls funtion from sckimage to interpolate image when subpixel shifts are applied + + Returns + ------- + ndarray + Corrected array. + """ + + Nproj, Nrow, Ncol = data.shape + dataOut = np.zeros(data.shape) + + if interpolation == True: + #translateFunction = st.SimilarityTransform(translation=(-dx,dy)) + M=np.matrix([[1,0,-dx],[0,1,dy],[0,0,1]]) + translateFunction = st.SimilarityTransform(matrix=M) + for n in range(Nproj): + dataOut[n,:,:] = st.warp(data[n,:,:], translateFunction) + + if interpolation == False: + Npad = max(dx,dy) + drow = int(-dy) # negative matrix row increments = dy + dcol = int(dx) # matrix column increments = dx + for n in range(Nproj): + PaddedImage = np.pad(data[n,:,:],Npad,'constant') + dataOut[n,:,:] = PaddedImage[Npad-drow:Npad+Nrow-drow,Npad-dcol:Npad+Ncol-dcol] # shift image by dx and dy, replace original without padding + + return dataOut + + +def linear_translation_correction(data,dx=100.5,dy=700.1,interpolation=True): + + """ + Corrects for a linear drift in field of view (horizontal dx, vertical dy) over time. The first index indicaties time data[time,:,:] in the time series of projections. dx and dy are the final shifts in FOV position. + + Parameters + ---------- + data: ndarray + Input array, stack of 2D (x,y) images, angle in z + + dx: int or float + total horizontal pixel offset from first (0 deg) to last (180 deg) projection + + dy: int or float + total horizontal pixel offset from first (0 deg) to last (180 deg) projection + + interpolation: boolean + True calls funtion from sckimage to interpolate image when subpixel shifts are applied + + Returns + ------- + ndarray + Corrected array. + """ + + Nproj, Nrow, Ncol = data.shape + Nproj=10 + + dataOut = np.zeros(data0.shape) + + dx_n = np.linspace(0,dx,Nproj) # generate array dx[n] of pixel shift for projection n = 0, 1, ... Nproj + + dy_n = np.linspace(0,dy,Nproj) # generate array dy[n] of pixel shift for projection n = 0, 1, ... Nproj + + if interpolation==True: + for n in range(Nproj): + #translateFunction = st.SimilarityTransform(translation=(-dx_n[n],dy_n[n])) # Generate Translation Function based on dy[n] and dx[n] + M=np.matrix([[1,0,-dx_n[n]],[0,1,dy_n[n]],[0,0,1]]) + translateFunction = st.SimilarityTransform(matrix=M) + image_n = data[n,:,:] + dataOut[n,:,:] = st.warp(image_n, translateFunction) # Apply translation with interpolation to projection[n] + #print(n) + + if interpolation==False: + Npad = max(dx,dy) + for n in range(Nproj): + PaddedImage = np.pad(data[n,:,:],Npad,'constant') # copy single projection and pad with maximum of dx,dy + drow = int(round(-dy_n[n])) # round shift to nearest pixel step, negative matrix row increments = dy + dcol = int(round(dx_n[n])) # round shift to nearest pixel step, matrix column increments = dx + dataOut[n,:,:] = PaddedImage[Npad-drow:Npad+Nrow-drow,Npad-dcol:Npad+Ncol-dcol] # shift image by dx and dy, replace original without padding + #print(n) + + return dataOut + + + """ + Parameters + ---------- + data: ndarray + Input array, stack of 2D (x,y) images, angle in z + pixelshift: float + total pixel offset from first (0 deg) to last (180 deg) projection + + Returns + ------- + ndarray + Corrected array. + """ + + +"""Hi Dula, +This is roughly what I am doing in the script to 'unspiral' the superweave data: +spd = float(int(sys.argv[2])/2048) +x = np.zeros((2049,200,2560), dtype=np.float32) +blks = np.round(np.linspace(0,2049,21)).astype(np.int) +for i in range(0,20): + dat = dxchange.read_als_832h5(fn, ind_tomo=range(blks[i],blks[i+1])) + prj = tomopy.normalize_nf(dat[0],dat[1],dat[2],dat[3]) + for ik,j in enumerate(range(blks[i],blks[i+1])): + l = prj.shape[1]//2-j*spd + li = int(l) + ri = li+200 + fc = l-li + x[j] = (1-fc)*prj[ik,li:ri] + x[j] += fc*prj[ik,li+1:ri+1] +dxchange.writer.write_hdf5(x, fname=fn[:-3]+'_unspiral.h5', overwrite=True, gname='tmp', dname='tmp', appendaxis=0) + +This processes the (roughly) central 200 slices, and saves it to a new file. The vertical speed is one of the input arguments, and I simply estimate it manually by looking at the first and last projection, shifting them by 'np.roll'. The input argument is the total number of pixels that are shifted over the whole scan (which is then converted to pixels per projection by dividing by the number of projections-1). +I don't really remember why I wrote my own code, but maybe I was running into problems using scikit-image as well. The current code uses linear interpolation, and gives pretty good results for the data I tested. + +Best, + +Daniel""" + + +def convertthetype(val): + constructors = [int, float, str] + for c in constructors: + try: + return c(val) + except ValueError: + pass + +############################################################################################### +# New Readers, so we don't have to read in darks and flats until they're needed +############################################################################################### +# Tomo +############################################################################################### + +def read_als_h5_tomo_only(fname, ind_tomo=None, ind_flat=None, ind_dark=None, + proj=None, sino=None): + """ + Read ALS 8.3.2 hdf5 file with stacked datasets. + + Parameters + ---------- + See docs for read_als_832h5 + """ + + with dxchange.reader.find_dataset_group(fname) as dgroup: + dname = dgroup.name.split('/')[-1] + + tomo_name = dname + '_0000_0000.tif' + + # Read metadata from dataset group attributes + keys = list(dgroup.attrs.keys()) + if 'nangles' in keys: + nproj = int(dgroup.attrs['nangles']) + + # Create arrays of indices to read projections + if ind_tomo is None: + ind_tomo = list(range(0, nproj)) + if proj is not None: + ind_tomo = ind_tomo[slice(*proj)] + + tomo = dxchange.reader.read_hdf5_stack( + dgroup, tomo_name, ind_tomo, slc=(None, sino)) + + return tomo + + +##################################################################################### +# Non tomo +##################################################################################### + +def read_als_h5_non_tomo(fname, ind_tomo=None, ind_flat=None, ind_dark=None, + proj=None, sino=None, whichbeamline='als'): + """ + Read ALS 8.3.2 hdf5 file with stacked datasets. + + Parameters + ---------- + See docs for read_als_832h5 + """ + + with dxchange.reader.find_dataset_group(fname) as dgroup: + dname = dgroup.name.split('/')[-1] + + flat_name = dname + 'bak_0000.tif' + dark_name = dname + 'drk_0000.tif' + + # Read metadata from dataset group attributes + keys = list(dgroup.attrs.keys()) + if 'nangles' in keys: + nproj = int(dgroup.attrs['nangles']) + if 'i0cycle' in keys: + inter_bright = int(dgroup.attrs['i0cycle']) + if 'num_bright_field' in keys: + nflat = int(dgroup.attrs['num_bright_field']) + else: + nflat = dxchange.reader._count_proj(dgroup, flat_name, nproj, + inter_bright=inter_bright) + if 'num_dark_fields' in keys: + ndark = int(dgroup.attrs['num_dark_fields']) + else: + ndark = dxchange.reader._count_proj(dgroup, dark_name, nproj) + + # Create arrays of indices to read projections, flats and darks + if ind_tomo is None: + ind_tomo = list(range(0, nproj)) + if proj is not None: + ind_tomo = ind_tomo[slice(*proj)] + ind_dark = list(range(0, ndark)) + group_dark = [nproj - 1] + ind_flat = list(range(0, nflat)) + + if inter_bright > 0: + group_flat = list(range(0, nproj, inter_bright)) + if group_flat[-1] != nproj - 1: + group_flat.append(nproj - 1) + elif inter_bright == 0: + if whichbeamline == 'als1131': + group_flat = [nproj - 1] + else: + group_flat = [0, nproj - 1] + else: + group_flat = None + + flat = dxchange.reader.read_hdf5_stack( + dgroup, flat_name, ind_flat, slc=(None, sino), out_ind=group_flat) + + dark = dxchange.reader.read_hdf5_stack( + dgroup, dark_name, ind_dark, slc=(None, sino), out_ind=group_dark) + + return flat, dark, dxchange.reader._map_loc(ind_tomo, group_flat) + + + +###################################################################################################### + +def read_als_1131h5(fname, ind_tomo=None, ind_flat=None, ind_dark=None, + proj=None, sino=None): + """ + Read ALS 11.3.1 hdf5 file with stacked datasets. + + Parameters + ---------- + + fname : str + Path to hdf5 file. + + ind_tomo : list of int, optional + Indices of the projection files to read. + + ind_flat : list of int, optional + Indices of the flat field files to read. + + ind_dark : list of int, optional + Indices of the dark field files to read. + + proj : {sequence, int}, optional + Specify projections to read. (start, end, step) + + sino : {sequence, int}, optional + Specify sinograms to read. (start, end, step) + + Returns + ------- + ndarray + 3D tomographic data. + + ndarray + 3D flat field data. + + ndarray + 3D dark field data. + + list of int + Indices of flat field data within tomography projection list + """ + + with dxchange.reader.find_dataset_group(fname) as dgroup: + dname = dgroup.name.split('/')[-1] + + tomo_name = dname + '_0000_0000.tif' + flat_name = dname + 'bak_0000.tif' + dark_name = dname + 'drk_0000.tif' + + # Read metadata from dataset group attributes + keys = list(dgroup.attrs.keys()) + if 'nangles' in keys: + nproj = int(dgroup.attrs['nangles']) + if 'i0cycle' in keys: + inter_bright = int(dgroup.attrs['i0cycle']) + if 'num_bright_field' in keys: + nflat = int(dgroup.attrs['num_bright_field']) + else: + nflat = dxchange.reader._count_proj(dgroup, flat_name, nproj, + inter_bright=inter_bright) + if 'num_dark_fields' in keys: + ndark = int(dgroup.attrs['num_dark_fields']) + else: + ndark = dxchange.reader._count_proj(dgroup, dark_name, nproj) + + # Create arrays of indices to read projections, flats and darks + if ind_tomo is None: + ind_tomo = list(range(0, nproj)) + if proj is not None: + ind_tomo = ind_tomo[slice(*proj)] + ind_dark = list(range(0, ndark)) + group_dark = [nproj - 1] + ind_flat = list(range(0, nflat)) + + if inter_bright > 0: + group_flat = list(range(0, nproj, inter_bright)) + if group_flat[-1] != nproj - 1: + group_flat.append(nproj - 1) + elif inter_bright == 0: + #group_flat = [0, nproj - 1] + group_flat = [nproj - 1] + else: + group_flat = None + + tomo = dxchange.reader.read_hdf5_stack( + dgroup, tomo_name, ind_tomo, slc=(None, sino)) + + flat = dxchange.reader.read_hdf5_stack( + dgroup, flat_name, ind_flat, slc=(None, sino), out_ind=group_flat) + + dark = dxchange.reader.read_hdf5_stack( + dgroup, dark_name, ind_dark, slc=(None, sino), out_ind=group_dark) + +# return tomo, flat, dark, dxchange.reader._map_loc(ind_tomo, group_flat) + return tomo, flat, dark, 0 + +###################################################################################################### + + +def read_sls(fname, exchange_rank=0, proj=None, sino=None, dtype=None): + """ + Read sls time resolved data format. + + Parameters + ---------- + fname : str + Path to hdf5 file. + + exchange_rank : int, optional + exchange_rank is added to "exchange" to point tomopy to the data + to reconstruct. if rank is not set then the data are raw from the + detector and are located under exchange = "exchange/...", to process + data that are the result of some intemedite processing step then + exchange_rank = 1, 2, ... will direct tomopy to process + "exchange1/...", + + proj : {sequence, int}, optional + Specify projections to read. (start, end, step) + + sino : {sequence, int}, optional + Specify sinograms to read. (start, end, step) + + dtype : numpy datatype, optional + Convert data to this datatype on read if specified. + + ind_tomo : list of int, optional + Indices of the projection files to read. + + Returns + ------- + ndarray + 3D tomographic data. + + ndarray + 3D flat field data. + + ndarray + 3D dark field data. + + ndarray + 1D theta in radian. + """ + if exchange_rank > 0: + exchange_base = 'exchange{:d}'.format(int(exchange_rank)) + else: + exchange_base = "exchange" + + tomo_grp = '/'.join([exchange_base, 'data']) + flat_grp = '/'.join([exchange_base, 'data_white']) + dark_grp = '/'.join([exchange_base, 'data_dark']) + theta_grp = '/'.join([exchange_base, 'theta']) + + tomo = dxchange.read_hdf5(fname, tomo_grp, slc=(proj, sino), dtype=dtype) + flat = dxchange.read_hdf5(fname, flat_grp, slc=(None, sino), dtype=dtype) + dark = dxchange.read_hdf5(fname, dark_grp, slc=(None, sino), dtype=dtype) + theta = dxchange.read_hdf5(fname, theta_grp) + + if (theta is None): + theta_grp = '/'.join([exchange_base, 'theta_aborted']) + theta = dxchange.read_hdf5(fname, theta_grp) + if (theta is None): + if verbose_printing: + print('could not find thetas, generating them based on 180 degree rotation') + theta_size = dxchange.read_dx_dims(fname, 'data')[0] + logger.warn('Generating "%s" [0-180] deg angles for missing "exchange/theta" dataset' % (str(theta_size))) + theta = np.linspace(0., 180., theta_size) + + theta = theta * np.pi / 180. + + if proj is not None: + theta = theta[proj[0]:proj[1]:proj[2]] + + return tomo, flat, dark, theta + +#Converts spreadsheet.xlsx file with headers into dictionaries +# def read_spreadsheet(filepath): +# workbook=xlrd.open_workbook(filepath) +# worksheet = workbook.sheet_by_index(0) +# +# # imports first row and converts to a list of header strings +# headerList = [] +# for col_index in range(worksheet.ncols): +# headerList.append(str(worksheet.cell_value(0,col_index))) +# +# dataList = [] +# # For each row, create a dictionary and like header name to data +# # converts each row to following format rowDictionary1 ={'header1':colvalue1,'header2':colvalue2,... } +# # compiles rowDictinaries into a list: dataList = [rowDictionary1, rowDictionary2,...] +# for row_index in range(1,worksheet.nrows): +# rowDictionary = {} +# for col_index in range(worksheet.ncols): +# cellValue = worksheet.cell_value(row_index,col_index) +# +# if type(cellValue)==unicode: +# cellValue = str(cellValue) +# +# # if cell contains string that looks like a tuple, convert to tuple +# if '(' in str(cellValue): +# cellValue = literal_eval(cellValue) +# +# # if cell contains string or int that looks like 'True', convert to boolean True +# if str(cellValue).lower() =='true' or (type(cellValue)==int and cellValue==1): +# cellValue = True +# +# # if cell contains string or int that looks like 'False', convert to boolean False +# if str(cellValue).lower() =='false' or (type(cellValue)==int and cellValue==0): +# cellValue = False +# +# if cellValue != '': # create dictionary element if cell value is not empty +# rowDictionary[headerList[col_index]] = cellValue +# dataList.append(rowDictionary) +# +# return(dataList) + + +# D.Y.Parkinson's interpreter for text input files +def main(parametersfile): + + if parametersfile.split('.')[-1] == 'txt': + with open(parametersfile,'r') as theinputfile: + theinput = theinputfile.read() + inputlist = theinput.splitlines() + for reconcounter in range(0,len(inputlist)): + inputlisttabsplit = inputlist[reconcounter].split() + if inputlisttabsplit: + functioninput = {'filename': inputlisttabsplit[0]} + for inputcounter in range(0,(len(inputlisttabsplit)-1)//2): + inputlisttabsplit[inputcounter*2+2] = inputlisttabsplit[inputcounter*2+2].replace('\"','') + inputcommasplit = inputlisttabsplit[inputcounter*2+2].split(',') + if len(inputcommasplit)>1: + inputcommasplitconverted = [] + for jk in range(0,len(inputcommasplit)): + inputcommasplitconverted.append(convertthetype(inputcommasplit[jk])) + else: + inputcommasplitconverted = convertthetype(inputlisttabsplit[inputcounter*2+2]) + functioninput[inputlisttabsplit[inputcounter*2+1]] = inputcommasplitconverted + else: + print("Ending at blank line in input.") + break + print("Read user input:") + print(functioninput) + recon_dictionary, _ = recon_setup(**functioninput) +# recon(**functioninput) + recon(**recon_dictionary) + +# H.S.Barnard Spreadsheet interpreter +# if parametersfile.split('.')[-1]=='xlsx': +# functioninput = read_spreadsheet(parametersfile) +# for i in range(len(functioninput)): +# recon(**functioninput[i]) + +# if __name__ == '__main__': +# parametersfile = 'input832.txt' if (len(sys.argv)<2) else sys.argv[1] +# main(parametersfile) + +def set_directory_permissions(path): + os.chmod(path, stat.S_IRWXU | stat.S_IRWXG | stat.S_IROTH | stat.S_IXOTH | stat.S_ISGID) + +def set_file_permissions(path): + os.chmod(path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IWGRP | stat.S_IROTH) + +if __name__ == '__main__': + if len(sys.argv) < 3: + print("Usage: python reconstruction.py [sino_start] [sino_end]") + sys.exit(1) + + file_name = sys.argv[1] + folder_path = sys.argv[2] + '/' + + # Optional slice range for multi-node parallelism + sino_start = int(sys.argv[3]) if len(sys.argv) > 3 else None + sino_end = int(sys.argv[4]) if len(sys.argv) > 4 else None + + sinoused = None + if sino_start is not None and sino_end is not None: + sinoused = (sino_start, sino_end, 1) + + recon_dictionary, _ = recon_setup(filename=file_name, + inputPath=folder_path, + outputPath="../scratch/"+folder_path, + sinoused=sinoused) + + recon(**recon_dictionary) diff --git a/scripts/perlmutter/tiff_to_zarr.py b/scripts/perlmutter/tiff_to_zarr.py new file mode 100644 index 00000000..e07c6e0e --- /dev/null +++ b/scripts/perlmutter/tiff_to_zarr.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python + +import argparse +import os +import sys + +import dxchange +from ngff_zarr import ( + detect_cli_io_backend, + cli_input_to_ngff_image, + to_multiscales, + to_ngff_zarr, + Methods +) + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='Convert TIFF files to NGFF Zarr format.') + parser.add_argument('tiff_directory', type=str, help='Directory containing the TIFF files.') + parser.add_argument('--zarr_directory', + type=str, + default=None, + help='Directory to store Zarr output. Default is new folder in input directory.') + parser.add_argument('--raw_file', + type=str, + default=None, + help='Path to the raw hdf5 file (for reading pixelsize metadata).') + + return parser.parse_args() + + +def debug_log(message): + """Writes a debug message to stdout, ensuring it appears in SLURM logs.""" + sys.stdout.write(f"{message}\n") + sys.stdout.flush() + + +def set_permissions_recursive(path, permissions=0o2775): + for root, dirs, files in os.walk(path): + for dir in dirs: + os.chmod(os.path.join(root, dir), permissions) + for file in files: + os.chmod(os.path.join(root, file), permissions) + os.chmod(path, permissions) # Also set permissions for the top-level directory + + +def read_pixelsize_from_hdf5(raw_file: str) -> dict: + pxsize = dxchange.read_hdf5(raw_file, + "/measurement/instrument/detector/pixel_size")[0] # Expect mm + pxsize = pxsize * 1000 # Convert to micrometer + return {'x': pxsize, 'y': pxsize, 'z': pxsize} + + +def main(): + args = parse_arguments() + + tiff_dir = args.tiff_directory + zarr_dir = args.zarr_directory + + # Debugging: Print the TIFF directory and current working directory + debug_log(f"Specified TIFF directory: {tiff_dir}") + debug_log(f"Current working directory: {os.getcwd()}") + + # Debugging: List contents of the current working directory + debug_log("Contents of the current working directory:") + debug_log(", ".join(os.listdir(os.getcwd()))) + + if not os.path.isdir(tiff_dir): + raise TypeError("The specified TIFF directory is not a valid directory") + + file_names = os.listdir(tiff_dir) + file_paths = [os.path.join(tiff_dir, file_name) for file_name in file_names if not file_name.startswith('.')] + file_paths.sort() + + if not file_paths: + raise ValueError("No TIFF files found in the specified directory") + if not zarr_dir: + last_part = os.path.basename(os.path.normpath(tiff_dir)) + zarr_dir = os.path.abspath(os.path.join(tiff_dir, '..', last_part + '.zarr')) + if not os.path.exists(zarr_dir): + os.makedirs(zarr_dir, mode=0o2775, exist_ok=True) + + print('Output directory: ' + zarr_dir) + + # Build NGFF Zarr directory + backend = detect_cli_io_backend(file_paths) + image = cli_input_to_ngff_image(backend, file_paths) + # The scale and axis units are the same as the one printed in the reconstruction script + image.scale = read_pixelsize_from_hdf5(args.raw_file) + image.axes_units = {'x': 'micrometer', 'y': 'micrometer', 'z': 'micrometer'} + multiscales = to_multiscales(image, method=Methods.DASK_IMAGE_GAUSSIAN, cache=False) + to_ngff_zarr(zarr_dir, multiscales) + print('NGFF Zarr created') + + # Set permissions for the output directory and its contents + set_permissions_recursive(zarr_dir) + + # Extract and set permissions for the parent directory (folder_name) + parent_dir = os.path.abspath(os.path.join(tiff_dir, '../')) # Extract parent directory + set_permissions_recursive(parent_dir) # Set permissions for parent directory + + +if __name__ == "__main__": + main() From 21244e56be84d071cac5699304f03d968a71c5f1 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 6 Jan 2026 11:33:21 -0800 Subject: [PATCH 02/72] adding logic for determining qos based on number of nodes requested --- orchestration/flows/bl832/nersc.py | 306 +++++++++++++++++++++-------- 1 file changed, 229 insertions(+), 77 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index e1d00522..6feee868 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -242,9 +242,16 @@ def reconstruct_multinode( logger.info(f"Folder name: {folder_name}") logger.info(f"Number of nodes: {num_nodes}") + if num_nodes == 8: + qos = "debug" + if num_nodes < 8: + qos = "realtime" + if num_nodes > 8: + qos = "preempt" + # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately job_script = f"""#!/bin/bash -#SBATCH -q realtime +#SBATCH -q {qos} #SBATCH -A als #SBATCH -C cpu #SBATCH --job-name=tomo_recon_{folder_name}_{file_name} @@ -253,42 +260,45 @@ def reconstruct_multinode( #SBATCH -N {num_nodes} #SBATCH --ntasks={num_nodes} #SBATCH --cpus-per-task=128 -#SBATCH --time=0:15:00 +#SBATCH --time=0:30:00 #SBATCH --exclusive -date +# Timing file for this job +TIMING_FILE="{pscratch_path}/tomo_recon_logs/timing_$SLURM_JOB_ID.txt" + +echo "JOB_START=$(date +%s)" > $TIMING_FILE echo "Running reconstruction with {num_nodes} nodes" -echo "Pre-pulling container image..." +echo "PREPULL_START=$(date +%s)" >> $TIMING_FILE podman-hpc pull {recon_image} +echo "PREPULL_END=$(date +%s)" >> $TIMING_FILE -echo "Creating directory {pscratch_path}/8.3.2/raw/{folder_name}" mkdir -p {pscratch_path}/8.3.2/raw/{folder_name} mkdir -p {pscratch_path}/8.3.2/scratch/{folder_name} -echo "Copying file {raw_path}/{folder_name}/{file_name} to {pscratch_path}/8.3.2/raw/{folder_name}/" -cp {raw_path}/{folder_name}/{file_name} {pscratch_path}/8.3.2/raw/{folder_name} -if [ $? -ne 0 ]; then - echo "Failed to copy data to pscratch." - exit 1 +echo "COPY_START=$(date +%s)" >> $TIMING_FILE +if [ ! -f "{pscratch_path}/8.3.2/raw/{folder_name}/{file_name}" ]; then + cp {raw_path}/{folder_name}/{file_name} {pscratch_path}/8.3.2/raw/{folder_name} + if [ $? -ne 0 ]; then + echo "Failed to copy data to pscratch." + exit 1 + fi + echo "COPY_SKIPPED=false" >> $TIMING_FILE +else + echo "COPY_SKIPPED=true" >> $TIMING_FILE fi +echo "COPY_END=$(date +%s)" >> $TIMING_FILE chmod 2775 {pscratch_path}/8.3.2/raw/{folder_name} chmod 2775 {pscratch_path}/8.3.2/scratch/{folder_name} chmod 664 {pscratch_path}/8.3.2/raw/{folder_name}/{file_name} -echo "Verifying copied files..." -ls -l {pscratch_path}/8.3.2/raw/{folder_name}/ - NNODES={num_nodes} -RAW_FILE="{pscratch_path}/8.3.2/raw/{folder_name}/{file_name}" -# Get the number of slices from the HDF5 file using the container -echo "Reading slice count from HDF5 file..." - -NUM_SLICES=$(podman-hpc run --rm \\ - --volume {pscratch_path}/8.3.2:/alsdata \\ - {recon_image} \\ +echo "METADATA_START=$(date +%s)" >> $TIMING_FILE +NUM_SLICES=$(podman-hpc run --rm \ + --volume {pscratch_path}/8.3.2:/alsdata \ + {recon_image} \ python -c " import h5py with h5py.File('/alsdata/raw/{folder_name}/{file_name}', 'r') as f: @@ -301,8 +311,9 @@ def reconstruct_multinode( print(int(grp.attrs['nslices'])) break " 2>&1 | grep -E '^[0-9]+$' | head -1) +echo "METADATA_END=$(date +%s)" >> $TIMING_FILE -echo "Detected NUM_SLICES: $NUM_SLICES" +echo "NUM_SLICES=$NUM_SLICES" >> $TIMING_FILE if [ -z "$NUM_SLICES" ]; then echo "Failed to read number of slices from HDF5 file" @@ -314,25 +325,19 @@ def reconstruct_multinode( exit 1 fi -echo "Total slices: $NUM_SLICES" -echo "Distributing across $NNODES nodes" - SLICES_PER_NODE=$((NUM_SLICES / NNODES)) -echo "Slices per node: ~$SLICES_PER_NODE" -# Launch reconstruction on each node +echo "RECON_START=$(date +%s)" >> $TIMING_FILE + for i in $(seq 0 $((NNODES - 1))); do SINO_START=$((i * SLICES_PER_NODE)) - # Last node takes any remainder slices if [ $i -eq $((NNODES - 1)) ]; then SINO_END=$NUM_SLICES else SINO_END=$(((i + 1) * SLICES_PER_NODE)) fi - echo "Launching node $i: slices $SINO_START to $SINO_END" - srun --nodes=1 --ntasks=1 --exclusive podman-hpc run \ --env NUMEXPR_MAX_THREADS=128 \ --env NUMEXPR_NUM_THREADS=128 \ @@ -345,17 +350,18 @@ def reconstruct_multinode( bash -c "cd /alsuser && python sfapi_reconstruction_multinode.py {file_name} {folder_name} $SINO_START $SINO_END" & done -echo "Waiting for all $NNODES nodes to complete..." wait WAIT_STATUS=$? +echo "RECON_END=$(date +%s)" >> $TIMING_FILE if [ $WAIT_STATUS -ne 0 ]; then echo "One or more reconstruction tasks failed" + echo "JOB_STATUS=FAILED" >> $TIMING_FILE exit 1 fi -echo "All nodes completed successfully" -date +echo "JOB_STATUS=SUCCESS" >> $TIMING_FILE +echo "JOB_END=$(date +%s)" >> $TIMING_FILE """ try: logger.info("Submitting reconstruction job script to Perlmutter.") @@ -373,7 +379,14 @@ def reconstruct_multinode( job.complete() # Wait until the job completes logger.info("Reconstruction job completed successfully.") - return True + # Fetch timing data + timing = self._fetch_timing_data(perlmutter, pscratch_path, job.jobid) + + return { + "success": True, + "job_id": job.jobid, + "timing": timing + } except Exception as e: logger.info(f"Error during job submission or completion: {e}") @@ -394,6 +407,65 @@ def reconstruct_multinode( else: return False + def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dict: + """Fetch and parse timing data from the SLURM job.""" + timing_file = f"{pscratch_path}/tomo_recon_logs/timing_{job_id}.txt" + + try: + # Use SFAPI to read the timing file + result = perlmutter.run(f"cat {timing_file}") + + # result might be a string directly, or an object with .output + if isinstance(result, str): + output = result + elif hasattr(result, 'output'): + output = result.output + elif hasattr(result, 'stdout'): + output = result.stdout + else: + output = str(result) + + logger.info(f"Timing file contents:\n{output}") + + # Parse timing data + timing = {} + for line in output.strip().split('\n'): + if '=' in line: + key, value = line.split('=', 1) + timing[key] = value.strip() + + # Calculate durations + breakdown = {} + + if 'JOB_START' in timing and 'JOB_END' in timing: + breakdown['total'] = int(timing['JOB_END']) - int(timing['JOB_START']) + + if 'PREPULL_START' in timing and 'PREPULL_END' in timing: + breakdown['container_pull'] = int(timing['PREPULL_END']) - int(timing['PREPULL_START']) + + if 'COPY_START' in timing and 'COPY_END' in timing: + breakdown['file_copy'] = int(timing['COPY_END']) - int(timing['COPY_START']) + breakdown['copy_skipped'] = timing.get('COPY_SKIPPED', 'false') == 'true' + + if 'METADATA_START' in timing and 'METADATA_END' in timing: + breakdown['metadata'] = int(timing['METADATA_END']) - int(timing['METADATA_START']) + + if 'RECON_START' in timing and 'RECON_END' in timing: + breakdown['reconstruction'] = int(timing['RECON_END']) - int(timing['RECON_START']) + + if 'NUM_SLICES' in timing: + breakdown['num_slices'] = int(timing['NUM_SLICES']) + + breakdown['job_status'] = timing.get('JOB_STATUS', 'UNKNOWN') + + return breakdown + + except Exception as e: + logger.warning(f"Error fetching timing data: {e}") + import traceback + logger.warning(traceback.format_exc()) + return None + def build_multi_resolution( self, file_path: str = "", @@ -652,19 +724,37 @@ def nersc_recon_flow( ) logger.info("NERSC reconstruction controller initialized") - if num_nodes > 1: - nersc_reconstruction_success = controller.reconstruct_multinode( - file_path=file_path, - num_nodes=num_nodes - ) - elif num_nodes == 1: - nersc_reconstruction_success = controller.reconstruct( - file_path=file_path, - ) + nersc_reconstruction_success = controller.reconstruct_multinode( + file_path=file_path, + num_nodes=num_nodes + ) + + if isinstance(nersc_reconstruction_success, dict): + success = nersc_reconstruction_success.get('success', False) + timing = nersc_reconstruction_success.get('timing') + + if timing: + logger.info("=" * 50) + logger.info("TIMING BREAKDOWN") + logger.info("=" * 50) + logger.info(f" Total job time: {timing.get('total', 'N/A')}s") + logger.info(f" Container pull: {timing.get('container_pull', 'N/A')}s") + logger.info(f" File copy: {timing.get('file_copy', 'N/A')}s (skipped: {timing.get('copy_skipped', 'N/A')})") + logger.info(f" Metadata detection: {timing.get('metadata', 'N/A')}s") + logger.info(f" RECONSTRUCTION: {timing.get('reconstruction', 'N/A')}s <-- actual recon time") + logger.info(f" Num slices: {timing.get('num_slices', 'N/A')}") + logger.info("=" * 50) + + # Calculate overhead + if all(k in timing for k in ['total', 'reconstruction']): + overhead = timing['total'] - timing['reconstruction'] + logger.info(f" Overhead: {overhead}s") + logger.info(f" Reconstruction %: {100 * timing['reconstruction'] / timing['total']:.1f}%") + logger.info("=" * 50) else: - raise ValueError("num_nodes must be at least 1") + success = nersc_reconstruction_success - logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + logger.info(f"NERSC reconstruction success: {success}") # Commented out for testing purposes -- should be re-enabled for production @@ -761,6 +851,8 @@ def nersc_streaming_flow( config = Config832() + # Fibers ------------------------------------------ + start = time.time() nersc_recon_flow( file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", @@ -771,6 +863,38 @@ def nersc_streaming_flow( logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") print(f"Total reconstruction time with 8 nodes: {end - start} seconds") + start = time.time() + nersc_recon_flow( + file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + num_nodes=4, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") + print(f"Total reconstruction time with 4 nodes: {end - start} seconds") + + start = time.time() + nersc_recon_flow( + file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + num_nodes=2, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") + print(f"Total reconstruction time with 2 nodes: {end - start} seconds") + + start = time.time() + nersc_recon_flow( + file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + num_nodes=1, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") + print(f"Total reconstruction time with 1 node: {end - start} seconds") + + # # Fungi ------------------------------------------ + start = time.time() nersc_recon_flow( file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", @@ -781,6 +905,38 @@ def nersc_streaming_flow( logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") print(f"Total reconstruction time with 8 nodes: {end - start} seconds") + start = time.time() + nersc_recon_flow( + file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + num_nodes=4, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") + print(f"Total reconstruction time with 4 nodes: {end - start} seconds") + + start = time.time() + nersc_recon_flow( + file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + num_nodes=2, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") + print(f"Total reconstruction time with 2 nodes: {end - start} seconds") + + start = time.time() + nersc_recon_flow( + file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + num_nodes=1, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") + print(f"Total reconstruction time with 1 node: {end - start} seconds") + + # # Silk ------------------------------------------ + start = time.time() nersc_recon_flow( file_path="dabramov/20251218_111600_silkraw.h5", @@ -791,36 +947,32 @@ def nersc_streaming_flow( logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") print(f"Total reconstruction time with 8 nodes: {end - start} seconds") - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - # num_nodes=4, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - # num_nodes=2, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 2 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - # num_nodes=1, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") - # print(f"Total reconstruction time with 1 node: {end - start} seconds") - # nersc_streaming_flow( - # config=config, - # walltime=datetime.timedelta(minutes=5) - # )' + start = time.time() + nersc_recon_flow( + file_path="dabramov/20251218_111600_silkraw.h5", + num_nodes=4, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") + print(f"Total reconstruction time with 4 nodes: {end - start} seconds") + + start = time.time() + nersc_recon_flow( + file_path="dabramov/20251218_111600_silkraw.h5", + num_nodes=2, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") + print(f"Total reconstruction time with 2 nodes: {end - start} seconds") + + start = time.time() + nersc_recon_flow( + file_path="dabramov/20251218_111600_silkraw.h5", + num_nodes=1, + config=config + ) + end = time.time() + logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") + print(f"Total reconstruction time with 1 node: {end - start} seconds") From 77d15ccc48436fb843f76c626c58ab0f01514a99 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 6 Jan 2026 14:34:05 -0800 Subject: [PATCH 03/72] Adding specific tag for microct image (for more efficient caching on NERSC). Also switching from podman to shifter for better overhead performance --- config.yml | 4 +- orchestration/flows/bl832/nersc.py | 201 ++++++++++++++++++++++++++--- 2 files changed, 186 insertions(+), 19 deletions(-) diff --git a/config.yml b/config.yml index 3f26a4f0..85393502 100644 --- a/config.yml +++ b/config.yml @@ -148,8 +148,8 @@ harbor_images832: multires_image: tomorecon_nersc_mpi_hdf5@sha256:cc098a2cfb6b1632ea872a202c66cb7566908da066fd8f8c123b92fa95c2a43c ghcr_images832: - recon_image: ghcr.io/als-computing/microct:master - multires_image: ghcr.io/als-computing/microct:master + recon_image: ghcr.io/als-computing/microct@sha256:1fdfb786726ee03301d624319e3d16702045072f38e2b0cca9d6237e5ab3f5ff + multires_image: ghcr.io/als-computing/microct@sha256:1fdfb786726ee03301d624319e3d16702045072f38e2b0cca9d6237e5ab3f5ff prefect: deployments: diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 6feee868..3ff6bd83 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -262,6 +262,7 @@ def reconstruct_multinode( #SBATCH --cpus-per-task=128 #SBATCH --time=0:30:00 #SBATCH --exclusive +#SBATCH --image={recon_image} # Timing file for this job TIMING_FILE="{pscratch_path}/tomo_recon_logs/timing_$SLURM_JOB_ID.txt" @@ -269,8 +270,8 @@ def reconstruct_multinode( echo "JOB_START=$(date +%s)" > $TIMING_FILE echo "Running reconstruction with {num_nodes} nodes" +# No container pull needed with Shifter - image is pre-staged via --image echo "PREPULL_START=$(date +%s)" >> $TIMING_FILE -podman-hpc pull {recon_image} echo "PREPULL_END=$(date +%s)" >> $TIMING_FILE mkdir -p {pscratch_path}/8.3.2/raw/{folder_name} @@ -296,9 +297,8 @@ def reconstruct_multinode( NNODES={num_nodes} echo "METADATA_START=$(date +%s)" >> $TIMING_FILE -NUM_SLICES=$(podman-hpc run --rm \ - --volume {pscratch_path}/8.3.2:/alsdata \ - {recon_image} \ +NUM_SLICES=$(shifter \ + --volume={pscratch_path}/8.3.2:/alsdata \ python -c " import h5py with h5py.File('/alsdata/raw/{folder_name}/{file_name}', 'r') as f: @@ -329,6 +329,9 @@ def reconstruct_multinode( echo "RECON_START=$(date +%s)" >> $TIMING_FILE +# Create symlink so folder_name resolves correctly (like podman mount did) +ln -sfn {pscratch_path}/8.3.2/raw/{folder_name} {pscratch_path}/8.3.2/{folder_name} + for i in $(seq 0 $((NNODES - 1))); do SINO_START=$((i * SLICES_PER_NODE)) @@ -338,16 +341,15 @@ def reconstruct_multinode( SINO_END=$(((i + 1) * SLICES_PER_NODE)) fi - srun --nodes=1 --ntasks=1 --exclusive podman-hpc run \ - --env NUMEXPR_MAX_THREADS=128 \ - --env NUMEXPR_NUM_THREADS=128 \ - --env OMP_NUM_THREADS=128 \ - --env MKL_NUM_THREADS=128 \ - --volume {recon_scripts_dir}/sfapi_reconstruction_multinode.py:/alsuser/sfapi_reconstruction_multinode.py \ - --volume {pscratch_path}/8.3.2/raw/{folder_name}:/alsuser/{folder_name} \ - --volume {pscratch_path}/8.3.2/scratch:/scratch \ - {recon_image} \ - bash -c "cd /alsuser && python sfapi_reconstruction_multinode.py {file_name} {folder_name} $SINO_START $SINO_END" & + srun --nodes=1 --ntasks=1 --exclusive shifter \ + --env=NUMEXPR_MAX_THREADS=128 \ + --env=NUMEXPR_NUM_THREADS=128 \ + --env=OMP_NUM_THREADS=128 \ + --env=MKL_NUM_THREADS=128 \ + --volume={pscratch_path}/8.3.2:/alsuser \ + --volume={pscratch_path}/8.3.2/scratch:/scratch \ + --volume={recon_scripts_dir}:/opt/scripts \ + /bin/bash -c "cd /alsuser && python /opt/scripts/sfapi_reconstruction_multinode.py {file_name} {folder_name} $SINO_START $SINO_END" & done wait @@ -466,6 +468,134 @@ def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dic logger.warning(traceback.format_exc()) return None + def pull_shifter_image( + self, + image: str = None, + wait: bool = True, + ) -> bool: + """ + Pull a container image into NERSC's Shifter cache. + + This should be run once when the image is updated, not before every reconstruction. + After the image is cached, jobs using --image= will start much faster. + + :param image: Container image to pull (defaults to recon_image from config) + :param wait: Whether to wait for the pull to complete + :return: True if successful, False otherwise + """ + logger.info("Starting Shifter image pull.") + + user = self.client.user() + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + + if image is None: + image = self.config.ghcr_images832["recon_image"] + + logger.info(f"Pulling image: {image}") + + job_script = f"""#!/bin/bash +#SBATCH -q debug +#SBATCH -A als +#SBATCH -C cpu +#SBATCH --job-name=shifter_pull +#SBATCH --output={pscratch_path}/tomo_recon_logs/shifter_pull_%j.out +#SBATCH --error={pscratch_path}/tomo_recon_logs/shifter_pull_%j.err +#SBATCH -N 1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=1 +#SBATCH --time=0:15:00 + +echo "Starting Shifter image pull at $(date)" +echo "Image: {image}" + +# Check if image already exists +echo "Checking existing images..." +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" || true + +# Pull the image +echo "Pulling image..." +shifterimg -v pull {image} +PULL_STATUS=$? + +if [ $PULL_STATUS -eq 0 ]; then + echo "Image pull successful" +else + echo "Image pull failed with status $PULL_STATUS" + exit 1 +fi + +# Verify the image is now available +echo "Verifying image..." +shifterimg images | grep -E "$(echo {image} | sed 's/:/.*/')" + +echo "Completed at $(date)" +""" + + try: + logger.info("Submitting Shifter image pull job to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + if wait: + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(30) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() + logger.info("Shifter image pull completed successfully.") + return True + else: + logger.info(f"Job submitted. Check status with job ID: {job.jobid}") + return True + + except Exception as e: + logger.error(f"Error during Shifter image pull: {e}") + return False + + def check_shifter_image( + self, + image: str = None, + ) -> bool: + """ + Check if a container image is already in NERSC's Shifter cache. + + :param image: Container image to check (defaults to recon_image from config) + :return: True if image exists in cache, False otherwise + """ + logger.info("Checking Shifter image cache.") + + if image is None: + image = self.config.ghcr_images832["recon_image"] + + try: + perlmutter = self.client.compute(Machine.perlmutter) + + # Run shifterimg images command + result = perlmutter.run(f"shifterimg images | grep -E \"$(echo {image} | sed 's/:/.*/g')\"") + + if isinstance(result, str): + output = result + elif hasattr(result, 'output'): + output = result.output + else: + output = str(result) + + if output.strip(): + logger.info(f"Image found in Shifter cache: {output.strip()}") + return True + else: + logger.info(f"Image not found in Shifter cache: {image}") + return False + + except Exception as e: + logger.warning(f"Error checking Shifter cache: {e}") + return False + def build_multi_resolution( self, file_path: str = "", @@ -847,11 +977,48 @@ def nersc_streaming_flow( return success +@flow(name="pull_shifter_image_flow", flow_run_name="pull_shifter_image") +def pull_shifter_image_flow( + image: Optional[str] = None, + config: Optional[Config832] = None, +) -> bool: + """ + Pull a container image into NERSC's Shifter cache. + + Run this once when the container image is updated. + """ + logger = get_run_logger() + + if config is None: + config = Config832() + + if image is None: + image = config.ghcr_images832["recon_image"] + + logger.info(f"Pulling Shifter image: {image}") + + controller = get_controller( + hpc_type=HPC.NERSC, + config=config + ) + + # Check if already cached + if controller.check_shifter_image(image): + logger.info("Image already in cache, pulling anyway to update...") + + success = controller.pull_shifter_image(image) + logger.info(f"Shifter image pull success: {success}") + + return success + + if __name__ == "__main__": config = Config832() - # Fibers ------------------------------------------ + # pull_shifter_image_flow(config=config) + + # # Fibers ------------------------------------------ start = time.time() nersc_recon_flow( @@ -893,7 +1060,7 @@ def nersc_streaming_flow( logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") print(f"Total reconstruction time with 1 node: {end - start} seconds") - # # Fungi ------------------------------------------ + # # # Fungi ------------------------------------------ start = time.time() nersc_recon_flow( @@ -935,7 +1102,7 @@ def nersc_streaming_flow( logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") print(f"Total reconstruction time with 1 node: {end - start} seconds") - # # Silk ------------------------------------------ + # # # Silk ------------------------------------------ start = time.time() nersc_recon_flow( From 9e90a2654385250574a6c42421a632e22d23e0f2 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 14 Jan 2026 17:13:26 -0800 Subject: [PATCH 04/72] Adding reconstruct_multinode() method --- orchestration/flows/bl832/nersc.py | 581 +++++++++++++++-------------- 1 file changed, 310 insertions(+), 271 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 3ff6bd83..794dc9a4 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -80,6 +80,9 @@ def reconstruct( ) -> bool: """ Use NERSC for tomography reconstruction + + :param file_path: Path to the file to reconstruct + :return: True if successful, False otherwise """ logger.info("Starting NERSC reconstruction process.") @@ -211,6 +214,7 @@ def reconstruct_multinode( :param file_path: Path to the file to reconstruct :param num_nodes: Number of nodes to use for parallel reconstruction + :return: True if successful, False otherwise """ logger.info("Starting NERSC reconstruction process.") @@ -247,7 +251,7 @@ def reconstruct_multinode( if num_nodes < 8: qos = "realtime" if num_nodes > 8: - qos = "preempt" + qos = "premium" # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately job_script = f"""#!/bin/bash @@ -349,7 +353,9 @@ def reconstruct_multinode( --volume={pscratch_path}/8.3.2:/alsuser \ --volume={pscratch_path}/8.3.2/scratch:/scratch \ --volume={recon_scripts_dir}:/opt/scripts \ - /bin/bash -c "cd /alsuser && python /opt/scripts/sfapi_reconstruction_multinode.py {file_name} {folder_name} $SINO_START $SINO_END" & + /bin/bash -c "cd /alsuser && python /opt/scripts/sfapi_reconstruction_multinode.py \ +{file_name} {folder_name} $SINO_START $SINO_END" & + done wait @@ -410,7 +416,14 @@ def reconstruct_multinode( return False def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dict: - """Fetch and parse timing data from the SLURM job.""" + """ + Fetch and parse timing data from the SLURM job. + + :param perlmutter: SFAPI compute object for Perlmutter + :param pscratch_path: Path to the user's pscratch directory + :param job_id: SLURM job ID + :return: Dictionary with timing breakdown + """ timing_file = f"{pscratch_path}/tomo_recon_logs/timing_{job_id}.txt" try: @@ -468,6 +481,117 @@ def _fetch_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dic logger.warning(traceback.format_exc()) return None + def build_multi_resolution( + self, + file_path: str = "", + ) -> bool: + """ + Use NERSC to make multiresolution version of tomography results. + + :param file_path: Path to the file to process + :return: True if successful, False otherwise + """ + + logger.info("Starting NERSC multiresolution process.") + + user = self.client.user() + + multires_image = self.config.ghcr_images832["multires_image"] + logger.info(f"{multires_image=}") + + recon_scripts_dir = self.config.nersc832_alsdev_recon_scripts.root_path + logger.info(f"{recon_scripts_dir=}") + + scratch_path = self.config.nersc832_alsdev_scratch.root_path + logger.info(f"{scratch_path=}") + + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + logger.info(f"{pscratch_path=}") + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + recon_path = f"scratch/{folder_name}/rec{file_name}/" + logger.info(f"{recon_path=}") + + raw_path = f"raw/{folder_name}/{file_name}.h5" + logger.info(f"{raw_path=}") + + # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately + job_script = f"""#!/bin/bash +#SBATCH -q realtime +#SBATCH -A als +#SBATCH -C cpu +#SBATCH --job-name=tomo_multires_{folder_name}_{file_name} +#SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err +#SBATCH -N 1 +#SBATCH --ntasks-per-node 1 +#SBATCH --cpus-per-task 128 +#SBATCH --time=0:15:00 +#SBATCH --exclusive + +date + +echo "Running multires container..." +srun podman-hpc run \ +--volume {recon_scripts_dir}/tiff_to_zarr.py:/alsuser/tiff_to_zarr.py \ +--volume {pscratch_path}/8.3.2:/alsdata \ +--volume {pscratch_path}/8.3.2:/alsuser/ \ +{multires_image} \ +bash -c "python tiff_to_zarr.py {recon_path} --raw_file {raw_path}" + +date +""" + try: + logger.info("Submitting Tiff to Zarr job script to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() # Wait until the job completes + logger.info("Reconstruction job completed successfully.") + + return True + + except Exception as e: + logger.warning(f"Error during job submission or completion: {e}") + match = re.search(r"Job not found:\s*(\d+)", str(e)) + + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.perlmutter.job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("Reconstruction job completed successfully after recovery.") + return True + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + return False + else: + return False + + def start_streaming_service( + self, + walltime: datetime.timedelta = datetime.timedelta(minutes=30), + ) -> str: + return NerscStreamingMixin.start_streaming_service( + self, + client=self.client, + walltime=walltime + ) + def pull_shifter_image( self, image: str = None, @@ -596,112 +720,6 @@ def check_shifter_image( logger.warning(f"Error checking Shifter cache: {e}") return False - def build_multi_resolution( - self, - file_path: str = "", - ) -> bool: - """Use NERSC to make multiresolution version of tomography results.""" - - logger.info("Starting NERSC multiresolution process.") - - user = self.client.user() - - multires_image = self.config.ghcr_images832["multires_image"] - logger.info(f"{multires_image=}") - - recon_scripts_dir = self.config.nersc832_alsdev_recon_scripts.root_path - logger.info(f"{recon_scripts_dir=}") - - scratch_path = self.config.nersc832_alsdev_scratch.root_path - logger.info(f"{scratch_path=}") - - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - logger.info(f"{pscratch_path=}") - - path = Path(file_path) - folder_name = path.parent.name - file_name = path.stem - - recon_path = f"scratch/{folder_name}/rec{file_name}/" - logger.info(f"{recon_path=}") - - raw_path = f"raw/{folder_name}/{file_name}.h5" - logger.info(f"{raw_path=}") - - # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately - job_script = f"""#!/bin/bash -#SBATCH -q realtime -#SBATCH -A als -#SBATCH -C cpu -#SBATCH --job-name=tomo_multires_{folder_name}_{file_name} -#SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out -#SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err -#SBATCH -N 1 -#SBATCH --ntasks-per-node 1 -#SBATCH --cpus-per-task 64 -#SBATCH --time=0:15:00 -#SBATCH --exclusive - -date - -echo "Running multires container..." -srun podman-hpc run \ ---volume {recon_scripts_dir}/tiff_to_zarr.py:/alsuser/tiff_to_zarr.py \ ---volume {pscratch_path}/8.3.2:/alsdata \ ---volume {pscratch_path}/8.3.2:/alsuser/ \ -{multires_image} \ -bash -c "python tiff_to_zarr.py {recon_path} --raw_file {raw_path}" - -date -""" - try: - logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - - time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - - return True - - except Exception as e: - logger.warning(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False - - def start_streaming_service( - self, - walltime: datetime.timedelta = datetime.timedelta(minutes=30), - ) -> str: - return NerscStreamingMixin.start_streaming_service( - self, - client=self.client, - walltime=walltime - ) - def schedule_pruning( config: Config832, @@ -840,6 +858,9 @@ def nersc_recon_flow( Perform tomography reconstruction on NERSC. :param file_path: Path to the file to reconstruct. + :param num_nodes: Number of nodes to use for parallel reconstruction. + :param config: Configuration object (if None, a default Config832 will be created). + :return: True if successful, False otherwise. """ logger = get_run_logger() @@ -854,10 +875,17 @@ def nersc_recon_flow( ) logger.info("NERSC reconstruction controller initialized") - nersc_reconstruction_success = controller.reconstruct_multinode( - file_path=file_path, - num_nodes=num_nodes - ) + if num_nodes == 1: + logger.info("Using single-node reconstruction") + nersc_reconstruction_success = controller.reconstruct( + file_path=file_path, + ) + else: + logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") + nersc_reconstruction_success = controller.reconstruct_multinode( + file_path=file_path, + num_nodes=num_nodes + ) if isinstance(nersc_reconstruction_success, dict): success = nersc_reconstruction_success.get('success', False) @@ -869,7 +897,10 @@ def nersc_recon_flow( logger.info("=" * 50) logger.info(f" Total job time: {timing.get('total', 'N/A')}s") logger.info(f" Container pull: {timing.get('container_pull', 'N/A')}s") - logger.info(f" File copy: {timing.get('file_copy', 'N/A')}s (skipped: {timing.get('copy_skipped', 'N/A')})") + logger.info( + f" File copy: {timing.get('file_copy', 'N/A')}s " + f"(skipped: {timing.get('copy_skipped', 'N/A')})" + ) logger.info(f" Metadata detection: {timing.get('metadata', 'N/A')}s") logger.info(f" RECONSTRUCTION: {timing.get('reconstruction', 'N/A')}s <-- actual recon time") logger.info(f" Num slices: {timing.get('num_slices', 'N/A')}") @@ -886,63 +917,61 @@ def nersc_recon_flow( logger.info(f"NERSC reconstruction success: {success}") - # Commented out for testing purposes -- should be re-enabled for production - - # nersc_multi_res_success = controller.build_multi_resolution( - # file_path=file_path, - # ) - # logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") + nersc_multi_res_success = controller.build_multi_resolution( + file_path=file_path, + ) + logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") - # path = Path(file_path) - # folder_name = path.parent.name - # file_name = path.stem + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem - # tiff_file_path = f"{folder_name}/rec{file_name}" - # zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + tiff_file_path = f"{folder_name}/rec{file_name}" + zarr_file_path = f"{folder_name}/rec{file_name}.zarr" - # logger.info(f"{tiff_file_path=}") - # logger.info(f"{zarr_file_path=}") + logger.info(f"{tiff_file_path=}") + logger.info(f"{zarr_file_path=}") # Transfer reconstructed data - # logger.info("Preparing transfer.") - # transfer_controller = get_transfer_controller( - # transfer_type=CopyMethod.GLOBUS, - # config=config - # ) + logger.info("Preparing transfer.") + transfer_controller = get_transfer_controller( + transfer_type=CopyMethod.GLOBUS, + config=config + ) - # logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") - # transfer_controller.copy( - # file_path=tiff_file_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.nersc832_alsdev_scratch - # ) + logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") + transfer_controller.copy( + file_path=tiff_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.nersc832_alsdev_scratch + ) - # transfer_controller.copy( - # file_path=zarr_file_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.nersc832_alsdev_scratch - # ) + transfer_controller.copy( + file_path=zarr_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.nersc832_alsdev_scratch + ) - # logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") - # transfer_controller.copy( - # file_path=tiff_file_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.data832_scratch - # ) + logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") + transfer_controller.copy( + file_path=tiff_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) - # transfer_controller.copy( - # file_path=zarr_file_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.data832_scratch - # ) + transfer_controller.copy( + file_path=zarr_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) - # logger.info("Scheduling pruning tasks.") - # schedule_pruning( - # config=config, - # raw_file_path=file_path, - # tiff_file_path=tiff_file_path, - # zarr_file_path=zarr_file_path - # ) + logger.info("Scheduling pruning tasks.") + schedule_pruning( + config=config, + raw_file_path=file_path, + tiff_file_path=tiff_file_path, + zarr_file_path=zarr_file_path + ) # TODO: Ingest into SciCat if nersc_reconstruction_success: @@ -1022,59 +1051,7 @@ def pull_shifter_image_flow( start = time.time() nersc_recon_flow( - file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - num_nodes=8, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") - print(f"Total reconstruction time with 8 nodes: {end - start} seconds") - - start = time.time() - nersc_recon_flow( - file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - num_nodes=4, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") - print(f"Total reconstruction time with 4 nodes: {end - start} seconds") - - start = time.time() - nersc_recon_flow( - file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - num_nodes=2, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") - print(f"Total reconstruction time with 2 nodes: {end - start} seconds") - - start = time.time() - nersc_recon_flow( - file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - num_nodes=1, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") - print(f"Total reconstruction time with 1 node: {end - start} seconds") - - # # # Fungi ------------------------------------------ - - start = time.time() - nersc_recon_flow( - file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", - num_nodes=8, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") - print(f"Total reconstruction time with 8 nodes: {end - start} seconds") - - start = time.time() - nersc_recon_flow( - file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + file_path="dabramov/20251218_111600_silkraw.h5", num_nodes=4, config=config ) @@ -1082,64 +1059,126 @@ def pull_shifter_image_flow( logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") print(f"Total reconstruction time with 4 nodes: {end - start} seconds") - start = time.time() - nersc_recon_flow( - file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", - num_nodes=2, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") - print(f"Total reconstruction time with 2 nodes: {end - start} seconds") - - start = time.time() - nersc_recon_flow( - file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", - num_nodes=1, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") - print(f"Total reconstruction time with 1 node: {end - start} seconds") + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + # num_nodes=8, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 8 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + # num_nodes=4, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + # num_nodes=2, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 2 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", + # num_nodes=1, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") + # print(f"Total reconstruction time with 1 node: {end - start} seconds") - # # # Silk ------------------------------------------ + # # # # Fungi ------------------------------------------ - start = time.time() - nersc_recon_flow( - file_path="dabramov/20251218_111600_silkraw.h5", - num_nodes=8, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") - print(f"Total reconstruction time with 8 nodes: {end - start} seconds") - - start = time.time() - nersc_recon_flow( - file_path="dabramov/20251218_111600_silkraw.h5", - num_nodes=4, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") - print(f"Total reconstruction time with 4 nodes: {end - start} seconds") + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + # num_nodes=8, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 8 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + # num_nodes=4, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + # num_nodes=2, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 2 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", + # num_nodes=1, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") + # print(f"Total reconstruction time with 1 node: {end - start} seconds") - start = time.time() - nersc_recon_flow( - file_path="dabramov/20251218_111600_silkraw.h5", - num_nodes=2, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") - print(f"Total reconstruction time with 2 nodes: {end - start} seconds") + # # # # Silk ------------------------------------------ - start = time.time() - nersc_recon_flow( - file_path="dabramov/20251218_111600_silkraw.h5", - num_nodes=1, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") - print(f"Total reconstruction time with 1 node: {end - start} seconds") + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20251218_111600_silkraw.h5", + # num_nodes=8, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 8 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20251218_111600_silkraw.h5", + # num_nodes=4, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20251218_111600_silkraw.h5", + # num_nodes=2, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 2 nodes: {end - start} seconds") + + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20251218_111600_silkraw.h5", + # num_nodes=1, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") + # print(f"Total reconstruction time with 1 node: {end - start} seconds") From 31ed453fe2483a16c3bb59646c31b4331f39d0e7 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 14 Jan 2026 17:13:46 -0800 Subject: [PATCH 05/72] Making the cancel_sfapi_job.py script more useful --- scripts/cancel_sfapi_job.py | 87 ++++++++++++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 6 deletions(-) diff --git a/scripts/cancel_sfapi_job.py b/scripts/cancel_sfapi_job.py index 53dec051..096881eb 100644 --- a/scripts/cancel_sfapi_job.py +++ b/scripts/cancel_sfapi_job.py @@ -1,4 +1,22 @@ +""" +Script to manage NERSC SLURM jobs via SFAPI. + +Usage: + +# List all jobs +python sfapi_jobs.py list +python sfapi_jobs.py -u dabramov list + +# Cancel a specific job +python sfapi_jobs.py cancel 47470003 + +# Cancel all jobs for a user +python sfapi_jobs.py cancel-all +python sfapi_jobs.py -u dabramov cancel-all +""" + from dotenv import load_dotenv +import argparse import json import logging import os @@ -9,6 +27,7 @@ load_dotenv() +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) client_id_path = os.getenv("PATH_NERSC_CLIENT_ID") @@ -29,10 +48,66 @@ with open(client_secret_path, "r") as f: client_secret = JsonWebKey.import_key(json.loads(f.read())) -with Client(client_id, client_secret) as client: - perlmutter = client.compute(Machine.perlmutter) - # job = perlmutter.submit_job(job_path) - jobs = perlmutter.jobs(user="dabramov") - for job in jobs: + +def list_jobs(user: str = "alsdev"): + """List all jobs for a user.""" + with Client(client_id, client_secret) as client: + perlmutter = client.compute(Machine.perlmutter) + jobs = perlmutter.jobs(user=user) + if not jobs: + logger.info(f"No jobs found for user: {user}") + return + for job in jobs: + logger.info(f"Job {job.jobid}: {job.name} - {job.state}") + + +def cancel_job(jobid: str): + """Cancel a specific job by ID.""" + with Client(client_id, client_secret) as client: + perlmutter = client.compute(Machine.perlmutter) + job = perlmutter.job(jobid=jobid) logger.info(f"Cancelling job: {job.jobid}") - job.cancel() + job.cancel(wait=True) + logger.info(f"Job {job.jobid} cancelled, state: {job.state}") + + +def cancel_all_jobs(user: str = "alsdev"): + """Cancel all jobs for a user.""" + with Client(client_id, client_secret) as client: + perlmutter = client.compute(Machine.perlmutter) + jobs = perlmutter.jobs(user=user) + if not jobs: + logger.info(f"No jobs found for user: {user}") + return + for job in jobs: + logger.info(f"Cancelling job: {job.jobid} ({job.name})") + job.cancel() + logger.info(f"Cancelled {len(jobs)} jobs") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Manage NERSC SLURM jobs via SFAPI") + parser.add_argument("--user", "-u", default="alsdev", help="Username for job queries") + + subparsers = parser.add_subparsers(dest="command", help="Commands") + + # List jobs + list_parser = subparsers.add_parser("list", help="List all jobs for a user") + + # Cancel specific job + cancel_parser = subparsers.add_parser("cancel", help="Cancel a specific job") + cancel_parser.add_argument("jobid", help="Job ID to cancel") + + # Cancel all jobs + cancel_all_parser = subparsers.add_parser("cancel-all", help="Cancel all jobs for a user") + + args = parser.parse_args() + + if args.command == "list": + list_jobs(user=args.user) + elif args.command == "cancel": + cancel_job(jobid=args.jobid) + elif args.command == "cancel-all": + cancel_all_jobs(user=args.user) + else: + parser.print_help() From a88fd4537b88911ca751de86cb7cdf071d0d7ea9 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 15 Jan 2026 11:49:58 -0800 Subject: [PATCH 06/72] in setup.cfg, adding a new section for flake8 to ignore the reconstruction codes in scripts/perlmutter/* and scripts/polaris/*, since they are a linting nightmare but work --- setup.cfg | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.cfg b/setup.cfg index 8ffb8adf..37f4f9f8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1 +1,6 @@ [tool:pytest] + +[flake8] +exclude = + scripts/perlmutter/* + scripts/polaris/* From 369160108663d18284878c66e00c5a209514a520 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 15 Jan 2026 13:09:48 -0800 Subject: [PATCH 07/72] Adding nersc_recon_num_nodes = 4 to Config832, which is used in bl832/nersc.py to set the number of nodes to use for reconstruction --- orchestration/flows/bl832/config.py | 1 + orchestration/flows/bl832/nersc.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 788eef4a..26727279 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -28,3 +28,4 @@ def _beam_specific_config(self) -> None: self.alcf832_scratch = self.endpoints["alcf832_scratch"] self.scicat = self.config["scicat"] self.ghcr_images832 = self.config["ghcr_images832"] + self.nersc_recon_num_nodes = 4 diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 794dc9a4..fe513801 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -851,14 +851,12 @@ def schedule_pruning( @flow(name="nersc_recon_flow", flow_run_name="nersc_recon-{file_path}") def nersc_recon_flow( file_path: str, - num_nodes: int = 1, config: Optional[Config832] = None, ) -> bool: """ Perform tomography reconstruction on NERSC. :param file_path: Path to the file to reconstruct. - :param num_nodes: Number of nodes to use for parallel reconstruction. :param config: Configuration object (if None, a default Config832 will be created). :return: True if successful, False otherwise. """ @@ -874,7 +872,8 @@ def nersc_recon_flow( config=config ) logger.info("NERSC reconstruction controller initialized") - + num_nodes = config.nersc_recon_num_nodes + logger.info(f"Configured to use {num_nodes} nodes for reconstruction") if num_nodes == 1: logger.info("Using single-node reconstruction") nersc_reconstruction_success = controller.reconstruct( From 1f487431b6696adad0cc10a39e5cc8c0f997b78b Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 28 Jan 2026 12:58:22 -0800 Subject: [PATCH 08/72] separating single node (production) nersc reconstruction flow from the multinode reconstruction flow --- orchestration/flows/bl832/nersc.py | 135 ++++++++++++++++++++++++----- 1 file changed, 112 insertions(+), 23 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index fe513801..9345ecb6 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -857,6 +857,98 @@ def nersc_recon_flow( Perform tomography reconstruction on NERSC. :param file_path: Path to the file to reconstruct. + """ + logger = get_run_logger() + + if config is None: + logger.info("Initializing Config") + config = Config832() + + logger.info(f"Starting NERSC reconstruction flow for {file_path=}") + controller = get_controller( + hpc_type=HPC.NERSC, + config=config + ) + logger.info("NERSC reconstruction controller initialized") + + nersc_reconstruction_success = controller.reconstruct( + file_path=file_path, + ) + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + nersc_multi_res_success = controller.build_multi_resolution( + file_path=file_path, + ) + logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + tiff_file_path = f"{folder_name}/rec{file_name}" + zarr_file_path = f"{folder_name}/rec{file_name}.zarr" + + logger.info(f"{tiff_file_path=}") + logger.info(f"{zarr_file_path=}") + + # Transfer reconstructed data + logger.info("Preparing transfer.") + transfer_controller = get_transfer_controller( + transfer_type=CopyMethod.GLOBUS, + config=config + ) + + logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") + transfer_controller.copy( + file_path=tiff_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.nersc832_alsdev_scratch + ) + + transfer_controller.copy( + file_path=zarr_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.nersc832_alsdev_scratch + ) + + logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") + transfer_controller.copy( + file_path=tiff_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + + transfer_controller.copy( + file_path=zarr_file_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + + logger.info("Scheduling pruning tasks.") + schedule_pruning( + config=config, + raw_file_path=file_path, + tiff_file_path=tiff_file_path, + zarr_file_path=zarr_file_path + ) + + # TODO: Ingest into SciCat + if nersc_reconstruction_success and nersc_multi_res_success: + return True + else: + return False + + +@flow(name="nersc_recon_multinode_flow", flow_run_name="nersc_recon_multinode-{file_path}") +def nersc_recon_multinode_flow( + file_path: str, + num_nodes: Optional[int] = 4, + config: Optional[Config832] = None, +) -> bool: + """ + Perform multi-node tomography reconstruction on NERSC. + + :param file_path: Path to the file to reconstruct. + :param num_nodes: Number of nodes to use for reconstruction. :param config: Configuration object (if None, a default Config832 will be created). :return: True if successful, False otherwise. """ @@ -872,19 +964,16 @@ def nersc_recon_flow( config=config ) logger.info("NERSC reconstruction controller initialized") - num_nodes = config.nersc_recon_num_nodes + + if num_nodes is None: + num_nodes = config.nersc_recon_num_nodes logger.info(f"Configured to use {num_nodes} nodes for reconstruction") - if num_nodes == 1: - logger.info("Using single-node reconstruction") - nersc_reconstruction_success = controller.reconstruct( - file_path=file_path, - ) - else: - logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") - nersc_reconstruction_success = controller.reconstruct_multinode( - file_path=file_path, - num_nodes=num_nodes - ) + + logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") + nersc_reconstruction_success = controller.reconstruct_multinode( + file_path=file_path, + num_nodes=num_nodes + ) if isinstance(nersc_reconstruction_success, dict): success = nersc_reconstruction_success.get('success', False) @@ -1040,23 +1129,23 @@ def pull_shifter_image_flow( return success -if __name__ == "__main__": +# if __name__ == "__main__": - config = Config832() +# config = Config832() # pull_shifter_image_flow(config=config) # # Fibers ------------------------------------------ - start = time.time() - nersc_recon_flow( - file_path="dabramov/20251218_111600_silkraw.h5", - num_nodes=4, - config=config - ) - end = time.time() - logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") - print(f"Total reconstruction time with 4 nodes: {end - start} seconds") + # start = time.time() + # nersc_recon_flow( + # file_path="dabramov/20251218_111600_silkraw.h5", + # num_nodes=4, + # config=config + # ) + # end = time.time() + # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") + # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") # start = time.time() # nersc_recon_flow( From 8ad972f34b82d1a009aa56340f04132fcacfd766 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 28 Jan 2026 12:58:39 -0800 Subject: [PATCH 09/72] Making a spearate deployment for the nersc multinode reconstruction flow --- orchestration/flows/bl832/prefect.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/orchestration/flows/bl832/prefect.yaml b/orchestration/flows/bl832/prefect.yaml index a1d4613b..3b93b154 100644 --- a/orchestration/flows/bl832/prefect.yaml +++ b/orchestration/flows/bl832/prefect.yaml @@ -43,6 +43,12 @@ deployments: name: nersc_recon_flow_pool work_queue_name: nersc_recon_flow_queue +- name: nersc_recon_multinode_flow + entrypoint: orchestration/flows/bl832/nersc.py:nersc_recon_multinode_flow + work_pool: + name: nersc_recon_flow_pool + work_queue_name: nersc_recon_multinode_flow_queue + - name: nersc_streaming_flow entrypoint: orchestration/flows/bl832/nersc.py:nersc_streaming_flow work_pool: From 81dd6c9a837748adf663e1880f9f340d0a872b48 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 28 Jan 2026 12:59:05 -0800 Subject: [PATCH 10/72] Creating option to turn on/off the nersc multinode reconstruction flow for the dispatcher --- orchestration/flows/bl832/dispatcher.py | 32 +++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/orchestration/flows/bl832/dispatcher.py b/orchestration/flows/bl832/dispatcher.py index cf1d0c64..60fefb38 100644 --- a/orchestration/flows/bl832/dispatcher.py +++ b/orchestration/flows/bl832/dispatcher.py @@ -25,6 +25,10 @@ class FlowParameterMapper: # From nersc.py "nersc_recon_flow/nersc_recon_flow": [ "file_path", + "config"], + "nersc_recon_multinode_flow/nersc_recon_multinode_flow": [ + "file_path", + "num_nodes", "config"] } @@ -51,27 +55,37 @@ class DecisionFlowInputModel(BaseModel): """ file_path: Optional[str] = Field(default=None) is_export_control: Optional[bool] = Field(default=False) + num_nodes: Optional[int] = Field(default=4) config: Optional[Union[dict, Any]] = Field(default_factory=dict) @task(name="setup_decision_settings") -def setup_decision_settings(alcf_recon: bool, nersc_recon: bool, new_file_832: bool) -> dict: +def setup_decision_settings( + alcf_recon: bool, + nersc_recon: bool, + nersc_recon_multinode: bool, + new_file_832: bool +) -> dict: """ This task is used to define the settings for the decision making process of the BL832 beamline. :param alcf_recon: Boolean indicating whether to run the ALCF reconstruction flow. :param nersc_recon: Boolean indicating whether to run the NERSC reconstruction flow. - :param nersc_move: Boolean indicating whether to move files to NERSC. + :param nersc_recon_multinode: Boolean indicating whether to run the NERSC multinode reconstruction flow. + :param new_file_832: Boolean indicating whether to move files to NERSC. :return: A dictionary containing the settings for each flow. """ logger = get_run_logger() try: logger.info(f"Setting up decision settings: alcf_recon={alcf_recon}, " - f"nersc_recon={nersc_recon}, new_file_832={new_file_832}") + f"nersc_recon={nersc_recon}, " + f"nersc_recon_multinode={nersc_recon_multinode}, " + f"new_file_832={new_file_832}") # Define which flows to run based on the input settings settings = { "alcf_recon_flow/alcf_recon_flow": alcf_recon, "nersc_recon_flow/nersc_recon_flow": nersc_recon, + "nersc_recon_multinode_flow/nersc_recon_multinode_flow": nersc_recon_multinode, "new_832_file_flow/new_file_832": new_file_832 } # Save the settings in a JSON block for later retrieval by other flows @@ -149,6 +163,11 @@ async def dispatcher( nersc_params = FlowParameterMapper.get_flow_parameters("nersc_recon_flow/nersc_recon_flow", available_params) tasks.append(run_recon_flow_async("nersc_recon_flow/nersc_recon_flow", nersc_params)) + if decision_settings.get("nersc_recon_multinode_flow/nersc_recon_multinode_flow"): + nersc_multinode_params = FlowParameterMapper.get_flow_parameters( + "nersc_recon_multinode_flow/nersc_recon_multinode_flow", available_params) + tasks.append(run_recon_flow_async("nersc_recon_multinode_flow/nersc_recon_multinode_flow", nersc_multinode_params)) + # Run ALCF and NERSC flows in parallel, if any if tasks: try: @@ -169,7 +188,12 @@ async def dispatcher( """ try: # Setup decision settings based on input parameters - setup_decision_settings(alcf_recon=True, nersc_recon=True, new_file_832=True) + setup_decision_settings( + alcf_recon=True, + nersc_recon=True, + nersc_recon_multinode=True, + new_file_832=True + ) # Run the main decision flow with the specified parameters # asyncio.run(dispatcher( # config={}, # PYTEST, ALCF, NERSC From 141a5a6e367f6dfdde2e058fdd913fd46d1b4e41 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 9 Feb 2026 11:29:12 -0800 Subject: [PATCH 11/72] Updating segmentation to use inference_v4. --- orchestration/flows/bl832/nersc.py | 772 ++++++++++++++++++++++++++++- 1 file changed, 771 insertions(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 9345ecb6..c83f046a 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -8,7 +8,7 @@ import time from authlib.jose import JsonWebKey -from prefect import flow, get_run_logger +from prefect import flow, get_run_logger, task from prefect.variables import Variable from sfapi_client import Client from sfapi_client.compute import Machine @@ -581,6 +581,707 @@ def build_multi_resolution( return False else: return False + +# def segmentation( +# self, +# recon_folder_path: str = "", +# num_nodes: int = 4, +# ) -> dict: +# """ +# Run SAM3 segmentation at NERSC Perlmutter (optimized). +# """ +# logger.info("Starting NERSC segmentation process.") + +# user = self.client.user() +# pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" +# cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" +# conda_env_path = f"{cfs_path}/envs/sam3" + +# # Paths +# # seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/forge_feb_seg_model_demo_v2/forge_feb_seg_model_demo/" +# seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" +# checkpoints_dir = f"{cfs_path}/tomography_segmentation_scripts/sam3_finetune/sam3/" +# hf_cache_dir = f"{cfs_path}/tomography_segmentation_scripts/.cache/huggingface" + +# bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" +# original_checkpoint = f"{checkpoints_dir}/sam3.pt" +# finetuned_checkpoint = f"{checkpoints_dir}/checkpoint.pt" + +# input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" +# output_folder = recon_folder_path.replace('/rec', '/seg') +# output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}" + +# logger.info(f"Input directory: {input_dir}") +# logger.info(f"Output directory: {output_dir}") +# logger.info(f"HuggingFace cache: {hf_cache_dir}") + +# batch_size = 8 +# nproc_per_node = 4 + +# prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", +# "Water-based Pith cells", "Xylem vessels"] +# prompts_str = " ".join([f'"{p}"' for p in prompts]) + +# if num_nodes <= 4: +# qos = "realtime" +# else: +# qos = "regular" + +# walltime = "00:15:00" + +# job_name = f"seg_{Path(recon_folder_path).name}" + +# job_script = f"""#!/bin/bash +# #SBATCH -q {qos} +# #SBATCH -A als +# #SBATCH -N {num_nodes} # 4 nodes = 16 GPUs total +# #SBATCH -C gpu +# #SBATCH --job-name={job_name} +# #SBATCH --time={walltime} # Reduce to 1 hour (500 images takes ~10 min) +# #SBATCH --ntasks-per-node=1 +# #SBATCH --gpus-per-node=4 +# #SBATCH --cpus-per-task=128 +# #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out +# #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err + +# # Get master node +# export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +# export MASTER_PORT=29500 + +# # Create output and log directories +# mkdir -p {output_dir} +# # mkdir -p logs + +# # Load your conda environment +# module load conda +# # conda activate /pscratch/sd/x/xchong/envs/sam3 +# conda activate {conda_env_path} + +# echo "============================================================" +# echo "JOB STARTED: $(date)" +# echo "============================================================" +# echo "Master: $MASTER_ADDR:$MASTER_PORT" +# echo "Nodes: $SLURM_JOB_NODELIST" +# echo "Job ID: $SLURM_JOB_ID" +# echo "GPUs: $((SLURM_NNODES * 4))" + +# # Count actual images +# NUM_IMAGES=$(ls {input_dir}/*.tif* 2>/dev/null | wc -l) +# echo "Images to process: $NUM_IMAGES" +# echo "============================================================" + +# # Record start time +# START_TIME=$(date +%s) + +# # Change to script directory +# cd {seg_scripts_dir} + +# # Run inference (no nsys for production) +# srun --ntasks-per-node=1 --gpus-per-task=4 bash -c " +# torchrun \ +# --nnodes=\$SLURM_NNODES \ +# --nproc_per_node={nproc_per_node} \ +# --rdzv_id=\$SLURM_JOB_ID \ +# --rdzv_backend=c10d \ +# --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ +# src/inference_v4.py \ +# --input-dir {input_dir} \ +# --output-dir {output_dir} \ +# --patch-size 640 \ +# --batch-size {batch_size} \ +# --confidence 0.5 \ +# --prompts {prompts_str} \\ +# --bpe-path {bpe_path} \\ +# --original-checkpoint {original_checkpoint} \\ +# --finetuned-checkpoint {finetuned_checkpoint} +# " + +# # Record end time and calculate duration +# END_TIME=$(date +%s) +# DURATION=$((END_TIME - START_TIME)) +# MINUTES=$((DURATION / 60)) +# SECONDS=$((DURATION % 60)) +# TIME_PER_IMAGE=$(echo "scale=3; $DURATION / $NUM_IMAGES" | bc) +# THROUGHPUT=$(echo "scale=2; $NUM_IMAGES / $DURATION * 60" | bc) +# SEG_STATUS=$? # ← Capture torchrun's exit status + +# echo "" +# echo "============================================================" +# echo "JOB COMPLETED: $(date)" +# echo "============================================================" +# echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" +# echo "Images processed: $NUM_IMAGES" +# echo "Time per image: ${{TIME_PER_IMAGE}}s" +# echo "Throughput: ${{THROUGHPUT}} images/minute" +# echo "Results saved to: {output_dir}" +# echo "Exit status: $SEG_STATUS" +# exit $SEG_STATUS +# echo "============================================================" +# """ +# # #SBATCH -q {qos} +# # #SBATCH -A als +# # #SBATCH -C gpu +# # #SBATCH --job-name=seg_{Path(recon_folder_path).name} +# # #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out +# # #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err +# # #SBATCH -N {num_nodes} +# # #SBATCH --ntasks-per-node=4 +# # #SBATCH --gpus-per-node=4 +# # #SBATCH --cpus-per-task=8 +# # #SBATCH --time={walltime} + +# # set -e + +# # TIMING_FILE="{pscratch_path}/tomo_seg_logs/timing_$SLURM_JOB_ID.txt" +# # echo "JOB_START=$(date +%s)" > $TIMING_FILE + +# # # Load PyTorch module (NERSC recommended) +# # module load pytorch + +# # # Install additional dependencies +# # pip install --user --quiet \\ +# # einops decord pycocotools psutil \\ +# # "timm>=1.0.17" "numpy>=1.26,<2" \\ +# # tqdm ftfy==6.1.1 regex \\ +# # iopath>=0.1.10 python-dotenv qlty \\ +# # git+https://github.com/facebookresearch/sam3.git 2>/dev/null || true + +# # mkdir -p {output_dir} +# # mkdir -p {pscratch_path}/tomo_seg_logs + +# # # Environment +# # export PYTHONPATH={seg_scripts_dir}:$PYTHONPATH +# # export HF_HOME={hf_cache_dir} +# # export HF_HUB_CACHE={hf_cache_dir} + +# # # Distributed settings +# # export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +# # export MASTER_PORT=29500 + +# # echo "============================================== +# # Job Configuration +# # ============================================== +# # Job ID: $SLURM_JOB_ID +# # Nodes: {num_nodes} +# # GPUs/node: 4 +# # Total GPUs: $(({num_nodes} * 4)) +# # Master: $MASTER_ADDR:$MASTER_PORT +# # Node list: $SLURM_JOB_NODELIST +# # Input: {input_dir} +# # Output: {output_dir} +# # ==============================================" + +# # NUM_FILES=$(ls -1 {input_dir}/*.tif {input_dir}/*.tiff 2>/dev/null | wc -l) +# # echo "Found $NUM_FILES TIFF files to process" + +# # echo "SEG_START=$(date +%s)" >> $TIMING_FILE + +# # cd {seg_scripts_dir} + +# # # NERSC recommended: srun with one task per GPU +# # # Each task gets SLURM_PROCID (global rank) and SLURM_LOCALID (local rank) +# # srun --export=ALL \\ +# # python -u -m src.inference_v2_optimized3 \\ +# # --input-dir {input_dir} \\ +# # --output-dir {output_dir} \\ +# # --bpe-path {checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz \\ +# # --finetuned-checkpoint {checkpoints_dir}/checkpoint.pt \\ +# # --original-checkpoint {checkpoints_dir}/sam3.pt \\ +# # --patch-size 640 \\ +# # --batch-size {batch_size} \\ +# # --confidence 0.5 \\ +# # --prompts {prompts_str} + +# # SEG_STATUS=$? +# # echo "SEG_END=$(date +%s)" >> $TIMING_FILE + +# # if [ $SEG_STATUS -ne 0 ]; then +# # echo "Segmentation failed with exit code $SEG_STATUS" +# # echo "JOB_STATUS=FAILED" >> $TIMING_FILE +# # exit 1 +# # fi + +# # chmod -R 2775 {output_dir} 2>/dev/null || true + +# # echo "JOB_STATUS=SUCCESS" >> $TIMING_FILE +# # echo "JOB_END=$(date +%s)" >> $TIMING_FILE +# # """ +# try: +# logger.info("Submitting segmentation job to Perlmutter.") +# perlmutter = self.client.compute(Machine.perlmutter) + +# # Ensure directories exist +# logger.info("Creating necessary directories...") +# perlmutter.run(f"mkdir -p {pscratch_path}/tomo_seg_logs") +# perlmutter.run(f"mkdir -p {cfs_path}/envs") + +# # Submit job +# job = perlmutter.submit_job(job_script) +# logger.info(f"Submitted job ID: {job.jobid}") + +# # Initial update +# try: +# job.update() +# except Exception as update_err: +# logger.warning(f"Initial job update failed, continuing: {update_err}") + +# # Wait briefly before polling +# time.sleep(60) +# logger.info(f"Job {job.jobid} current state: {job.state}") + +# # Wait for completion +# job.complete() +# logger.info("Segmentation job completed successfully.") + +# # Fetch timing data +# timing = self._fetch_seg_timing_data(perlmutter, pscratch_path, job.jobid) + +# if timing: +# logger.info("=" * 60) +# logger.info("SEGMENTATION TIMING BREAKDOWN") +# logger.info("=" * 60) +# logger.info(f" Total job time: {timing.get('total', 'N/A')}s") + +# if 'env_setup' in timing: +# logger.info(f" Environment setup: {timing['env_setup']}s") + +# logger.info(f" SEGMENTATION: {timing.get('segmentation', 'N/A')}s <-- inference time") +# logger.info(f" Job status: {timing.get('job_status', 'UNKNOWN')}") +# logger.info("=" * 60) + +# return { +# "success": True, +# "job_id": job.jobid, +# "timing": timing, +# "output_dir": output_dir +# } + +# except Exception as e: +# logger.error(f"Error during segmentation job: {e}") +# import traceback +# logger.error(traceback.format_exc()) + +# # Attempt recovery +# match = re.search(r"Job not found:\s*(\d+)", str(e)) +# if match: +# jobid = match.group(1) +# logger.info(f"Attempting to recover job {jobid}.") +# try: +# job = self.client.compute(Machine.perlmutter).job(jobid=jobid) +# time.sleep(30) +# job.complete() +# logger.info("Segmentation job completed after recovery.") + +# timing = self._fetch_seg_timing_data(perlmutter, pscratch_path, jobid) +# return { +# "success": True, +# "job_id": jobid, +# "timing": timing, +# "output_dir": output_dir +# } +# except Exception as recovery_err: +# logger.error(f"Failed to recover job {jobid}: {recovery_err}") + +# return { +# "success": False, +# "job_id": None, +# "timing": None, +# "output_dir": None +# } + + +# def _fetch_seg_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dict: +# """ +# Fetch and parse timing data from the segmentation job. + +# :param perlmutter: SFAPI compute object for Perlmutter +# :param pscratch_path: Path to the user's pscratch directory +# :param job_id: SLURM job ID +# :return: Dictionary with timing breakdown +# """ +# timing_file = f"{pscratch_path}/tomo_seg_logs/timing_{job_id}.txt" + +# try: +# # Use SFAPI to read the timing file +# result = perlmutter.run(f"cat {timing_file}") + +# # Handle different result types +# if isinstance(result, str): +# output = result +# elif hasattr(result, 'output'): +# output = result.output +# elif hasattr(result, 'stdout'): +# output = result.stdout +# else: +# output = str(result) + +# logger.info(f"Timing file contents:\n{output}") + +# # Parse timing data +# timing = {} +# for line in output.strip().split('\n'): +# if '=' in line: +# key, value = line.split('=', 1) +# timing[key] = value.strip() + +# # Calculate durations +# breakdown = {} + +# if 'JOB_START' in timing and 'JOB_END' in timing: +# breakdown['total'] = int(timing['JOB_END']) - int(timing['JOB_START']) + +# if 'ENV_SETUP_START' in timing and 'ENV_SETUP_END' in timing: +# breakdown['env_setup'] = int(timing['ENV_SETUP_END']) - int(timing['ENV_SETUP_START']) + +# if 'SEG_START' in timing and 'SEG_END' in timing: +# breakdown['segmentation'] = int(timing['SEG_END']) - int(timing['SEG_START']) + +# breakdown['job_status'] = timing.get('JOB_STATUS', 'UNKNOWN') + +# return breakdown + +# except Exception as e: +# logger.warning(f"Error fetching timing data: {e}") +# import traceback +# logger.warning(traceback.format_exc()) +# return None + + def segmentation( + self, + recon_folder_path: str = "", + num_nodes: int = 4, + ) -> dict: + """ + Run SAM3 segmentation at NERSC Perlmutter (v4 with overlap + max confidence stitching). + """ + logger.info("Starting NERSC segmentation process (inference_v4).") + + user = self.client.user() + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" + conda_env_path = f"{cfs_path}/envs/sam3" + + # Paths + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" + checkpoints_dir = f"{cfs_path}/tomography_segmentation_scripts/sam3_finetune/sam3/" + + bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" + original_checkpoint = f"{checkpoints_dir}/sam3.pt" + finetuned_checkpoint = f"{checkpoints_dir}/checkpoint.pt" + + input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" + output_folder = recon_folder_path.replace('/rec', '/seg') + output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}" + + logger.info(f"Input directory: {input_dir}") + logger.info(f"Output directory: {output_dir}") + logger.info(f"Conda environment: {conda_env_path}") + + batch_size = 8 + nproc_per_node = 4 + + prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", + "Water-based Pith cells", "Xylem vessels"] + prompts_str = " ".join([f'"{p}"' for p in prompts]) + + if num_nodes <= 4: + qos = "realtime" + else: + qos = "regular" + + walltime = "00:59:00" + + job_name = f"seg_{Path(recon_folder_path).name}" + + job_script = f"""#!/bin/bash +#SBATCH -q {qos} +#SBATCH -A als +#SBATCH -N {num_nodes} +#SBATCH -C gpu +#SBATCH --job-name={job_name} +#SBATCH --time={walltime} +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=128 +#SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err + +# Create output and log directories +mkdir -p {output_dir} +mkdir -p {pscratch_path}/tomo_seg_logs + +# Load conda module +module load conda + +# Check if environment exists, create if it doesn't +if [ ! -d "{conda_env_path}" ]; then + echo "Conda environment not found at {conda_env_path}" + echo "Creating new environment..." + + # Check if Xiaoya's environment exists as a reference + if [ -d "/pscratch/sd/x/xchong/envs/sam3" ]; then + echo "Cloning from Xiaoya's environment..." + conda create --prefix {conda_env_path} --clone /pscratch/sd/x/xchong/envs/sam3 -y + else + echo "Creating fresh environment..." + conda create --prefix {conda_env_path} python=3.10 -y + conda activate {conda_env_path} + + # Install dependencies + pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 + pip install einops decord pycocotools psutil + pip install "timm>=1.0.17" "numpy>=1.26,<2" + pip install tqdm ftfy==6.1.1 regex + pip install iopath>=0.1.10 python-dotenv qlty + pip install transformers + pip install git+https://github.com/facebookresearch/sam3.git + + conda deactivate + fi + + echo "Environment setup complete" +else + echo "Using existing conda environment at {conda_env_path}" +fi + +# Activate the environment +conda activate {conda_env_path} + +# Get master node +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +export MASTER_PORT=29500 + +echo "============================================================" +echo "JOB STARTED: $(date)" +echo "============================================================" +echo "Master: $MASTER_ADDR:$MASTER_PORT" +echo "Nodes: $SLURM_JOB_NODELIST" +echo "Job ID: $SLURM_JOB_ID" +echo "GPUs: $((SLURM_NNODES * 4))" +echo "Input: {input_dir}" +echo "Output: {output_dir}" + +# Count actual images +NUM_IMAGES=$(ls {input_dir}/*.tif* 2>/dev/null | wc -l) +echo "Images to process: $NUM_IMAGES" +echo "============================================================" + +# Record start time +START_TIME=$(date +%s) + +# Change to script directory +cd {seg_scripts_dir} + +# Run inference with v4 +srun --ntasks-per-node=1 --gpus-per-task=4 bash -c " +torchrun \\ + --nnodes=\$SLURM_NNODES \\ + --nproc_per_node={nproc_per_node} \\ + --rdzv_id=\$SLURM_JOB_ID \\ + --rdzv_backend=c10d \\ + --rdzv_endpoint=\$MASTER_ADDR:\$MASTER_PORT \\ + src/inference_v4.py \\ + --input-dir {input_dir} \\ + --output-dir {output_dir} \\ + --patch-size 640 \\ + --batch-size {batch_size} \\ + --confidence 0.5 \\ + --prompts {prompts_str} \\ + --bpe-path {bpe_path} \\ + --original-checkpoint {original_checkpoint} \\ + --finetuned-checkpoint {finetuned_checkpoint} +" + +SEG_STATUS=$? + +# Record end time and calculate duration +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) +MINUTES=$((DURATION / 60)) +SECONDS=$((DURATION % 60)) + +if [ $NUM_IMAGES -gt 0 ]; then + TIME_PER_IMAGE=$(echo "scale=3; $DURATION / $NUM_IMAGES" | bc) + THROUGHPUT=$(echo "scale=2; $NUM_IMAGES / $DURATION * 60" | bc) +else + TIME_PER_IMAGE="N/A" + THROUGHPUT="N/A" +fi + +echo "" +echo "============================================================" +echo "JOB COMPLETED: $(date)" +echo "============================================================" +echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" +echo "Images processed: $NUM_IMAGES" +echo "Time per image: ${{TIME_PER_IMAGE}}s" +echo "Throughput: ${{THROUGHPUT}} images/minute" +echo "Results saved to: {output_dir}" +echo "Exit status: $SEG_STATUS" +echo "============================================================" + +# Set permissions +chmod -R 2775 {output_dir} 2>/dev/null || true + +exit $SEG_STATUS +""" + + try: + logger.info("Submitting segmentation job to Perlmutter (v4).") + perlmutter = self.client.compute(Machine.perlmutter) + + # Ensure directories exist + logger.info("Creating necessary directories...") + perlmutter.run(f"mkdir -p {pscratch_path}/tomo_seg_logs") + perlmutter.run(f"mkdir -p {output_dir}") + + # Submit job + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + # Initial update + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + # Wait briefly before polling + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + # Wait for completion + job.complete() + logger.info("Segmentation job completed successfully.") + + # Fetch timing data from output file + timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, job.jobid, job_name) + + if timing: + logger.info("=" * 60) + logger.info("SEGMENTATION TIMING BREAKDOWN") + logger.info("=" * 60) + logger.info(f" Total time: {timing.get('total_time', 'N/A')}") + logger.info(f" Images processed: {timing.get('num_images', 'N/A')}") + logger.info(f" Time per image: {timing.get('time_per_image', 'N/A')}") + logger.info(f" Throughput: {timing.get('throughput', 'N/A')} images/min") + logger.info(f" Exit status: {timing.get('exit_status', 'N/A')}") + logger.info("=" * 60) + + return { + "success": True, + "job_id": job.jobid, + "timing": timing, + "output_dir": output_dir + } + + except Exception as e: + logger.error(f"Error during segmentation job: {e}") + import traceback + logger.error(traceback.format_exc()) + + # Attempt recovery + match = re.search(r"Job not found:\s*(\d+)", str(e)) + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.compute(Machine.perlmutter).job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("Segmentation job completed after recovery.") + + timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, jobid, job_name) + return { + "success": True, + "job_id": jobid, + "timing": timing, + "output_dir": output_dir + } + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + + return { + "success": False, + "job_id": None, + "timing": None, + "output_dir": None + } + + + def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: str, job_name: str) -> dict: + """ + Fetch and parse timing data from the SLURM output file. + + :param perlmutter: SFAPI compute object for Perlmutter + :param pscratch_path: Path to the user's pscratch directory + :param job_id: SLURM job ID + :param job_name: Job name for finding output file + :return: Dictionary with timing breakdown + """ + output_file = f"{pscratch_path}/tomo_seg_logs/{job_name}_{job_id}.out" + + try: + # Use SFAPI to read the output file + result = perlmutter.run(f"cat {output_file}") + + # Handle different result types + if isinstance(result, str): + output = result + elif hasattr(result, 'output'): + output = result.output + elif hasattr(result, 'stdout'): + output = result.stdout + else: + output = str(result) + + logger.info(f"Job output file contents (last 50 lines):") + lines = output.strip().split('\n') + for line in lines[-50:]: + logger.info(f" {line}") + + # Parse timing data from the output + timing = {} + + for line in lines: + if "Total time:" in line: + # Extract: "Total time: 5m 23s (323s)" + match = re.search(r'(\d+)m\s+(\d+)s\s+\((\d+)s\)', line) + if match: + timing['total_time'] = f"{match.group(1)}m {match.group(2)}s" + timing['total_seconds'] = int(match.group(3)) + + elif "Images processed:" in line: + # Extract: "Images processed: 100" + match = re.search(r'Images processed:\s+(\d+)', line) + if match: + timing['num_images'] = int(match.group(1)) + + elif "Time per image:" in line: + # Extract: "Time per image: 3.230s" + match = re.search(r'Time per image:\s+([\d.]+)s', line) + if match: + timing['time_per_image'] = f"{match.group(1)}s" + + elif "Throughput:" in line: + # Extract: "Throughput: 18.58 images/minute" + match = re.search(r'Throughput:\s+([\d.]+)\s+images/minute', line) + if match: + timing['throughput'] = float(match.group(1)) + + elif "Exit status:" in line: + # Extract: "Exit status: 0" + match = re.search(r'Exit status:\s+(\d+)', line) + if match: + timing['exit_status'] = int(match.group(1)) + + return timing if timing else None + + except Exception as e: + logger.warning(f"Error fetching timing data from output: {e}") + import traceback + logger.warning(traceback.format_exc()) + return None + def start_streaming_service( self, @@ -1129,6 +1830,75 @@ def pull_shifter_image_flow( return success +@task(name="nersc_segmentation_task") +def nersc_segmentation_task( + recon_folder_path: str, + config: Optional[Config832] = None, +) -> bool: + """ + Run segmentation task at NERSC. + + :param recon_folder_path: Path to the reconstructed data folder to be processed. + :param config: Configuration object for the flow. + :return: True if the task completed successfully, False otherwise. + """ + logger = get_run_logger() + if config is None: + logger.info("No config provided, using default Config832.") + config = Config832() + + # Initialize the Tomography Controller and run the segmentation + logger.info("Initializing NERSC Tomography HPC Controller.") + tomography_controller = get_controller( + hpc_type=HPC.NERSC, + config=config + ) + logger.info(f"Starting NERSC segmentation task for {recon_folder_path=}") + nersc_segmentation_success = tomography_controller.segmentation( + recon_folder_path=recon_folder_path, + ) + if not nersc_segmentation_success: + logger.error("Segmentation Failed.") + else: + logger.info("Segmentation Successful.") + return nersc_segmentation_success + + +@flow(name="nersc_segmentation_integration_test", flow_run_name="nersc_segmentation_integration_test") +def nersc_segmentation_integration_test() -> bool: + """ + Integration test for the NERSC segmentation task. + + :return: True if the segmentation task completed successfully, False otherwise. + """ + logger = get_run_logger() + logger.info("Starting NERSC segmentation integration test.") + recon_folder_path = 'synaps-i/rec20211222_125057_petiole4' # 'test' # + flow_success = nersc_segmentation_task( + recon_folder_path=recon_folder_path, + config=Config832() + ) + logger.info(f"Flow success: {flow_success}") + return flow_success + + +if __name__ == "__main__": + # Run the integration test flow + + # from sfapi_client import Client + # from sfapi_client.compute import Machine + + # # Use your existing client setup + # client = NERSCTomographyHPCController.create_sfapi_client() + # perlmutter = client.compute(Machine.perlmutter) + + # job.cancel() + # job = perlmutter.job(jobid=48570063) + # print(f"Job {job.jobid} cancelled, state: {job.state}") + + result = nersc_segmentation_integration_test() + print(f"Integration test result: {result}") + # if __name__ == "__main__": # config = Config832() From b4cab67ee62f9fa84c8026ae1b56bea9d7a158a8 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 9 Feb 2026 14:33:03 -0800 Subject: [PATCH 12/72] removing comments. segmentation still isn't working --- orchestration/flows/bl832/nersc.py | 720 ++++++++++++----------------- 1 file changed, 285 insertions(+), 435 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index c83f046a..5cb59aba 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -16,6 +16,7 @@ from orchestration.flows.bl832.config import Config832 from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController +from orchestration.prune_controller import get_prune_controller, PruneMethod from orchestration.transfer_controller import get_transfer_controller, CopyMethod from orchestration.flows.bl832.streaming_mixin import ( NerscStreamingMixin, SlurmJobBlock, cancellation_hook, monitor_streaming_job, save_block @@ -582,369 +583,6 @@ def build_multi_resolution( else: return False -# def segmentation( -# self, -# recon_folder_path: str = "", -# num_nodes: int = 4, -# ) -> dict: -# """ -# Run SAM3 segmentation at NERSC Perlmutter (optimized). -# """ -# logger.info("Starting NERSC segmentation process.") - -# user = self.client.user() -# pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" -# cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" -# conda_env_path = f"{cfs_path}/envs/sam3" - -# # Paths -# # seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/forge_feb_seg_model_demo_v2/forge_feb_seg_model_demo/" -# seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" -# checkpoints_dir = f"{cfs_path}/tomography_segmentation_scripts/sam3_finetune/sam3/" -# hf_cache_dir = f"{cfs_path}/tomography_segmentation_scripts/.cache/huggingface" - -# bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" -# original_checkpoint = f"{checkpoints_dir}/sam3.pt" -# finetuned_checkpoint = f"{checkpoints_dir}/checkpoint.pt" - -# input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" -# output_folder = recon_folder_path.replace('/rec', '/seg') -# output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}" - -# logger.info(f"Input directory: {input_dir}") -# logger.info(f"Output directory: {output_dir}") -# logger.info(f"HuggingFace cache: {hf_cache_dir}") - -# batch_size = 8 -# nproc_per_node = 4 - -# prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", -# "Water-based Pith cells", "Xylem vessels"] -# prompts_str = " ".join([f'"{p}"' for p in prompts]) - -# if num_nodes <= 4: -# qos = "realtime" -# else: -# qos = "regular" - -# walltime = "00:15:00" - -# job_name = f"seg_{Path(recon_folder_path).name}" - -# job_script = f"""#!/bin/bash -# #SBATCH -q {qos} -# #SBATCH -A als -# #SBATCH -N {num_nodes} # 4 nodes = 16 GPUs total -# #SBATCH -C gpu -# #SBATCH --job-name={job_name} -# #SBATCH --time={walltime} # Reduce to 1 hour (500 images takes ~10 min) -# #SBATCH --ntasks-per-node=1 -# #SBATCH --gpus-per-node=4 -# #SBATCH --cpus-per-task=128 -# #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out -# #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err - -# # Get master node -# export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) -# export MASTER_PORT=29500 - -# # Create output and log directories -# mkdir -p {output_dir} -# # mkdir -p logs - -# # Load your conda environment -# module load conda -# # conda activate /pscratch/sd/x/xchong/envs/sam3 -# conda activate {conda_env_path} - -# echo "============================================================" -# echo "JOB STARTED: $(date)" -# echo "============================================================" -# echo "Master: $MASTER_ADDR:$MASTER_PORT" -# echo "Nodes: $SLURM_JOB_NODELIST" -# echo "Job ID: $SLURM_JOB_ID" -# echo "GPUs: $((SLURM_NNODES * 4))" - -# # Count actual images -# NUM_IMAGES=$(ls {input_dir}/*.tif* 2>/dev/null | wc -l) -# echo "Images to process: $NUM_IMAGES" -# echo "============================================================" - -# # Record start time -# START_TIME=$(date +%s) - -# # Change to script directory -# cd {seg_scripts_dir} - -# # Run inference (no nsys for production) -# srun --ntasks-per-node=1 --gpus-per-task=4 bash -c " -# torchrun \ -# --nnodes=\$SLURM_NNODES \ -# --nproc_per_node={nproc_per_node} \ -# --rdzv_id=\$SLURM_JOB_ID \ -# --rdzv_backend=c10d \ -# --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ -# src/inference_v4.py \ -# --input-dir {input_dir} \ -# --output-dir {output_dir} \ -# --patch-size 640 \ -# --batch-size {batch_size} \ -# --confidence 0.5 \ -# --prompts {prompts_str} \\ -# --bpe-path {bpe_path} \\ -# --original-checkpoint {original_checkpoint} \\ -# --finetuned-checkpoint {finetuned_checkpoint} -# " - -# # Record end time and calculate duration -# END_TIME=$(date +%s) -# DURATION=$((END_TIME - START_TIME)) -# MINUTES=$((DURATION / 60)) -# SECONDS=$((DURATION % 60)) -# TIME_PER_IMAGE=$(echo "scale=3; $DURATION / $NUM_IMAGES" | bc) -# THROUGHPUT=$(echo "scale=2; $NUM_IMAGES / $DURATION * 60" | bc) -# SEG_STATUS=$? # ← Capture torchrun's exit status - -# echo "" -# echo "============================================================" -# echo "JOB COMPLETED: $(date)" -# echo "============================================================" -# echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" -# echo "Images processed: $NUM_IMAGES" -# echo "Time per image: ${{TIME_PER_IMAGE}}s" -# echo "Throughput: ${{THROUGHPUT}} images/minute" -# echo "Results saved to: {output_dir}" -# echo "Exit status: $SEG_STATUS" -# exit $SEG_STATUS -# echo "============================================================" -# """ -# # #SBATCH -q {qos} -# # #SBATCH -A als -# # #SBATCH -C gpu -# # #SBATCH --job-name=seg_{Path(recon_folder_path).name} -# # #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out -# # #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err -# # #SBATCH -N {num_nodes} -# # #SBATCH --ntasks-per-node=4 -# # #SBATCH --gpus-per-node=4 -# # #SBATCH --cpus-per-task=8 -# # #SBATCH --time={walltime} - -# # set -e - -# # TIMING_FILE="{pscratch_path}/tomo_seg_logs/timing_$SLURM_JOB_ID.txt" -# # echo "JOB_START=$(date +%s)" > $TIMING_FILE - -# # # Load PyTorch module (NERSC recommended) -# # module load pytorch - -# # # Install additional dependencies -# # pip install --user --quiet \\ -# # einops decord pycocotools psutil \\ -# # "timm>=1.0.17" "numpy>=1.26,<2" \\ -# # tqdm ftfy==6.1.1 regex \\ -# # iopath>=0.1.10 python-dotenv qlty \\ -# # git+https://github.com/facebookresearch/sam3.git 2>/dev/null || true - -# # mkdir -p {output_dir} -# # mkdir -p {pscratch_path}/tomo_seg_logs - -# # # Environment -# # export PYTHONPATH={seg_scripts_dir}:$PYTHONPATH -# # export HF_HOME={hf_cache_dir} -# # export HF_HUB_CACHE={hf_cache_dir} - -# # # Distributed settings -# # export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) -# # export MASTER_PORT=29500 - -# # echo "============================================== -# # Job Configuration -# # ============================================== -# # Job ID: $SLURM_JOB_ID -# # Nodes: {num_nodes} -# # GPUs/node: 4 -# # Total GPUs: $(({num_nodes} * 4)) -# # Master: $MASTER_ADDR:$MASTER_PORT -# # Node list: $SLURM_JOB_NODELIST -# # Input: {input_dir} -# # Output: {output_dir} -# # ==============================================" - -# # NUM_FILES=$(ls -1 {input_dir}/*.tif {input_dir}/*.tiff 2>/dev/null | wc -l) -# # echo "Found $NUM_FILES TIFF files to process" - -# # echo "SEG_START=$(date +%s)" >> $TIMING_FILE - -# # cd {seg_scripts_dir} - -# # # NERSC recommended: srun with one task per GPU -# # # Each task gets SLURM_PROCID (global rank) and SLURM_LOCALID (local rank) -# # srun --export=ALL \\ -# # python -u -m src.inference_v2_optimized3 \\ -# # --input-dir {input_dir} \\ -# # --output-dir {output_dir} \\ -# # --bpe-path {checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz \\ -# # --finetuned-checkpoint {checkpoints_dir}/checkpoint.pt \\ -# # --original-checkpoint {checkpoints_dir}/sam3.pt \\ -# # --patch-size 640 \\ -# # --batch-size {batch_size} \\ -# # --confidence 0.5 \\ -# # --prompts {prompts_str} - -# # SEG_STATUS=$? -# # echo "SEG_END=$(date +%s)" >> $TIMING_FILE - -# # if [ $SEG_STATUS -ne 0 ]; then -# # echo "Segmentation failed with exit code $SEG_STATUS" -# # echo "JOB_STATUS=FAILED" >> $TIMING_FILE -# # exit 1 -# # fi - -# # chmod -R 2775 {output_dir} 2>/dev/null || true - -# # echo "JOB_STATUS=SUCCESS" >> $TIMING_FILE -# # echo "JOB_END=$(date +%s)" >> $TIMING_FILE -# # """ -# try: -# logger.info("Submitting segmentation job to Perlmutter.") -# perlmutter = self.client.compute(Machine.perlmutter) - -# # Ensure directories exist -# logger.info("Creating necessary directories...") -# perlmutter.run(f"mkdir -p {pscratch_path}/tomo_seg_logs") -# perlmutter.run(f"mkdir -p {cfs_path}/envs") - -# # Submit job -# job = perlmutter.submit_job(job_script) -# logger.info(f"Submitted job ID: {job.jobid}") - -# # Initial update -# try: -# job.update() -# except Exception as update_err: -# logger.warning(f"Initial job update failed, continuing: {update_err}") - -# # Wait briefly before polling -# time.sleep(60) -# logger.info(f"Job {job.jobid} current state: {job.state}") - -# # Wait for completion -# job.complete() -# logger.info("Segmentation job completed successfully.") - -# # Fetch timing data -# timing = self._fetch_seg_timing_data(perlmutter, pscratch_path, job.jobid) - -# if timing: -# logger.info("=" * 60) -# logger.info("SEGMENTATION TIMING BREAKDOWN") -# logger.info("=" * 60) -# logger.info(f" Total job time: {timing.get('total', 'N/A')}s") - -# if 'env_setup' in timing: -# logger.info(f" Environment setup: {timing['env_setup']}s") - -# logger.info(f" SEGMENTATION: {timing.get('segmentation', 'N/A')}s <-- inference time") -# logger.info(f" Job status: {timing.get('job_status', 'UNKNOWN')}") -# logger.info("=" * 60) - -# return { -# "success": True, -# "job_id": job.jobid, -# "timing": timing, -# "output_dir": output_dir -# } - -# except Exception as e: -# logger.error(f"Error during segmentation job: {e}") -# import traceback -# logger.error(traceback.format_exc()) - -# # Attempt recovery -# match = re.search(r"Job not found:\s*(\d+)", str(e)) -# if match: -# jobid = match.group(1) -# logger.info(f"Attempting to recover job {jobid}.") -# try: -# job = self.client.compute(Machine.perlmutter).job(jobid=jobid) -# time.sleep(30) -# job.complete() -# logger.info("Segmentation job completed after recovery.") - -# timing = self._fetch_seg_timing_data(perlmutter, pscratch_path, jobid) -# return { -# "success": True, -# "job_id": jobid, -# "timing": timing, -# "output_dir": output_dir -# } -# except Exception as recovery_err: -# logger.error(f"Failed to recover job {jobid}: {recovery_err}") - -# return { -# "success": False, -# "job_id": None, -# "timing": None, -# "output_dir": None -# } - - -# def _fetch_seg_timing_data(self, perlmutter, pscratch_path: str, job_id: str) -> dict: -# """ -# Fetch and parse timing data from the segmentation job. - -# :param perlmutter: SFAPI compute object for Perlmutter -# :param pscratch_path: Path to the user's pscratch directory -# :param job_id: SLURM job ID -# :return: Dictionary with timing breakdown -# """ -# timing_file = f"{pscratch_path}/tomo_seg_logs/timing_{job_id}.txt" - -# try: -# # Use SFAPI to read the timing file -# result = perlmutter.run(f"cat {timing_file}") - -# # Handle different result types -# if isinstance(result, str): -# output = result -# elif hasattr(result, 'output'): -# output = result.output -# elif hasattr(result, 'stdout'): -# output = result.stdout -# else: -# output = str(result) - -# logger.info(f"Timing file contents:\n{output}") - -# # Parse timing data -# timing = {} -# for line in output.strip().split('\n'): -# if '=' in line: -# key, value = line.split('=', 1) -# timing[key] = value.strip() - -# # Calculate durations -# breakdown = {} - -# if 'JOB_START' in timing and 'JOB_END' in timing: -# breakdown['total'] = int(timing['JOB_END']) - int(timing['JOB_START']) - -# if 'ENV_SETUP_START' in timing and 'ENV_SETUP_END' in timing: -# breakdown['env_setup'] = int(timing['ENV_SETUP_END']) - int(timing['ENV_SETUP_START']) - -# if 'SEG_START' in timing and 'SEG_END' in timing: -# breakdown['segmentation'] = int(timing['SEG_END']) - int(timing['SEG_START']) - -# breakdown['job_status'] = timing.get('JOB_STATUS', 'UNKNOWN') - -# return breakdown - -# except Exception as e: -# logger.warning(f"Error fetching timing data: {e}") -# import traceback -# logger.warning(traceback.format_exc()) -# return None def segmentation( self, @@ -1006,51 +644,27 @@ def segmentation( #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err +# Get master node +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +export MASTER_PORT=29500 + +# Set HuggingFace cache to pre-downloaded model files (avoid gated repo auth) +export HF_HOME="{cfs_path}/tomography_segmentation_scripts/.cache/huggingface" +export HF_HUB_CACHE="{cfs_path}/tomography_segmentation_scripts/.cache/huggingface" + +# Set parameters +export INPUT_DIR="{input_dir}" +export OUTPUT_DIR="{output_dir}" +export BATCH_SIZE={batch_size} + # Create output and log directories -mkdir -p {output_dir} +mkdir -p ${{OUTPUT_DIR}} mkdir -p {pscratch_path}/tomo_seg_logs -# Load conda module +# Load conda module and activate environment module load conda - -# Check if environment exists, create if it doesn't -if [ ! -d "{conda_env_path}" ]; then - echo "Conda environment not found at {conda_env_path}" - echo "Creating new environment..." - - # Check if Xiaoya's environment exists as a reference - if [ -d "/pscratch/sd/x/xchong/envs/sam3" ]; then - echo "Cloning from Xiaoya's environment..." - conda create --prefix {conda_env_path} --clone /pscratch/sd/x/xchong/envs/sam3 -y - else - echo "Creating fresh environment..." - conda create --prefix {conda_env_path} python=3.10 -y - conda activate {conda_env_path} - - # Install dependencies - pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118 - pip install einops decord pycocotools psutil - pip install "timm>=1.0.17" "numpy>=1.26,<2" - pip install tqdm ftfy==6.1.1 regex - pip install iopath>=0.1.10 python-dotenv qlty - pip install transformers - pip install git+https://github.com/facebookresearch/sam3.git - - conda deactivate - fi - - echo "Environment setup complete" -else - echo "Using existing conda environment at {conda_env_path}" -fi - -# Activate the environment conda activate {conda_env_path} -# Get master node -export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) -export MASTER_PORT=29500 - echo "============================================================" echo "JOB STARTED: $(date)" echo "============================================================" @@ -1058,12 +672,12 @@ def segmentation( echo "Nodes: $SLURM_JOB_NODELIST" echo "Job ID: $SLURM_JOB_ID" echo "GPUs: $((SLURM_NNODES * 4))" -echo "Input: {input_dir}" -echo "Output: {output_dir}" +echo "Input: ${{INPUT_DIR}}" +echo "Output: ${{OUTPUT_DIR}}" # Count actual images -NUM_IMAGES=$(ls {input_dir}/*.tif* 2>/dev/null | wc -l) -echo "Images to process: $NUM_IMAGES" +NUM_IMAGES=$(ls ${{INPUT_DIR}}/*.tif* 2>/dev/null | wc -l) +echo "Images to process: ${{NUM_IMAGES}}" echo "============================================================" # Record start time @@ -1073,24 +687,31 @@ def segmentation( cd {seg_scripts_dir} # Run inference with v4 -srun --ntasks-per-node=1 --gpus-per-task=4 bash -c " -torchrun \\ - --nnodes=\$SLURM_NNODES \\ - --nproc_per_node={nproc_per_node} \\ - --rdzv_id=\$SLURM_JOB_ID \\ - --rdzv_backend=c10d \\ - --rdzv_endpoint=\$MASTER_ADDR:\$MASTER_PORT \\ - src/inference_v4.py \\ - --input-dir {input_dir} \\ - --output-dir {output_dir} \\ - --patch-size 640 \\ - --batch-size {batch_size} \\ - --confidence 0.5 \\ - --prompts {prompts_str} \\ - --bpe-path {bpe_path} \\ - --original-checkpoint {original_checkpoint} \\ - --finetuned-checkpoint {finetuned_checkpoint} -" + +export TORCH_DISTRIBUTED_DEBUG=DETAIL +export NCCL_DEBUG=INFO +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export BPE_PATH="{bpe_path}" +export ORIG_CKPT="{original_checkpoint}" +export FT_CKPT="{finetuned_checkpoint}" + +srun --ntasks-per-node=1 --gpus-per-task=4 \ + torchrun \ + --nnodes=$SLURM_NNODES \ + --nproc_per_node=4 \ + --rdzv_id=$SLURM_JOB_ID \ + --rdzv_backend=c10d \ + --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ + src/inference_v4.py \ + --input-dir "${{INPUT_DIR}}" \ + --output-dir "${{OUTPUT_DIR}}" \ + --patch-size 640 \ + --batch-size "${{BATCH_SIZE}}" \ + --confidence 0.5 \ + --prompts 'Cortex' 'Phloem Fibers' 'Air-based Pith cells' 'Water-based Pith cells' 'Xylem vessels' \ + --bpe-path "${{BPE_PATH}}" \ + --original-checkpoint "${{ORIG_CKPT}}" \ + --finetuned-checkpoint "${{FT_CKPT}}" SEG_STATUS=$? @@ -1116,12 +737,12 @@ def segmentation( echo "Images processed: $NUM_IMAGES" echo "Time per image: ${{TIME_PER_IMAGE}}s" echo "Throughput: ${{THROUGHPUT}} images/minute" -echo "Results saved to: {output_dir}" +echo "Results saved to: ${{OUTPUT_DIR}}" echo "Exit status: $SEG_STATUS" echo "============================================================" # Set permissions -chmod -R 2775 {output_dir} 2>/dev/null || true +chmod -R 2775 ${{OUTPUT_DIR}} 2>/dev/null || true exit $SEG_STATUS """ @@ -1769,6 +1390,235 @@ def nersc_recon_multinode_flow( return False +@flow(name="nersc_forge_recon_segment_flow", flow_run_name="nersc_recon_seg-{file_path}") +def nersc_forge_recon_segment_flow( + file_path: str, + config: Optional[Config832] = None, +) -> bool: + """ + Process and transfer a file from bl832 to NERSC and run reconstruction and segmentation. + + :param file_path: The path to the file to be processed. + :param config: Configuration object for the flow. + :return: True if the flow completed successfully, False otherwise. + """ + logger = get_run_logger() + + # STEP 1: Setup Configuration + if config is None: + logger.info("Initializing Config") + config = Config832() + + # Paths + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + scratch_path_tiff = f"{folder_name}/rec{file_name}" + scratch_path_segment = f"{folder_name}/seg{file_name}" + + logger.info(f"Starting NERSC reconstruction + segmentation flow for {file_path=}") + logger.info(f"Reconstructed TIFFs will be at: {scratch_path_tiff}") + logger.info(f"Segmented output will be at: {scratch_path_segment}") + + transfer_controller = get_transfer_controller( + transfer_type=CopyMethod.GLOBUS, + config=config + ) + + controller = get_controller( + hpc_type=HPC.NERSC, + config=config + ) + logger.info("NERSC reconstruction controller initialized") + + if num_nodes is None: + num_nodes = config.nersc_recon_num_nodes + logger.info(f"Configured to use {num_nodes} nodes for reconstruction") + + # Track success for pruning decisions + nersc_reconstruction_success = False + nersc_segmentation_success = False + data832_tiff_transfer_success = False + data832_segment_transfer_success = False + + # STEP 2: Run Multinode Reconstruction at NERSC + logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") + recon_result = controller.reconstruct_multinode( + file_path=file_path, + num_nodes=num_nodes + ) + + if isinstance(recon_result, dict): + nersc_reconstruction_success = recon_result.get('success', False) + timing = recon_result.get('timing') + + if timing: + logger.info("=" * 50) + logger.info("TIMING BREAKDOWN") + logger.info("=" * 50) + logger.info(f" Total job time: {timing.get('total', 'N/A')}s") + logger.info(f" Container pull: {timing.get('container_pull', 'N/A')}s") + logger.info( + f" File copy: {timing.get('file_copy', 'N/A')}s " + f"(skipped: {timing.get('copy_skipped', 'N/A')})" + ) + logger.info(f" Metadata detection: {timing.get('metadata', 'N/A')}s") + logger.info(f" RECONSTRUCTION: {timing.get('reconstruction', 'N/A')}s <-- actual recon time") + logger.info(f" Num slices: {timing.get('num_slices', 'N/A')}") + logger.info("=" * 50) + + # Calculate overhead + if all(k in timing for k in ['total', 'reconstruction']): + overhead = timing['total'] - timing['reconstruction'] + logger.info(f" Overhead: {overhead}s") + logger.info(f" Reconstruction %: {100 * timing['reconstruction'] / timing['total']:.1f}%") + logger.info("=" * 50) + else: + nersc_reconstruction_success = recon_result + + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + + if not nersc_reconstruction_success: + logger.error("Reconstruction Failed.") + raise ValueError("Reconstruction at NERSC Failed") + else: + logger.info("Reconstruction Successful.") + + # STEP 3: Send reconstructed data (tiff) to data832 + logger.info(f"Transferring reconstructed TIFFs from NERSC pscratch to data832") + try: + data832_tiff_transfer_success = transfer_controller.copy( + file_path=scratch_path_tiff, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + logger.info(f"Transfer reconstructed TIFF data to data832 success: {data832_tiff_transfer_success}") + except Exception as e: + logger.error(f"Failed to transfer TIFFs to data832: {e}") + data832_tiff_transfer_success = False + + # STEP 4: Run the Segmentation Task at NERSC + logger.info(f"Starting NERSC segmentation task for {scratch_path_tiff=}") + seg_result = nersc_segmentation_task( + recon_folder_path=scratch_path_tiff, + config=config + ) + if isinstance(seg_result, dict): + nersc_segmentation_success = seg_result.get('success', False) + timing = seg_result.get('timing') + + if timing: + logger.info("=" * 50) + logger.info("SEGMENTATION TIMING BREAKDOWN") + logger.info("=" * 50) + logger.info(f" Total time: {timing.get('total_time', 'N/A')}") + logger.info(f" Images processed: {timing.get('num_images', 'N/A')}") + logger.info(f" Time per image: {timing.get('time_per_image', 'N/A')}") + logger.info(f" Throughput: {timing.get('throughput', 'N/A')} images/min") + logger.info(f" Exit status: {timing.get('exit_status', 'N/A')}") + logger.info("=" * 50) + else: + nersc_segmentation_success = bool(seg_result) + + if not nersc_segmentation_success: + logger.warning("Segmentation at NERSC Failed") + else: + logger.info("Segmentation at NERSC Successful") + + # STEP 5: Transfer segmented data to data832 + logger.info(f"Transferring segmented data from NERSC pscratch to data832") + try: + data832_segment_transfer_success = transfer_controller.copy( + file_path=scratch_path_segment, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + logger.info(f"Transfer segmented data to data832 success: {data832_segment_transfer_success}") + except Exception as e: + logger.error(f"Failed to transfer segmented data to data832: {e}") + data832_segment_transfer_success = False + + # STEP 6: Schedule Pruning of files + logger.info("Scheduling file pruning tasks.") + prune_controller = get_prune_controller( + prune_type=PruneMethod.GLOBUS, + config=config + ) + + # Prune raw from NERSC pscratch + logger.info("Scheduling pruning of NERSC pscratch raw data.") + try: + prune_controller.prune( + file_path=f"{folder_name}/{path.name}", + source_endpoint=config.nersc832_alsdev_pscratch_raw, + check_endpoint=None, + days_from_now=1.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule raw data pruning: {e}") + + # Prune TIFFs from NERSC pscratch/scratch + if nersc_reconstruction_success: + logger.info("Scheduling pruning of NERSC pscratch reconstruction data.") + try: + prune_controller.prune( + file_path=scratch_path_tiff, + source_endpoint=config.nersc832_alsdev_pscratch_scratch, + check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, + days_from_now=1.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule reconstruction data pruning: {e}") + + # Prune segmented data from NERSC pscratch/scratch + if nersc_segmentation_success: + logger.info("Scheduling pruning of NERSC pscratch segmentation data.") + try: + prune_controller.prune( + file_path=scratch_path_segment, + source_endpoint=config.nersc832_alsdev_pscratch_scratch, + check_endpoint=config.data832_scratch if data832_segment_transfer_success else None, + days_from_now=1.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule segmentation data pruning: {e}") + + # Prune reconstructed TIFFs from data832 scratch (longer retention) + if data832_tiff_transfer_success: + logger.info("Scheduling pruning of data832 scratch reconstruction TIFF data.") + try: + prune_controller.prune( + file_path=scratch_path_tiff, + source_endpoint=config.data832_scratch, + check_endpoint=None, + days_from_now=30.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule data832 tiff pruning: {e}") + + # Prune segmented data from data832 scratch (longer retention) + if data832_segment_transfer_success: + logger.info("Scheduling pruning of data832 scratch segmentation data.") + try: + prune_controller.prune( + file_path=scratch_path_segment, + source_endpoint=config.data832_scratch, + check_endpoint=None, + days_from_now=30.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule data832 segment pruning: {e}") + + # TODO: ingest to scicat + + if nersc_reconstruction_success and nersc_segmentation_success: + logger.info("NERSC reconstruction + segmentation flow completed successfully.") + return True + else: + logger.warning(f"Flow completed with issues: recon={nersc_reconstruction_success}, seg={nersc_segmentation_success}") + return False + + @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( walltime: datetime.timedelta = datetime.timedelta(minutes=5), @@ -1873,7 +1723,7 @@ def nersc_segmentation_integration_test() -> bool: """ logger = get_run_logger() logger.info("Starting NERSC segmentation integration test.") - recon_folder_path = 'synaps-i/rec20211222_125057_petiole4' # 'test' # + recon_folder_path = 'synaps-i/rec_test' # rec20211222_125057_petiole4' # 'test' # flow_success = nersc_segmentation_task( recon_folder_path=recon_folder_path, config=Config832() @@ -1885,16 +1735,16 @@ def nersc_segmentation_integration_test() -> bool: if __name__ == "__main__": # Run the integration test flow - # from sfapi_client import Client - # from sfapi_client.compute import Machine + from sfapi_client import Client + from sfapi_client.compute import Machine - # # Use your existing client setup - # client = NERSCTomographyHPCController.create_sfapi_client() - # perlmutter = client.compute(Machine.perlmutter) + # Use your existing client setup + client = NERSCTomographyHPCController.create_sfapi_client() + perlmutter = client.compute(Machine.perlmutter) - # job.cancel() - # job = perlmutter.job(jobid=48570063) - # print(f"Job {job.jobid} cancelled, state: {job.state}") + job = perlmutter.job(jobid=48691530) + job.cancel() + print(f"Job {job.jobid} cancelled, state: {job.state}") result = nersc_segmentation_integration_test() print(f"Integration test result: {result}") From 271d2bf92da01a5e1277e1989e2d98ec5e09ccc0 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 10 Feb 2026 09:19:28 -0800 Subject: [PATCH 13/72] this configuration worked with 1 node for segmentation, testing with 4 now --- config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/config.yml b/config.yml index 85393502..bc8ffa20 100644 --- a/config.yml +++ b/config.yml @@ -46,6 +46,7 @@ globus: uri: beegfs.als.lbl.gov uuid: d33b5d6e-1603-414e-93cb-bcb732b7914a name: bl733-beegfs-data + # 8.3.2 ENDPOINTS spot832: From 9171ea5b345c9b127f212494d33cb2d7dd4bec47 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 10 Feb 2026 09:24:02 -0800 Subject: [PATCH 14/72] adding nersc_forge_recon_segment_flow to prefect.yaml for deployment --- orchestration/flows/bl832/prefect.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/orchestration/flows/bl832/prefect.yaml b/orchestration/flows/bl832/prefect.yaml index 3b93b154..f9a32997 100644 --- a/orchestration/flows/bl832/prefect.yaml +++ b/orchestration/flows/bl832/prefect.yaml @@ -49,6 +49,12 @@ deployments: name: nersc_recon_flow_pool work_queue_name: nersc_recon_multinode_flow_queue +- name: nersc_forge_recon_segment_flow + entrypoint: orchestration/flows/bl832/nersc.py:nersc_forge_recon_segment_flow + work_pool: + name: nersc_recon_flow_pool + work_queue_name: nersc_forge_recon_segment_flow_queue + - name: nersc_streaming_flow entrypoint: orchestration/flows/bl832/nersc.py:nersc_streaming_flow work_pool: From b5bd66d663c67c3c847816a699a4efc583f008ff Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 10 Feb 2026 09:24:16 -0800 Subject: [PATCH 15/72] removing comments --- orchestration/flows/bl832/nersc.py | 62 ++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 5cb59aba..b5c55373 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -597,7 +597,7 @@ def segmentation( user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" - conda_env_path = f"{cfs_path}/envs/sam3" + conda_env_path = f"{cfs_path}/envs/sam3-py311" # Paths seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" @@ -628,7 +628,6 @@ def segmentation( qos = "regular" walltime = "00:59:00" - job_name = f"seg_{Path(recon_folder_path).name}" job_script = f"""#!/bin/bash @@ -648,9 +647,36 @@ def segmentation( export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) export MASTER_PORT=29500 -# Set HuggingFace cache to pre-downloaded model files (avoid gated repo auth) -export HF_HOME="{cfs_path}/tomography_segmentation_scripts/.cache/huggingface" -export HF_HUB_CACHE="{cfs_path}/tomography_segmentation_scripts/.cache/huggingface" +# Load conda module and activate environment +module load conda +conda activate {conda_env_path} + +# --------------------------- +# Hugging Face: cache bootstrap + token +# --------------------------- +HF_HOME_ROOT="{cfs_path}/.cache/huggingface" +mkdir -p "${{HF_HOME_ROOT}}/hub" "${{HF_HOME_ROOT}}/datasets" + +export HF_HOME="${{HF_HOME_ROOT}}" +export HF_HUB_CACHE="${{HF_HOME_ROOT}}/hub" +export TRANSFORMERS_CACHE="${{HF_HUB_CACHE}}" +export HF_DATASETS_CACHE="${{HF_HOME_ROOT}}/datasets" + +# prove what each rank sees +echo "[RANK=$SLURM_PROCID] HF_HOME=$HF_HOME" +echo "[RANK=$SLURM_PROCID] HF_HUB_CACHE=$HF_HUB_CACHE" + +# Best-effort perms (ignore if not allowed) +chmod -R 2775 "{cfs_path}/tomography_segmentation_scripts/.cache" 2>/dev/null || true +chmod -R 2775 "${{HF_HOME_ROOT}}" 2>/dev/null || true + +# # Set HuggingFace cache to pre-downloaded model files (avoid gated repo auth) +# export HF_HOME="{cfs_path}/tomography_segmentation_scripts/.cache/huggingface" +# export HF_HUB_CACHE="$HF_HOME/hub" + +# export HF_HUB_OFFLINE=1 +# export TRANSFORMERS_OFFLINE=1 +# export HF_DATASETS_OFFLINE=1 # Set parameters export INPUT_DIR="{input_dir}" @@ -661,9 +687,6 @@ def segmentation( mkdir -p ${{OUTPUT_DIR}} mkdir -p {pscratch_path}/tomo_seg_logs -# Load conda module and activate environment -module load conda -conda activate {conda_env_path} echo "============================================================" echo "JOB STARTED: $(date)" @@ -687,7 +710,6 @@ def segmentation( cd {seg_scripts_dir} # Run inference with v4 - export TORCH_DISTRIBUTED_DEBUG=DETAIL export NCCL_DEBUG=INFO export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 @@ -697,12 +719,12 @@ def segmentation( srun --ntasks-per-node=1 --gpus-per-task=4 \ torchrun \ - --nnodes=$SLURM_NNODES \ + --nnodes={num_nodes} \ --nproc_per_node=4 \ --rdzv_id=$SLURM_JOB_ID \ --rdzv_backend=c10d \ --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ - src/inference_v4.py \ + src/inference_v4_logs.py \ --input-dir "${{INPUT_DIR}}" \ --output-dir "${{OUTPUT_DIR}}" \ --patch-size 640 \ @@ -1723,7 +1745,7 @@ def nersc_segmentation_integration_test() -> bool: """ logger = get_run_logger() logger.info("Starting NERSC segmentation integration test.") - recon_folder_path = 'synaps-i/rec_test' # rec20211222_125057_petiole4' # 'test' # + recon_folder_path = 'synaps-i/rec20211222_125057_petiole4' # 'test' # flow_success = nersc_segmentation_task( recon_folder_path=recon_folder_path, config=Config832() @@ -1735,16 +1757,16 @@ def nersc_segmentation_integration_test() -> bool: if __name__ == "__main__": # Run the integration test flow - from sfapi_client import Client - from sfapi_client.compute import Machine + # from sfapi_client import Client + # from sfapi_client.compute import Machine - # Use your existing client setup - client = NERSCTomographyHPCController.create_sfapi_client() - perlmutter = client.compute(Machine.perlmutter) + # # Use your existing client setup + # client = NERSCTomographyHPCController.create_sfapi_client() + # perlmutter = client.compute(Machine.perlmutter) - job = perlmutter.job(jobid=48691530) - job.cancel() - print(f"Job {job.jobid} cancelled, state: {job.state}") + # job = perlmutter.job(jobid=48700180) + # job.cancel() + # print(f"Job {job.jobid} cancelled, state: {job.state}") result = nersc_segmentation_integration_test() print(f"Integration test result: {result}") From 3a5b1d2f60dbc8904d4bbbe0c040b0ac36a0be4b Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 10 Feb 2026 09:27:35 -0800 Subject: [PATCH 16/72] making config.nersc_recon_num_nodes to set number of nodes for segmentation --- orchestration/flows/bl832/config.py | 1 + orchestration/flows/bl832/nersc.py | 15 ++------------- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 26727279..faeb1425 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -29,3 +29,4 @@ def _beam_specific_config(self) -> None: self.scicat = self.config["scicat"] self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_num_nodes = 4 + self.nersc_segment_num_nodes = 4 diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index b5c55373..f4c92028 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -651,9 +651,6 @@ def segmentation( module load conda conda activate {conda_env_path} -# --------------------------- -# Hugging Face: cache bootstrap + token -# --------------------------- HF_HOME_ROOT="{cfs_path}/.cache/huggingface" mkdir -p "${{HF_HOME_ROOT}}/hub" "${{HF_HOME_ROOT}}/datasets" @@ -670,14 +667,6 @@ def segmentation( chmod -R 2775 "{cfs_path}/tomography_segmentation_scripts/.cache" 2>/dev/null || true chmod -R 2775 "${{HF_HOME_ROOT}}" 2>/dev/null || true -# # Set HuggingFace cache to pre-downloaded model files (avoid gated repo auth) -# export HF_HOME="{cfs_path}/tomography_segmentation_scripts/.cache/huggingface" -# export HF_HUB_CACHE="$HF_HOME/hub" - -# export HF_HUB_OFFLINE=1 -# export TRANSFORMERS_OFFLINE=1 -# export HF_DATASETS_OFFLINE=1 - # Set parameters export INPUT_DIR="{input_dir}" export OUTPUT_DIR="{output_dir}" @@ -687,7 +676,6 @@ def segmentation( mkdir -p ${{OUTPUT_DIR}} mkdir -p {pscratch_path}/tomo_seg_logs - echo "============================================================" echo "JOB STARTED: $(date)" echo "============================================================" @@ -1416,6 +1404,7 @@ def nersc_recon_multinode_flow( def nersc_forge_recon_segment_flow( file_path: str, config: Optional[Config832] = None, + num_nodes: Optional[int] = None, ) -> bool: """ Process and transfer a file from bl832 to NERSC and run reconstruction and segmentation. @@ -1454,7 +1443,7 @@ def nersc_forge_recon_segment_flow( logger.info("NERSC reconstruction controller initialized") if num_nodes is None: - num_nodes = config.nersc_recon_num_nodes + num_nodes = config.nersc_segment_num_nodes logger.info(f"Configured to use {num_nodes} nodes for reconstruction") # Track success for pruning decisions From 9df2f97c9590b48ae11b65fd6986659783aa734a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 10 Feb 2026 15:14:55 -0800 Subject: [PATCH 17/72] Using the amsc006 reservation for recon+segmentation --- orchestration/flows/bl832/nersc.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index f4c92028..c01934fb 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -254,10 +254,15 @@ def reconstruct_multinode( if num_nodes > 8: qos = "premium" +#SBATCH -q regular +#SBATCH -A amsc006 +#SBATCH --reservation=_CAP_SYNAPYIDINOSAM + # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately job_script = f"""#!/bin/bash -#SBATCH -q {qos} -#SBATCH -A als +#SBATCH -q regular # {qos} +#SBATCH -A amsc006 # als +#SBATCH --reservation=_CAP_reconstruction #SBATCH -C cpu #SBATCH --job-name=tomo_recon_{folder_name}_{file_name} #SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out @@ -587,7 +592,7 @@ def build_multi_resolution( def segmentation( self, recon_folder_path: str = "", - num_nodes: int = 4, + num_nodes: int = 10, ) -> dict: """ Run SAM3 segmentation at NERSC Perlmutter (v4 with overlap + max confidence stitching). @@ -631,10 +636,11 @@ def segmentation( job_name = f"seg_{Path(recon_folder_path).name}" job_script = f"""#!/bin/bash -#SBATCH -q {qos} -#SBATCH -A als +#SBATCH -q regular +#SBATCH -A amsc006 +#SBATCH --reservation=_CAP_SYNAPYIDINOSAM #SBATCH -N {num_nodes} -#SBATCH -C gpu +#SBATCH -C gpu&hbm80g # gpu #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node=1 @@ -1756,9 +1762,9 @@ def nersc_segmentation_integration_test() -> bool: # job = perlmutter.job(jobid=48700180) # job.cancel() # print(f"Job {job.jobid} cancelled, state: {job.state}") - - result = nersc_segmentation_integration_test() - print(f"Integration test result: {result}") + nersc_forge_recon_segment_flow('/global/raw/synaps-i/20211222_122032_petiole3_scan2.h5') + # result = nersc_segmentation_integration_test() + # print(f"Integration test result: {result}") # if __name__ == "__main__": From 340fcb20b5fa0c266552cb5588bcb37f49d686f4 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 11 Feb 2026 09:47:41 -0800 Subject: [PATCH 18/72] Configuring to use all the nodes in the reservation --- orchestration/flows/bl832/config.py | 4 ++-- orchestration/flows/bl832/nersc.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index faeb1425..91667479 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -28,5 +28,5 @@ def _beam_specific_config(self) -> None: self.alcf832_scratch = self.endpoints["alcf832_scratch"] self.scicat = self.config["scicat"] self.ghcr_images832 = self.config["ghcr_images832"] - self.nersc_recon_num_nodes = 4 - self.nersc_segment_num_nodes = 4 + self.nersc_recon_num_nodes = 16 + self.nersc_segment_num_nodes = 26 diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index c01934fb..80b8c2f2 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -207,7 +207,7 @@ def reconstruct( def reconstruct_multinode( self, file_path: str = "", - num_nodes: int = 2, + num_nodes: int = 16, ) -> bool: """ @@ -592,7 +592,7 @@ def build_multi_resolution( def segmentation( self, recon_folder_path: str = "", - num_nodes: int = 10, + num_nodes: int = 26, ) -> dict: """ Run SAM3 segmentation at NERSC Perlmutter (v4 with overlap + max confidence stitching). @@ -610,7 +610,7 @@ def segmentation( bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" original_checkpoint = f"{checkpoints_dir}/sam3.pt" - finetuned_checkpoint = f"{checkpoints_dir}/checkpoint.pt" + finetuned_checkpoint = f"{checkpoints_dir}/checkpoint_v2.pt" input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" output_folder = recon_folder_path.replace('/rec', '/seg') @@ -620,7 +620,7 @@ def segmentation( logger.info(f"Output directory: {output_dir}") logger.info(f"Conda environment: {conda_env_path}") - batch_size = 8 + batch_size = 16 nproc_per_node = 4 prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", From f2f8806db750b1f60a4236bff6ad7d3246f0775c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 11 Feb 2026 10:05:19 -0800 Subject: [PATCH 19/72] num_nodes fix --- orchestration/flows/bl832/nersc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 80b8c2f2..729e5abd 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1279,7 +1279,7 @@ def nersc_recon_flow( @flow(name="nersc_recon_multinode_flow", flow_run_name="nersc_recon_multinode-{file_path}") def nersc_recon_multinode_flow( file_path: str, - num_nodes: Optional[int] = 4, + num_nodes: Optional[int] = 16, config: Optional[Config832] = None, ) -> bool: """ @@ -1449,7 +1449,7 @@ def nersc_forge_recon_segment_flow( logger.info("NERSC reconstruction controller initialized") if num_nodes is None: - num_nodes = config.nersc_segment_num_nodes + num_nodes = config.nersc_recon_num_nodes logger.info(f"Configured to use {num_nodes} nodes for reconstruction") # Track success for pruning decisions From 347816dbbd6702ba641a23e60fa3cf6d32e0726f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 11 Feb 2026 15:46:24 -0800 Subject: [PATCH 20/72] changing segmentation confidence from 0.5 to 0.2 --- orchestration/flows/bl832/nersc.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 729e5abd..0869f8d4 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -723,7 +723,7 @@ def segmentation( --output-dir "${{OUTPUT_DIR}}" \ --patch-size 640 \ --batch-size "${{BATCH_SIZE}}" \ - --confidence 0.5 \ + --confidence 0.2 \ --prompts 'Cortex' 'Phloem Fibers' 'Air-based Pith cells' 'Water-based Pith cells' 'Xylem vessels' \ --bpe-path "${{BPE_PATH}}" \ --original-checkpoint "${{ORIG_CKPT}}" \ @@ -1753,16 +1753,26 @@ def nersc_segmentation_integration_test() -> bool: # Run the integration test flow # from sfapi_client import Client - # from sfapi_client.compute import Machine + from sfapi_client.compute import Machine - # # Use your existing client setup - # client = NERSCTomographyHPCController.create_sfapi_client() - # perlmutter = client.compute(Machine.perlmutter) + # Use your existing client setup + client = NERSCTomographyHPCController.create_sfapi_client() + perlmutter = client.compute(Machine.perlmutter) - # job = perlmutter.job(jobid=48700180) + job = perlmutter.job(jobid=48781402) + job.cancel() + print(f"Job {job.jobid} cancelled, state: {job.state}") + + # job = perlmutter.job(jobid=48778803) + # job.cancel() + # print(f"Job {job.jobid} cancelled, state: {job.state}") + + # job = perlmutter.job(jobid=48777760) # job.cancel() # print(f"Job {job.jobid} cancelled, state: {job.state}") - nersc_forge_recon_segment_flow('/global/raw/synaps-i/20211222_122032_petiole3_scan2.h5') + + + # nersc_forge_recon_segment_flow('/global/raw/synaps-i/20211222_122032_petiole3_scan2.h5') # result = nersc_segmentation_integration_test() # print(f"Integration test result: {result}") From dd0799337c74da28ecde762c4d628525e119eded Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 11 Feb 2026 16:19:26 -0800 Subject: [PATCH 21/72] Setting patch-size=400 and confidence=0.5 --- orchestration/flows/bl832/nersc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 0869f8d4..9d50775e 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -721,9 +721,9 @@ def segmentation( src/inference_v4_logs.py \ --input-dir "${{INPUT_DIR}}" \ --output-dir "${{OUTPUT_DIR}}" \ - --patch-size 640 \ + --patch-size 400 \ --batch-size "${{BATCH_SIZE}}" \ - --confidence 0.2 \ + --confidence 0.5 \ --prompts 'Cortex' 'Phloem Fibers' 'Air-based Pith cells' 'Water-based Pith cells' 'Xylem vessels' \ --bpe-path "${{BPE_PATH}}" \ --original-checkpoint "${{ORIG_CKPT}}" \ From 54d832c4df9586d371aaae6fd03267996d1e3815 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 11 Feb 2026 16:46:08 -0800 Subject: [PATCH 22/72] confidence=0.2 --- orchestration/flows/bl832/nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 9d50775e..903cea57 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -723,7 +723,7 @@ def segmentation( --output-dir "${{OUTPUT_DIR}}" \ --patch-size 400 \ --batch-size "${{BATCH_SIZE}}" \ - --confidence 0.5 \ + --confidence 0.2 \ --prompts 'Cortex' 'Phloem Fibers' 'Air-based Pith cells' 'Water-based Pith cells' 'Xylem vessels' \ --bpe-path "${{BPE_PATH}}" \ --original-checkpoint "${{ORIG_CKPT}}" \ From e9336dc9474ecefa0bbec1814f0a67bdff6c27b5 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 12 Feb 2026 09:33:02 -0800 Subject: [PATCH 23/72] Adding prefect variable to override defaults for segmentation --- orchestration/flows/bl832/nersc.py | 66 ++++++++++++++++++++++++------ 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 903cea57..cc269868 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -619,28 +619,69 @@ def segmentation( logger.info(f"Input directory: {input_dir}") logger.info(f"Output directory: {output_dir}") logger.info(f"Conda environment: {conda_env_path}") + + # Default values (used when defaults=True or variable not found) + default_batch_size = 1 + default_patch_size = 400 + default_confidence = 0.5 + default_overlap = 0.25 # assuming this was your original default + default_qos = "demand" + default_account = "als" + default_constraint = "gpu" + + # Load options from Prefect variable + try: + seg_options = Variable.get("nersc-segmentation-options", default={}) + if isinstance(seg_options, str): + import json + seg_options = json.loads(seg_options) + except Exception as e: + logger.warning(f"Could not load nersc-segmentation-options variable: {e}. Using defaults.") + seg_options = {"defaults": True} + + # Determine which values to use + use_defaults = seg_options.get("defaults", True) - batch_size = 16 - nproc_per_node = 4 + if use_defaults: + logger.info("Using hardcoded default segmentation parameters") + batch_size = default_batch_size + patch_size = default_patch_size + confidence = default_confidence + overlap = default_overlap + qos = default_qos + account = default_account + constraint = default_constraint + else: + logger.info("Using parameters from nersc-segmentation-options variable") + batch_size = seg_options.get("batch_size", default_batch_size) + patch_size = seg_options.get("patch_size", default_patch_size) + confidence = seg_options.get("confidence", default_confidence) + overlap = seg_options.get("overlap", default_overlap) + qos = seg_options.get("qos", default_qos) + account = seg_options.get("account", default_account) + constraint = seg_options.get("constraint", default_constraint) + + # batch_size = 16 + # nproc_per_node = 4 prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", "Water-based Pith cells", "Xylem vessels"] - prompts_str = " ".join([f'"{p}"' for p in prompts]) + # prompts_str = " ".join([f'"{p}"' for p in prompts]) - if num_nodes <= 4: - qos = "realtime" - else: - qos = "regular" + # if num_nodes <= 4: + # qos = "realtime" + # else: + # qos = "regular" walltime = "00:59:00" job_name = f"seg_{Path(recon_folder_path).name}" job_script = f"""#!/bin/bash -#SBATCH -q regular -#SBATCH -A amsc006 +#SBATCH -q {qos} +#SBATCH -A {account} #SBATCH --reservation=_CAP_SYNAPYIDINOSAM #SBATCH -N {num_nodes} -#SBATCH -C gpu&hbm80g # gpu +#SBATCH -C {constraint} # gpu #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node=1 @@ -721,9 +762,10 @@ def segmentation( src/inference_v4_logs.py \ --input-dir "${{INPUT_DIR}}" \ --output-dir "${{OUTPUT_DIR}}" \ - --patch-size 400 \ + --patch-size {patch_size} \ --batch-size "${{BATCH_SIZE}}" \ - --confidence 0.2 \ + --confidence {confidence} \ + --overlap-ratio {overlap} \ --prompts 'Cortex' 'Phloem Fibers' 'Air-based Pith cells' 'Water-based Pith cells' 'Xylem vessels' \ --bpe-path "${{BPE_PATH}}" \ --original-checkpoint "${{ORIG_CKPT}}" \ From 567899e1adcd61c4b028a9e903013ca39a633b8f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Thu, 12 Feb 2026 11:50:09 -0800 Subject: [PATCH 24/72] updating segmentation to v5 --- orchestration/flows/bl832/nersc.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index cc269868..0325fb7b 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -587,7 +587,7 @@ def build_multi_resolution( return False else: return False - + def segmentation( self, @@ -605,7 +605,8 @@ def segmentation( conda_env_path = f"{cfs_path}/envs/sam3-py311" # Paths - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" + # seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5/forge_feb_seg_model_demo/" checkpoints_dir = f"{cfs_path}/tomography_segmentation_scripts/sam3_finetune/sam3/" bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" @@ -744,7 +745,7 @@ def segmentation( # Change to script directory cd {seg_scripts_dir} -# Run inference with v4 +# Run inference with v5 export TORCH_DISTRIBUTED_DEBUG=DETAIL export NCCL_DEBUG=INFO export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 @@ -759,7 +760,7 @@ def segmentation( --rdzv_id=$SLURM_JOB_ID \ --rdzv_backend=c10d \ --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ - src/inference_v4_logs.py \ + src/inference_v5.py \ --input-dir "${{INPUT_DIR}}" \ --output-dir "${{OUTPUT_DIR}}" \ --patch-size {patch_size} \ @@ -1795,15 +1796,15 @@ def nersc_segmentation_integration_test() -> bool: # Run the integration test flow # from sfapi_client import Client - from sfapi_client.compute import Machine + # from sfapi_client.compute import Machine - # Use your existing client setup - client = NERSCTomographyHPCController.create_sfapi_client() - perlmutter = client.compute(Machine.perlmutter) + # # Use your existing client setup + # client = NERSCTomographyHPCController.create_sfapi_client() + # perlmutter = client.compute(Machine.perlmutter) - job = perlmutter.job(jobid=48781402) - job.cancel() - print(f"Job {job.jobid} cancelled, state: {job.state}") + # job = perlmutter.job(jobid=48781402) + # job.cancel() + # print(f"Job {job.jobid} cancelled, state: {job.state}") # job = perlmutter.job(jobid=48778803) # job.cancel() @@ -1814,9 +1815,9 @@ def nersc_segmentation_integration_test() -> bool: # print(f"Job {job.jobid} cancelled, state: {job.state}") - # nersc_forge_recon_segment_flow('/global/raw/synaps-i/20211222_122032_petiole3_scan2.h5') - # result = nersc_segmentation_integration_test() - # print(f"Integration test result: {result}") + nersc_forge_recon_segment_flow('/global/raw/raw/DD-00842_hexemer/20260212_110324_petiole24.h5') + result = nersc_segmentation_integration_test() + print(f"Integration test result: {result}") # if __name__ == "__main__": From 68490dfd232d57d2f80fe10b12fee006e62df297 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Feb 2026 10:44:48 -0800 Subject: [PATCH 25/72] new checkpoint --- orchestration/flows/bl832/nersc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 0325fb7b..fb588426 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -595,9 +595,9 @@ def segmentation( num_nodes: int = 26, ) -> dict: """ - Run SAM3 segmentation at NERSC Perlmutter (v4 with overlap + max confidence stitching). + Run SAM3 segmentation at NERSC Perlmutter (v5 with overlap + max confidence stitching). """ - logger.info("Starting NERSC segmentation process (inference_v4).") + logger.info("Starting NERSC segmentation process (inference_v5).") user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" @@ -611,7 +611,7 @@ def segmentation( bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" original_checkpoint = f"{checkpoints_dir}/sam3.pt" - finetuned_checkpoint = f"{checkpoints_dir}/checkpoint_v2.pt" + finetuned_checkpoint = f"{checkpoints_dir}/checkpoint_v3.pt" input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" output_folder = recon_folder_path.replace('/rec', '/seg') From 290a98341d0035a4e1cc097b736ecb87d022c47f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Feb 2026 10:47:18 -0800 Subject: [PATCH 26/72] adding checkpoint as part of the segmentation variable --- orchestration/flows/bl832/nersc.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index fb588426..927e5380 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -611,7 +611,7 @@ def segmentation( bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" original_checkpoint = f"{checkpoints_dir}/sam3.pt" - finetuned_checkpoint = f"{checkpoints_dir}/checkpoint_v3.pt" + # finetuned_checkpoint = f"{checkpoints_dir}/checkpoint_v3.pt" input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" output_folder = recon_folder_path.replace('/rec', '/seg') @@ -629,6 +629,7 @@ def segmentation( default_qos = "demand" default_account = "als" default_constraint = "gpu" + default_checkpoint = "checkpoint_v3.pt" # Load options from Prefect variable try: @@ -652,6 +653,7 @@ def segmentation( qos = default_qos account = default_account constraint = default_constraint + checkpoint = default_checkpoint else: logger.info("Using parameters from nersc-segmentation-options variable") batch_size = seg_options.get("batch_size", default_batch_size) @@ -661,10 +663,12 @@ def segmentation( qos = seg_options.get("qos", default_qos) account = seg_options.get("account", default_account) constraint = seg_options.get("constraint", default_constraint) - + checkpoint = seg_options.get("checkpoint", default_checkpoint) # batch_size = 16 # nproc_per_node = 4 + finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" + prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", "Water-based Pith cells", "Xylem vessels"] # prompts_str = " ".join([f'"{p}"' for p in prompts]) From 1d5b8b4db032a30e30d4901ffd399a8b3c150a8e Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Feb 2026 11:28:49 -0800 Subject: [PATCH 27/72] adding support for a list of confidence scores that map to the prompt list --- orchestration/flows/bl832/nersc.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 927e5380..a3659a40 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -624,7 +624,7 @@ def segmentation( # Default values (used when defaults=True or variable not found) default_batch_size = 1 default_patch_size = 400 - default_confidence = 0.5 + default_confidence = [0.5] default_overlap = 0.25 # assuming this was your original default default_qos = "demand" default_account = "als" @@ -669,6 +669,12 @@ def segmentation( finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" + # Format confidence for command line (handles both single value and list) + if isinstance(confidence, list): + confidence_str = " ".join(str(c) for c in confidence) + else: + confidence_str = str(confidence) + prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", "Water-based Pith cells", "Xylem vessels"] # prompts_str = " ".join([f'"{p}"' for p in prompts]) @@ -769,7 +775,7 @@ def segmentation( --output-dir "${{OUTPUT_DIR}}" \ --patch-size {patch_size} \ --batch-size "${{BATCH_SIZE}}" \ - --confidence {confidence} \ + --confidence {confidence_str} \ --overlap-ratio {overlap} \ --prompts 'Cortex' 'Phloem Fibers' 'Air-based Pith cells' 'Water-based Pith cells' 'Xylem vessels' \ --bpe-path "${{BPE_PATH}}" \ From b8d6cee5d89d860b691935caa6d137056c0ff99e Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 11:31:11 -0800 Subject: [PATCH 28/72] updaing nersc flows with multisegmentation flows --- orchestration/flows/bl832/nersc.py | 1028 +++++++++++++++++++++++- orchestration/flows/bl832/prefect.yaml | 6 + 2 files changed, 1020 insertions(+), 14 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index a3659a40..2a1d62cc 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -548,6 +548,124 @@ def build_multi_resolution( {multires_image} \ bash -c "python tiff_to_zarr.py {recon_path} --raw_file {raw_path}" +date +""" + try: + logger.info("Submitting Tiff to Zarr job script to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() # Wait until the job completes + logger.info("Reconstruction job completed successfully.") + + return True + + except Exception as e: + logger.warning(f"Error during job submission or completion: {e}") + match = re.search(r"Job not found:\s*(\d+)", str(e)) + + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.perlmutter.job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("Reconstruction job completed successfully after recovery.") + return True + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + return False + else: + return False + + def build_multi_resolution_optimize( + self, + file_path: str = "", + num_nodes: int = 4, + ) -> bool: + """ + Use NERSC to make multiresolution version of tomography results with multi-node scaling. + """ + logger.info("Starting NERSC multiresolution process (multi-node).") + + user = self.client.user() + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + + multires_image = self.config.ghcr_images832["multires_image"] + recon_scripts_dir = self.config.nersc832_alsdev_recon_scripts.root_path + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + + recon_path = f"scratch/{folder_name}/rec{file_name}/" + raw_path = f"raw/{folder_name}/{file_name}.h5" + + # Scale time with nodes (less time needed with more workers) + walltime = "0:30:00" if num_nodes <= 4 else "0:15:00" + + job_script = f"""#!/bin/bash +#SBATCH -q realtime +#SBATCH -A als +#SBATCH -C cpu +#SBATCH --job-name=tomo_multires_{folder_name}_{file_name} +#SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err +#SBATCH -N {num_nodes} +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=128 +#SBATCH --time={walltime} +#SBATCH --exclusive + +date +echo "Starting multi-node Zarr conversion with {num_nodes} nodes" + +# Get scheduler node +SCHEDULER_NODE=$(hostname) +SCHEDULER_PORT=8786 +SCHEDULER_ADDR="tcp://$SCHEDULER_NODE:$SCHEDULER_PORT" + +echo "Scheduler will run on: $SCHEDULER_ADDR" + +# Start Dask scheduler on first node +srun --nodes=1 --ntasks=1 --exclusive podman-hpc run \\ + --volume {pscratch_path}/8.3.2:/alsdata \\ + {multires_image} \\ + dask scheduler --port $SCHEDULER_PORT & + +SCHEDULER_PID=$! +sleep 10 # Give scheduler time to start + +# Start Dask workers on all nodes +for i in $(seq 1 {num_nodes}); do + srun --nodes=1 --ntasks=1 --exclusive podman-hpc run \\ + --env NUMEXPR_MAX_THREADS=128 \\ + --env OMP_NUM_THREADS=128 \\ + --volume {pscratch_path}/8.3.2:/alsdata \\ + {multires_image} \\ + dask worker $SCHEDULER_ADDR --nthreads 32 --nworkers 4 --memory-limit 60GB & +done + +sleep 15 # Give workers time to connect + +# Run the conversion script, connecting to the cluster +srun --nodes=1 --ntasks=1 podman-hpc run \\ + --volume {recon_scripts_dir}/tiff_to_zarr_multinode.py:/alsuser/tiff_to_zarr_multinode.py \\ + --volume {pscratch_path}/8.3.2:/alsdata \\ + --volume {pscratch_path}/8.3.2:/alsuser/ \\ + {multires_image} \\ + bash -c "python tiff_to_zarr_multinode.py {recon_path} --raw_file {raw_path} --scheduler $SCHEDULER_ADDR" +wait date """ try: @@ -606,7 +724,7 @@ def segmentation( # Paths # seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5/forge_feb_seg_model_demo/" + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo/" checkpoints_dir = f"{cfs_path}/tomography_segmentation_scripts/sam3_finetune/sam3/" bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" @@ -629,7 +747,7 @@ def segmentation( default_qos = "demand" default_account = "als" default_constraint = "gpu" - default_checkpoint = "checkpoint_v3.pt" + default_checkpoint = "checkpoint_v5.pt" # Load options from Prefect variable try: @@ -690,7 +808,7 @@ def segmentation( job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} -#SBATCH --reservation=_CAP_SYNAPYIDINOSAM +#SBATCH --reservation=INC0249856 #SBATCH -N {num_nodes} #SBATCH -C {constraint} # gpu #SBATCH --job-name={job_name} @@ -889,14 +1007,551 @@ def segmentation( } except Exception as recovery_err: logger.error(f"Failed to recover job {jobid}: {recovery_err}") - - return { - "success": False, - "job_id": None, - "timing": None, - "output_dir": None - } - + + return { + "success": False, + "job_id": None, + "timing": None, + "output_dir": None + } + + def segmentation_dino( + self, + recon_folder_path: str = "", + ) -> bool: + """ + Run DINO segmentation at NERSC Perlmutter via SFAPI Slurm job. + + :param recon_folder_path: Relative path to the reconstructed data folder, + e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' + :return: True if the job completed successfully, False otherwise. + """ + logger.info("Starting NERSC DINO segmentation process.") + + user = self.client.user() + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" + conda_env_path = f"{cfs_path}/envs/dino_demo" + + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo" + dino_checkpoint = f"{cfs_path}/tomography_segmentation_scripts/dino/best.ckpt" + + input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" + seg_folder = recon_folder_path.replace("/rec", "/seg") + output_dir = f"{pscratch_path}/8.3.2/scratch/{seg_folder}/dino" + + logger.info(f"DINO input dir: {input_dir}") + logger.info(f"DINO output dir: {output_dir}") + + DINO_DEFAULTS = { + "defaults": True, + "batch_size": 4, + "num_nodes": 4, + "nproc_per_node": 4, + "qos": "regular", + "account": "amsc006", + "constraint": "gpu&hbm80g", + "walltime": "00:59:00", + } + try: + seg_options = Variable.get("nersc-dino-seg-options", default={"defaults": True}, _sync=True) + if isinstance(seg_options, str): + import json + seg_options = json.loads(seg_options) + except Exception as e: + logger.warning(f"Could not load nersc-dino-seg-options: {e}. Using defaults.") + seg_options = {"defaults": True} + + use_defaults = seg_options.get("defaults", True) + opts = DINO_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in DINO_DEFAULTS.items()} + + batch_size = opts["batch_size"] + num_nodes = opts["num_nodes"] + nproc_per_node = opts["nproc_per_node"] + qos = opts["qos"] + account = opts["account"] + constraint = opts["constraint"] + walltime = opts["walltime"] + + job_name = f"dino_{Path(recon_folder_path).name}" + + job_script = f"""#!/bin/bash +#SBATCH -q {qos} +#SBATCH -A {account} +#SBATCH -N {num_nodes} +#SBATCH -C {constraint} +#SBATCH --reservation=INC0249856 +#SBATCH --job-name={job_name} +#SBATCH --time={walltime} +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=128 +#SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err + +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +export MASTER_PORT=29500 + +module load conda +conda activate {conda_env_path} + +HF_HOME_ROOT="{cfs_path}/.cache/huggingface" +mkdir -p "${{HF_HOME_ROOT}}/hub" "${{HF_HOME_ROOT}}/datasets" +export HF_HOME="${{HF_HOME_ROOT}}" +export HF_HUB_CACHE="${{HF_HOME_ROOT}}/hub" +export TRANSFORMERS_CACHE="${{HF_HUB_CACHE}}" +export HF_DATASETS_CACHE="${{HF_HOME_ROOT}}/datasets" + +chmod -R 2775 "{cfs_path}/tomography_segmentation_scripts/.cache" 2>/dev/null || true +chmod -R 2775 "${{HF_HOME_ROOT}}" 2>/dev/null || true + +mkdir -p {output_dir} +mkdir -p {pscratch_path}/tomo_seg_logs + +echo "============================================================" +echo "DINO SEGMENTATION STARTED: $(date)" +echo "============================================================" +echo "Master: $MASTER_ADDR:$MASTER_PORT" +echo "Nodes: $SLURM_JOB_NODELIST" +echo "Job ID: $SLURM_JOB_ID" +echo "Input: {input_dir}" +echo "Output: {output_dir}" +echo "Parameters: batch_size={batch_size}" +echo "============================================================" + +NUM_IMAGES=$(ls {input_dir}/*.tif* 2>/dev/null | wc -l) +echo "Images to process: ${{NUM_IMAGES}}" + +START_TIME=$(date +%s) + +cd {seg_scripts_dir} + +export TORCH_DISTRIBUTED_DEBUG=DETAIL +export NCCL_DEBUG=INFO +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 + +srun --ntasks-per-node=1 --gpus-per-task=4 \\ + torchrun \\ + --nnodes={num_nodes} \\ + --nproc_per_node={nproc_per_node} \\ + --rdzv_id=$SLURM_JOB_ID \\ + --rdzv_backend=c10d \\ + --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \\ + -m src.inference_dino_v1 \\ + --input-dir "{input_dir}" \\ + --output-dir "{output_dir}" \\ + --batch-size {batch_size} \\ + --finetuned-checkpoint "{dino_checkpoint}" \\ + --save-overlay + +SEG_STATUS=$? + +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) +MINUTES=$((DURATION / 60)) +SECONDS=$((DURATION % 60)) + +echo "" +echo "============================================================" +echo "DINO SEGMENTATION COMPLETED: $(date)" +echo "============================================================" +echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" +echo "Images processed: ${{NUM_IMAGES}}" +echo "Exit status: $SEG_STATUS" +echo "============================================================" + +chmod -R 2775 {output_dir} 2>/dev/null || true + +exit $SEG_STATUS +""" + try: + logger.info("Submitting DINO segmentation job to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() + logger.info("DINO segmentation job completed successfully.") + return True + + except Exception as e: + logger.error(f"Error during DINO segmentation job submission or completion: {e}") + match = re.search(r"Job not found:\s*(\d+)", str(e)) + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.compute(Machine.perlmutter).job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("DINO segmentation job completed successfully after recovery.") + return True + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + return False + else: + return False + + def segmentation_cellpose( + self, + recon_folder_path: str = "", + ) -> bool: + """ + Run Cellpose segmentation at NERSC Perlmutter via SFAPI Slurm job. + + :param recon_folder_path: Relative path to the reconstructed data folder, + e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' + :return: True if the job completed successfully, False otherwise. + """ + logger.info("Starting NERSC Cellpose segmentation process.") + + user = self.client.user() + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" + conda_env_path = f"{cfs_path}/envs/dino_demo" + + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo" + cellpose_checkpoint = f"{cfs_path}/tomography_segmentation_scripts/cellpose/petiole_model_flow0" + + input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" + seg_folder = recon_folder_path.replace("/rec", "/seg") + output_dir = f"{pscratch_path}/8.3.2/scratch/{seg_folder}/cellpose" + + logger.info(f"Cellpose input dir: {input_dir}") + logger.info(f"Cellpose output dir: {output_dir}") + + CELLPOSE_DEFAULTS = { + "defaults": True, + "num_nodes": 4, + "nproc_per_node": 4, + "qos": "regular", + "account": "amsc006", + "constraint": "gpu&hbm80g", + "walltime": "00:59:00", + } + try: + seg_options = Variable.get("nersc-cellpose-seg-options", default={"defaults": True}, _sync=True) + if isinstance(seg_options, str): + import json + seg_options = json.loads(seg_options) + except Exception as e: + logger.warning(f"Could not load nersc-cellpose-seg-options: {e}. Using defaults.") + seg_options = {"defaults": True} + + use_defaults = seg_options.get("defaults", True) + opts = CELLPOSE_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in CELLPOSE_DEFAULTS.items()} + + num_nodes = opts["num_nodes"] + nproc_per_node = opts["nproc_per_node"] + qos = opts["qos"] + account = opts["account"] + constraint = opts["constraint"] + walltime = opts["walltime"] + + job_name = f"cellpose_{Path(recon_folder_path).name}" + + job_script = f"""#!/bin/bash +#SBATCH -q {qos} +#SBATCH -A {account} +#SBATCH -N {num_nodes} +#SBATCH -C {constraint} +#SBATCH --reservation=INC0249856 +#SBATCH --job-name={job_name} +#SBATCH --time={walltime} +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=128 +#SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err + +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +export MASTER_PORT=29500 + +module load conda +conda activate {conda_env_path} + +HF_HOME_ROOT="{cfs_path}/.cache/huggingface" +mkdir -p "${{HF_HOME_ROOT}}/hub" "${{HF_HOME_ROOT}}/datasets" +export HF_HOME="${{HF_HOME_ROOT}}" +export HF_HUB_CACHE="${{HF_HOME_ROOT}}/hub" +export TRANSFORMERS_CACHE="${{HF_HUB_CACHE}}" +export HF_DATASETS_CACHE="${{HF_HOME_ROOT}}/datasets" + +chmod -R 2775 "{cfs_path}/tomography_segmentation_scripts/.cache" 2>/dev/null || true +chmod -R 2775 "${{HF_HOME_ROOT}}" 2>/dev/null || true + +mkdir -p {output_dir} +mkdir -p {pscratch_path}/tomo_seg_logs + +echo "============================================================" +echo "CELLPOSE SEGMENTATION STARTED: $(date)" +echo "============================================================" +echo "Master: $MASTER_ADDR:$MASTER_PORT" +echo "Nodes: $SLURM_JOB_NODELIST" +echo "Job ID: $SLURM_JOB_ID" +echo "Input: {input_dir}" +echo "Output: {output_dir}" +echo "============================================================" + +NUM_IMAGES=$(ls {input_dir}/*.tif* 2>/dev/null | wc -l) +echo "Images to process: ${{NUM_IMAGES}}" + +START_TIME=$(date +%s) + +cd {seg_scripts_dir} + +export TORCH_DISTRIBUTED_DEBUG=DETAIL +export NCCL_DEBUG=INFO +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 + +srun --ntasks-per-node=1 --gpus-per-task=4 \\ + torchrun \\ + --nnodes={num_nodes} \\ + --nproc_per_node={nproc_per_node} \\ + --rdzv_id=$SLURM_JOB_ID \\ + --rdzv_backend=c10d \\ + --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \\ + -m src.inference_cellpose_v3 \\ + --input-dir "{input_dir}" \\ + --output-dir "{output_dir}" \\ + --finetuned-checkpoint "{cellpose_checkpoint}" \\ + --save-overlay + +SEG_STATUS=$? + +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) +MINUTES=$((DURATION / 60)) +SECONDS=$((DURATION % 60)) + +echo "" +echo "============================================================" +echo "CELLPOSE SEGMENTATION COMPLETED: $(date)" +echo "============================================================" +echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" +echo "Images processed: ${{NUM_IMAGES}}" +echo "Exit status: $SEG_STATUS" +echo "============================================================" + +chmod -R 2775 {output_dir} 2>/dev/null || true + +exit $SEG_STATUS +""" + try: + logger.info("Submitting Cellpose segmentation job to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() + logger.info("Cellpose segmentation job completed successfully.") + return True + + except Exception as e: + logger.error(f"Error during Cellpose segmentation job submission or completion: {e}") + match = re.search(r"Job not found:\s*(\d+)", str(e)) + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.compute(Machine.perlmutter).job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("Cellpose segmentation job completed successfully after recovery.") + return True + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + return False + else: + return False + + def combine_segmentations( + self, + recon_folder_path: str = "", + ) -> bool: + """ + Run CPU-based combination of Cellpose+DINO and SAM3+DINO segmentation results + at NERSC Perlmutter via SFAPI Slurm job. + + :param recon_folder_path: Relative path to the reconstructed data folder, + e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' + :return: True if the job completed successfully, False otherwise. + """ + logger.info("Starting NERSC segmentation combination process.") + + user = self.client.user() + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" + conda_env_path = f"{cfs_path}/envs/dino_demo" + + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5/forge_feb_seg_model_demo" + + seg_folder = recon_folder_path.replace("/rec", "/seg") + input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" + seg_base = f"{pscratch_path}/8.3.2/scratch/{seg_folder}" + + sam3_results = f"{seg_base}/sam3" + cellpose_results = f"{seg_base}/cellpose" + dino_results = f"{seg_base}/dino" + combined_output = f"{seg_base}/combined" + + logger.info(f"Combine input dir: {input_dir}") + logger.info(f"Combine output dir: {combined_output}") + + COMBINE_DEFAULTS = { + "defaults": True, + "num_nodes": 1, + "qos": "regular", + "account": "amsc006", + "constraint": "cpu", + "walltime": "01:00:00", + } + try: + seg_options = Variable.get("nersc-combine-seg-options", default={"defaults": True}, _sync=True) + if isinstance(seg_options, str): + import json + seg_options = json.loads(seg_options) + except Exception as e: + logger.warning(f"Could not load nersc-combine-seg-options: {e}. Using defaults.") + seg_options = {"defaults": True} + + use_defaults = seg_options.get("defaults", True) + opts = COMBINE_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in COMBINE_DEFAULTS.items()} + + num_nodes = opts["num_nodes"] + qos = opts["qos"] + account = opts["account"] + constraint = opts["constraint"] + walltime = opts["walltime"] + + job_name = f"combine_{Path(recon_folder_path).name}" + + job_script = f"""#!/bin/bash +#SBATCH -q {qos} +#SBATCH -A {account} +#SBATCH -N {num_nodes} +#SBATCH -C {constraint} +#SBATCH --reservation=INC0249856 +#SBATCH --job-name={job_name} +#SBATCH --time={walltime} +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=128 +#SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err + +module load conda +conda activate {conda_env_path} + +mkdir -p {combined_output}/cellpose_dino +mkdir -p {combined_output}/sam_dino +mkdir -p {pscratch_path}/tomo_seg_logs + +echo "============================================================" +echo "SEGMENTATION COMBINATION STARTED: $(date)" +echo "============================================================" +echo "Input: {input_dir}" +echo "SAM3: {sam3_results}" +echo "Cellpose: {cellpose_results}" +echo "DINO: {dino_results}" +echo "Output: {combined_output}" +echo "============================================================" + +START_TIME=$(date +%s) + +cd {seg_scripts_dir} + +echo "--- Running Cellpose + DINO combination ---" +python -m src.combine_cellpose_dino \\ + --input-dir "{input_dir}" \\ + --instance-masks-dir "{cellpose_results}/instance_masks" \\ + --semantic-masks-dir "{dino_results}/semantic_masks" \\ + --output-dir "{combined_output}/cellpose_dino" + +CELLPOSE_DINO_STATUS=$? +echo "Cellpose+DINO exit status: $CELLPOSE_DINO_STATUS" + +echo "--- Running SAM3 + DINO combination ---" +python -m src.combine_sam_dino \\ + --input-dir "{input_dir}" \\ + --instance-masks-dir "{sam3_results}" \\ + --semantic-masks-dir "{dino_results}/semantic_masks" \\ + --output-dir "{combined_output}/sam_dino" + +SAM_DINO_STATUS=$? +echo "SAM3+DINO exit status: $SAM_DINO_STATUS" + +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) +MINUTES=$((DURATION / 60)) +SECONDS=$((DURATION % 60)) + +echo "" +echo "============================================================" +echo "SEGMENTATION COMBINATION COMPLETED: $(date)" +echo "============================================================" +echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" +echo "Cellpose+DINO status: $CELLPOSE_DINO_STATUS" +echo "SAM3+DINO status: $SAM_DINO_STATUS" +echo "============================================================" + +chmod -R 2775 {combined_output} 2>/dev/null || true + +if [ $CELLPOSE_DINO_STATUS -ne 0 ] || [ $SAM_DINO_STATUS -ne 0 ]; then + exit 1 +fi +exit 0 +""" + try: + logger.info("Submitting segmentation combination job to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() + logger.info("Segmentation combination job completed successfully.") + return True + + except Exception as e: + logger.error(f"Error during segmentation combination job submission or completion: {e}") + match = re.search(r"Job not found:\s*(\d+)", str(e)) + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.compute(Machine.perlmutter).job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("Segmentation combination job completed successfully after recovery.") + return True + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + return False + else: + return False def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: str, job_name: str) -> dict: """ @@ -1689,6 +2344,238 @@ def nersc_forge_recon_segment_flow( return False +@flow(name="nersc_forge_recon_multisegment_flow", + flow_run_name="nersc_recon_multiseg-{file_path}") +def nersc_forge_recon_multisegment_flow( + file_path: str, + config: Optional[Config832] = None, + num_nodes: Optional[int] = None, +) -> bool: + """ + Transfer raw data to NERSC, run reconstruction, then run SAM3, DINO, and Cellpose + segmentation concurrently, followed by a combination step. + + :param file_path: The path to the file to be processed. + :param config: Configuration object for the flow. + :param num_nodes: Number of nodes for reconstruction. + :return: True if reconstruction and at least one segmentation task succeeded. + """ + logger = get_run_logger() + + if config is None: + logger.info("Initializing Config") + config = Config832() + + path = Path(file_path) + folder_name = path.parent.name + file_name = path.stem + scratch_path_tiff = f"{folder_name}/rec{file_name}" + scratch_path_segment = f"{folder_name}/seg{file_name}" + + logger.info(f"Starting NERSC reconstruction + multi-segmentation flow for {file_path=}") + logger.info(f"Reconstructed TIFFs will be at: {scratch_path_tiff}") + logger.info(f"Segmented output will be at: {scratch_path_segment}") + + transfer_controller = get_transfer_controller( + transfer_type=CopyMethod.GLOBUS, + config=config + ) + controller = get_controller(hpc_type=HPC.NERSC, config=config) + logger.info("NERSC controller initialized") + + if num_nodes is None: + num_nodes = config.nersc_recon_num_nodes + logger.info(f"Configured to use {num_nodes} nodes for reconstruction") + + nersc_reconstruction_success = False + sam3_success = False + dino_success = False + cellpose_success = False + data832_tiff_transfer_success = False + data832_segment_transfer_success = False + + # ── STEP 1: Multinode Reconstruction ───────────────────────────────────── + logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") + recon_result = controller.reconstruct_multinode( + file_path=file_path, + num_nodes=num_nodes + ) + + if isinstance(recon_result, dict): + nersc_reconstruction_success = recon_result.get('success', False) + timing = recon_result.get('timing') + if timing: + logger.info("=" * 50) + logger.info("TIMING BREAKDOWN") + logger.info("=" * 50) + logger.info(f" Total job time: {timing.get('total', 'N/A')}s") + logger.info(f" Container pull: {timing.get('container_pull', 'N/A')}s") + logger.info( + f" File copy: {timing.get('file_copy', 'N/A')}s " + f"(skipped: {timing.get('copy_skipped', 'N/A')})" + ) + logger.info(f" Metadata detection: {timing.get('metadata', 'N/A')}s") + logger.info(f" RECONSTRUCTION: {timing.get('reconstruction', 'N/A')}s <-- actual recon time") + logger.info(f" Num slices: {timing.get('num_slices', 'N/A')}") + logger.info("=" * 50) + if all(k in timing for k in ['total', 'reconstruction']): + overhead = timing['total'] - timing['reconstruction'] + logger.info(f" Overhead: {overhead}s") + logger.info(f" Reconstruction %: {100 * timing['reconstruction'] / timing['total']:.1f}%") + logger.info("=" * 50) + else: + nersc_reconstruction_success = recon_result + + logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") + + if not nersc_reconstruction_success: + logger.error("Reconstruction Failed.") + raise ValueError("Reconstruction at NERSC Failed") + + logger.info("Reconstruction Successful.") + + # ── STEP 2: Transfer TIFFs to data832 ──────────────────────────────────── + logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") + try: + data832_tiff_transfer_success = transfer_controller.copy( + file_path=scratch_path_tiff, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + logger.info(f"Transfer reconstructed TIFF data to data832 success: {data832_tiff_transfer_success}") + except Exception as e: + logger.error(f"Failed to transfer TIFFs to data832: {e}") + data832_tiff_transfer_success = False + + # ── STEP 3: SAM3 / DINO / Cellpose concurrently ────────────────────────── + logger.info("Submitting SAM3, DINO, and Cellpose segmentation tasks concurrently.") + + sam3_future = nersc_segmentation_task.submit( + recon_folder_path=scratch_path_tiff, config=config + ) + dino_future = nersc_segmentation_dino_task.submit( + recon_folder_path=scratch_path_tiff, config=config + ) + cellpose_future = nersc_segmentation_cellpose_task.submit( + recon_folder_path=scratch_path_tiff, config=config + ) + + sam3_result = sam3_future.result() + dino_success = dino_future.result() + cellpose_success = cellpose_future.result() + + # nersc_segmentation_task (SAM3) returns a dict + if isinstance(sam3_result, dict): + sam3_success = sam3_result.get('success', False) + else: + sam3_success = bool(sam3_result) + + logger.info( + f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}, Cellpose: {cellpose_success}" + ) + + any_seg_success = any([sam3_success, dino_success, cellpose_success]) + + # ── STEP 4: Combine (sync, after all three complete) ───────────────────── + if dino_success and (sam3_success or cellpose_success): + logger.info("Running segmentation combination (SAM3+DINO and Cellpose+DINO).") + combine_success = controller.combine_segmentations( + recon_folder_path=scratch_path_tiff + ) + logger.info(f"Combination result: {combine_success}") + else: + logger.warning("Skipping combination: requires DINO plus at least one of SAM3/Cellpose.") + + # ── STEP 5: Transfer segmentation outputs to data832 ───────────────────── + if any_seg_success: + logger.info("Transferring segmentation outputs from NERSC pscratch to data832") + try: + data832_segment_transfer_success = transfer_controller.copy( + file_path=scratch_path_segment, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + logger.info(f"Transfer segmented data to data832 success: {data832_segment_transfer_success}") + except Exception as e: + logger.error(f"Failed to transfer segmented data to data832: {e}") + data832_segment_transfer_success = False + + # ── STEP 6: Pruning ─────────────────────────────────────────────────────── + logger.info("Scheduling file pruning tasks.") + prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) + + logger.info("Scheduling pruning of NERSC pscratch raw data.") + try: + prune_controller.prune( + file_path=f"{folder_name}/{path.name}", + source_endpoint=config.nersc832_alsdev_pscratch_raw, + check_endpoint=None, + days_from_now=1.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule raw data pruning: {e}") + + if nersc_reconstruction_success: + logger.info("Scheduling pruning of NERSC pscratch reconstruction data.") + try: + prune_controller.prune( + file_path=scratch_path_tiff, + source_endpoint=config.nersc832_alsdev_pscratch_scratch, + check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, + days_from_now=1.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule reconstruction data pruning: {e}") + + if any_seg_success: + logger.info("Scheduling pruning of NERSC pscratch segmentation data.") + try: + prune_controller.prune( + file_path=scratch_path_segment, + source_endpoint=config.nersc832_alsdev_pscratch_scratch, + check_endpoint=config.data832_scratch if data832_segment_transfer_success else None, + days_from_now=1.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule segmentation data pruning: {e}") + + if data832_tiff_transfer_success: + logger.info("Scheduling pruning of data832 scratch reconstruction TIFF data.") + try: + prune_controller.prune( + file_path=scratch_path_tiff, + source_endpoint=config.data832_scratch, + check_endpoint=None, + days_from_now=30.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule data832 tiff pruning: {e}") + + if data832_segment_transfer_success: + logger.info("Scheduling pruning of data832 scratch segmentation data.") + try: + prune_controller.prune( + file_path=scratch_path_segment, + source_endpoint=config.data832_scratch, + check_endpoint=None, + days_from_now=30.0 + ) + except Exception as e: + logger.warning(f"Failed to schedule data832 segment pruning: {e}") + + # TODO: ingest to scicat + + if nersc_reconstruction_success and any_seg_success: + logger.info("NERSC reconstruction + multi-segmentation flow completed successfully.") + return True + else: + logger.warning( + f"Flow completed with issues: recon={nersc_reconstruction_success}, " + f"sam3={sam3_success}, dino={dino_success}, cellpose={cellpose_success}" + ) + return False + + @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( walltime: datetime.timedelta = datetime.timedelta(minutes=5), @@ -1750,6 +2637,58 @@ def pull_shifter_image_flow( return success +@task(name="nersc_multiresolution_task") +def nersc_multiresolution_task( + file_path: str, + config: Optional[Config832] = None, +) -> bool: + """ + Run multiresolution task at NERSC. + + :param file_path: Path to the reconstructed data folder to be processed. + :param config: Configuration object for the flow. + :return: True if the task completed successfully, False otherwise. + """ + logger = get_run_logger() + if config is None: + logger.info("No config provided, using default Config832.") + config = Config832() + + # Initialize the Tomography Controller and run the segmentation + logger.info("Initializing NERSC Tomography HPC Controller.") + tomography_controller = get_controller( + hpc_type=HPC.NERSC, + config=config + ) + logger.info(f"Starting NERSC multiresolution task for {file_path=}") + nersc_multiresolution_success = tomography_controller.build_multi_resolution_optimize( + file_path=file_path, + ) + if not nersc_multiresolution_success: + logger.error("Multiresolution Failed.") + else: + logger.info("Multiresolution Successful.") + return nersc_multiresolution_success + + +@flow(name="nersc_multiresolution_integration_test", flow_run_name="nersc_multiresolution_integration_test") +def nersc_multiresolution_integration_test() -> bool: + """ + Integration test for the NERSC multiresolution task. + + :return: True if the multiresolution task completed successfully, False otherwise. + """ + logger = get_run_logger() + logger.info("Starting NERSC multiresolution integration test.") + file_path = 'DD-00842_hexemer/20260213_155826_petiole49.h5' # 'test' # + flow_success = nersc_multiresolution_task( + file_path=file_path, + config=Config832() + ) + logger.info(f"Flow success: {flow_success}") + return flow_success + + @task(name="nersc_segmentation_task") def nersc_segmentation_task( recon_folder_path: str, @@ -1784,6 +2723,63 @@ def nersc_segmentation_task( return nersc_segmentation_success +@task(name="nersc_segmentation_dino_task") +def nersc_segmentation_dino_task( + recon_folder_path: str, + config: Optional[Config832] = None, +) -> bool: + logger = get_run_logger() + if config is None: + logger.info("No config provided, using default Config832.") + config = Config832() + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + logger.info(f"Starting NERSC DINO segmentation task for {recon_folder_path=}") + success = tomography_controller.segmentation_dino(recon_folder_path=recon_folder_path) + if not success: + logger.error("DINO segmentation failed.") + else: + logger.info("DINO segmentation successful.") + return success + + +@task(name="nersc_segmentation_cellpose_task") +def nersc_segmentation_cellpose_task( + recon_folder_path: str, + config: Optional[Config832] = None, +) -> bool: + logger = get_run_logger() + if config is None: + logger.info("No config provided, using default Config832.") + config = Config832() + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + logger.info(f"Starting NERSC Cellpose segmentation task for {recon_folder_path=}") + success = tomography_controller.segmentation_cellpose(recon_folder_path=recon_folder_path) + if not success: + logger.error("Cellpose segmentation failed.") + else: + logger.info("Cellpose segmentation successful.") + return success + + +@task(name="nersc_combine_segmentations_task") +def nersc_combine_segmentations_task( + recon_folder_path: str, + config: Optional[Config832] = None, +) -> bool: + logger = get_run_logger() + if config is None: + logger.info("No config provided, using default Config832.") + config = Config832() + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + logger.info(f"Starting NERSC combine segmentations task for {recon_folder_path=}") + success = tomography_controller.combine_segmentations(recon_folder_path=recon_folder_path) + if not success: + logger.error("Combine segmentations failed.") + else: + logger.info("Combine segmentations successful.") + return success + + @flow(name="nersc_segmentation_integration_test", flow_run_name="nersc_segmentation_integration_test") def nersc_segmentation_integration_test() -> bool: """ @@ -1803,6 +2799,9 @@ def nersc_segmentation_integration_test() -> bool: if __name__ == "__main__": + + nersc_multiresolution_integration_test() + # Run the integration test flow # from sfapi_client import Client @@ -1824,10 +2823,11 @@ def nersc_segmentation_integration_test() -> bool: # job.cancel() # print(f"Job {job.jobid} cancelled, state: {job.state}") + # nersc_forge_recon_segment_flow('/global/raw/raw/DD-00842_hexemer/20260212_110324_petiole24.h5') + # result = nersc_segmentation_integration_test() + # print(f"Integration test result: {result}") + - nersc_forge_recon_segment_flow('/global/raw/raw/DD-00842_hexemer/20260212_110324_petiole24.h5') - result = nersc_segmentation_integration_test() - print(f"Integration test result: {result}") # if __name__ == "__main__": diff --git a/orchestration/flows/bl832/prefect.yaml b/orchestration/flows/bl832/prefect.yaml index f9a32997..8514cbdb 100644 --- a/orchestration/flows/bl832/prefect.yaml +++ b/orchestration/flows/bl832/prefect.yaml @@ -55,6 +55,12 @@ deployments: name: nersc_recon_flow_pool work_queue_name: nersc_forge_recon_segment_flow_queue +- name: nersc_forge_recon_multisegment_flow + entrypoint: orchestration/flows/bl832/nersc.py:nersc_forge_recon_multisegment_flow + work_pool: + name: nersc_recon_flow_pool + work_queue_name: nersc_forge_recon_multisegment_flow_queue + - name: nersc_streaming_flow entrypoint: orchestration/flows/bl832/nersc.py:nersc_streaming_flow work_pool: From 9ae81fc20f2b11fff1cc98c1c99d83b8687e657b Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 11:39:02 -0800 Subject: [PATCH 29/72] updaing nersc recon reservation name --- orchestration/flows/bl832/nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 2a1d62cc..7d1fcd40 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -262,7 +262,7 @@ def reconstruct_multinode( job_script = f"""#!/bin/bash #SBATCH -q regular # {qos} #SBATCH -A amsc006 # als -#SBATCH --reservation=_CAP_reconstruction +#SBATCH --reservation=INC0249856 #SBATCH -C cpu #SBATCH --job-name=tomo_recon_{folder_name}_{file_name} #SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out From 5532d128bb36e89d9e6a2d6c6c63a19c40ca2336 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 11:48:24 -0800 Subject: [PATCH 30/72] updating reservation name --- orchestration/flows/bl832/nersc.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 7d1fcd40..d1ee247c 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -256,13 +256,13 @@ def reconstruct_multinode( #SBATCH -q regular #SBATCH -A amsc006 -#SBATCH --reservation=_CAP_SYNAPYIDINOSAM +#SBATCH --reservation=_CAP_MarchModCon_CPU # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately job_script = f"""#!/bin/bash #SBATCH -q regular # {qos} #SBATCH -A amsc006 # als -#SBATCH --reservation=INC0249856 +#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_CPU #SBATCH -C cpu #SBATCH --job-name=tomo_recon_{folder_name}_{file_name} #SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out @@ -808,7 +808,7 @@ def segmentation( job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} -#SBATCH --reservation=INC0249856 +#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_GPU #SBATCH -N {num_nodes} #SBATCH -C {constraint} # gpu #SBATCH --job-name={job_name} @@ -1080,7 +1080,7 @@ def segmentation_dino( #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=INC0249856 +#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_GPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node=1 @@ -1263,7 +1263,7 @@ def segmentation_cellpose( #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=INC0249856 +#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_GPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node=1 @@ -1447,7 +1447,7 @@ def combine_segmentations( #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=INC0249856 +#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_CPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks=1 From 4dea9a7010fb37540f8facbb3dca31cb531f9bf5 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 13:34:06 -0800 Subject: [PATCH 31/72] adjusting node numbers --- orchestration/flows/bl832/nersc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index d1ee247c..8140f0c6 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -935,7 +935,7 @@ def segmentation( """ try: - logger.info("Submitting segmentation job to Perlmutter (v4).") + logger.info("Submitting segmentation job to Perlmutter (v5).") perlmutter = self.client.compute(Machine.perlmutter) # Ensure directories exist @@ -1230,7 +1230,7 @@ def segmentation_cellpose( CELLPOSE_DEFAULTS = { "defaults": True, - "num_nodes": 4, + "num_nodes": 10, "nproc_per_node": 4, "qos": "regular", "account": "amsc006", @@ -1416,7 +1416,7 @@ def combine_segmentations( COMBINE_DEFAULTS = { "defaults": True, - "num_nodes": 1, + "num_nodes": 8, "qos": "regular", "account": "amsc006", "constraint": "cpu", From d37213f3ec2538189f9d9dad220c8f9caa0491b9 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 13:36:34 -0800 Subject: [PATCH 32/72] adjusting path for combine step scripts --- orchestration/flows/bl832/nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 8140f0c6..9258e5d4 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1400,7 +1400,7 @@ def combine_segmentations( cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" conda_env_path = f"{cfs_path}/envs/dino_demo" - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5/forge_feb_seg_model_demo" + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo" seg_folder = recon_folder_path.replace("/rec", "/seg") input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" From e88ad2d30321666ef38ef35a38a3aa6ffae83c2c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 14:51:04 -0800 Subject: [PATCH 33/72] Update prompt list for latest sam3 version --- orchestration/flows/bl832/nersc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 9258e5d4..4f6c7050 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -724,7 +724,7 @@ def segmentation( # Paths # seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo/" + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo/" checkpoints_dir = f"{cfs_path}/tomography_segmentation_scripts/sam3_finetune/sam3/" bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" @@ -744,10 +744,10 @@ def segmentation( default_patch_size = 400 default_confidence = [0.5] default_overlap = 0.25 # assuming this was your original default - default_qos = "demand" + default_qos = "regular" default_account = "als" default_constraint = "gpu" - default_checkpoint = "checkpoint_v5.pt" + default_checkpoint = "checkpoint_v6.pt" # Load options from Prefect variable try: @@ -895,7 +895,7 @@ def segmentation( --batch-size "${{BATCH_SIZE}}" \ --confidence {confidence_str} \ --overlap-ratio {overlap} \ - --prompts 'Cortex' 'Phloem Fibers' 'Air-based Pith cells' 'Water-based Pith cells' 'Xylem vessels' \ + --prompts 'Cortex' 'Phloem Fibers' 'Phloem' 'Hydrated Xylem vessels' 'Air-based Pith cells' 'Water-based Pith cells' 'Dehydrated Xylem vessels' \ --bpe-path "${{BPE_PATH}}" \ --original-checkpoint "${{ORIG_CKPT}}" \ --finetuned-checkpoint "${{FT_CKPT}}" From 7e1562c3f36b5ca2eaca49219d16a4b714407aaf Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 14:52:58 -0800 Subject: [PATCH 34/72] transferring segmented results to data832 as they complete rather than waiting --- orchestration/flows/bl832/nersc.py | 218 ++++++++++++++++++++++++----- 1 file changed, 181 insertions(+), 37 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 4f6c7050..2b554571 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2460,51 +2460,81 @@ def nersc_forge_recon_multisegment_flow( recon_folder_path=scratch_path_tiff, config=config ) + # ── STEP 4: Transfer each model's output as it completes ───────────────── sam3_result = sam3_future.result() - dino_success = dino_future.result() - cellpose_success = cellpose_future.result() + sam3_success = sam3_result.get('success', False) if isinstance(sam3_result, dict) else bool(sam3_result) + logger.info(f"SAM3 segmentation result: {sam3_success}") + if sam3_success: + logger.info("Transferring SAM3 segmentation outputs to data832") + sam3_segment_path = f"{folder_name}/seg{file_name}/sam3" + try: + data832_sam3_transfer_success = transfer_controller.copy( + file_path=sam3_segment_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + logger.info(f"SAM3 transfer to data832 success: {data832_sam3_transfer_success}") + except Exception as e: + logger.error(f"Failed to transfer SAM3 outputs to data832: {e}") - # nersc_segmentation_task (SAM3) returns a dict - if isinstance(sam3_result, dict): - sam3_success = sam3_result.get('success', False) - else: - sam3_success = bool(sam3_result) + dino_success = dino_future.result() + logger.info(f"DINO segmentation result: {dino_success}") + if dino_success: + logger.info("Transferring DINO segmentation outputs to data832") + dino_segment_path = f"{folder_name}/seg{file_name}/dino" + try: + data832_dino_transfer_success = transfer_controller.copy( + file_path=dino_segment_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + logger.info(f"DINO transfer to data832 success: {data832_dino_transfer_success}") + except Exception as e: + logger.error(f"Failed to transfer DINO outputs to data832: {e}") - logger.info( - f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}, Cellpose: {cellpose_success}" - ) + cellpose_success = cellpose_future.result() + logger.info(f"Cellpose segmentation result: {cellpose_success}") + if cellpose_success: + logger.info("Transferring Cellpose segmentation outputs to data832") + cellpose_segment_path = f"{folder_name}/seg{file_name}/cellpose" + try: + data832_cellpose_transfer_success = transfer_controller.copy( + file_path=cellpose_segment_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + logger.info(f"Cellpose transfer to data832 success: {data832_cellpose_transfer_success}") + except Exception as e: + logger.error(f"Failed to transfer Cellpose outputs to data832: {e}") any_seg_success = any([sam3_success, dino_success, cellpose_success]) - # ── STEP 4: Combine (sync, after all three complete) ───────────────────── + logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}, Cellpose: {cellpose_success}") + + # ── STEP 5: Combine (after all three complete) ──────────────────────────── if dino_success and (sam3_success or cellpose_success): logger.info("Running segmentation combination (SAM3+DINO and Cellpose+DINO).") - combine_success = controller.combine_segmentations( - recon_folder_path=scratch_path_tiff - ) + combine_success = controller.combine_segmentations(recon_folder_path=scratch_path_tiff) logger.info(f"Combination result: {combine_success}") + if combine_success: + logger.info("Transferring combined segmentation outputs to data832") + combined_segment_path = f"{folder_name}/seg{file_name}/combined" + try: + data832_combined_transfer_success = transfer_controller.copy( + file_path=combined_segment_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + logger.info(f"Combined transfer to data832 success: {data832_combined_transfer_success}") + except Exception as e: + logger.error(f"Failed to transfer combined outputs to data832: {e}") else: logger.warning("Skipping combination: requires DINO plus at least one of SAM3/Cellpose.") - # ── STEP 5: Transfer segmentation outputs to data832 ───────────────────── - if any_seg_success: - logger.info("Transferring segmentation outputs from NERSC pscratch to data832") - try: - data832_segment_transfer_success = transfer_controller.copy( - file_path=scratch_path_segment, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch - ) - logger.info(f"Transfer segmented data to data832 success: {data832_segment_transfer_success}") - except Exception as e: - logger.error(f"Failed to transfer segmented data to data832: {e}") - data832_segment_transfer_success = False - # ── STEP 6: Pruning ─────────────────────────────────────────────────────── logger.info("Scheduling file pruning tasks.") prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) - logger.info("Scheduling pruning of NERSC pscratch raw data.") try: prune_controller.prune( file_path=f"{folder_name}/{path.name}", @@ -2516,7 +2546,6 @@ def nersc_forge_recon_multisegment_flow( logger.warning(f"Failed to schedule raw data pruning: {e}") if nersc_reconstruction_success: - logger.info("Scheduling pruning of NERSC pscratch reconstruction data.") try: prune_controller.prune( file_path=scratch_path_tiff, @@ -2528,19 +2557,21 @@ def nersc_forge_recon_multisegment_flow( logger.warning(f"Failed to schedule reconstruction data pruning: {e}") if any_seg_success: - logger.info("Scheduling pruning of NERSC pscratch segmentation data.") try: prune_controller.prune( file_path=scratch_path_segment, source_endpoint=config.nersc832_alsdev_pscratch_scratch, - check_endpoint=config.data832_scratch if data832_segment_transfer_success else None, + check_endpoint=config.data832_scratch if any([ + data832_sam3_transfer_success, + data832_dino_transfer_success, + data832_cellpose_transfer_success + ]) else None, days_from_now=1.0 ) except Exception as e: logger.warning(f"Failed to schedule segmentation data pruning: {e}") if data832_tiff_transfer_success: - logger.info("Scheduling pruning of data832 scratch reconstruction TIFF data.") try: prune_controller.prune( file_path=scratch_path_tiff, @@ -2551,8 +2582,8 @@ def nersc_forge_recon_multisegment_flow( except Exception as e: logger.warning(f"Failed to schedule data832 tiff pruning: {e}") - if data832_segment_transfer_success: - logger.info("Scheduling pruning of data832 scratch segmentation data.") + if any([data832_sam3_transfer_success, data832_dino_transfer_success, + data832_cellpose_transfer_success, data832_combined_transfer_success]): try: prune_controller.prune( file_path=scratch_path_segment, @@ -2563,8 +2594,6 @@ def nersc_forge_recon_multisegment_flow( except Exception as e: logger.warning(f"Failed to schedule data832 segment pruning: {e}") - # TODO: ingest to scicat - if nersc_reconstruction_success and any_seg_success: logger.info("NERSC reconstruction + multi-segmentation flow completed successfully.") return True @@ -2575,6 +2604,121 @@ def nersc_forge_recon_multisegment_flow( ) return False + # sam3_result = sam3_future.result() + # dino_success = dino_future.result() + # cellpose_success = cellpose_future.result() + + # # nersc_segmentation_task (SAM3) returns a dict + # if isinstance(sam3_result, dict): + # sam3_success = sam3_result.get('success', False) + # else: + # sam3_success = bool(sam3_result) + + # logger.info( + # f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}, Cellpose: {cellpose_success}" + # ) + + # any_seg_success = any([sam3_success, dino_success, cellpose_success]) + + # # ── STEP 4: Combine (sync, after all three complete) ───────────────────── + # if dino_success and (sam3_success or cellpose_success): + # logger.info("Running segmentation combination (SAM3+DINO and Cellpose+DINO).") + # combine_success = controller.combine_segmentations( + # recon_folder_path=scratch_path_tiff + # ) + # logger.info(f"Combination result: {combine_success}") + # else: + # logger.warning("Skipping combination: requires DINO plus at least one of SAM3/Cellpose.") + + # # ── STEP 5: Transfer segmentation outputs to data832 ───────────────────── + # if any_seg_success: + # logger.info("Transferring segmentation outputs from NERSC pscratch to data832") + # try: + # data832_segment_transfer_success = transfer_controller.copy( + # file_path=scratch_path_segment, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch + # ) + # logger.info(f"Transfer segmented data to data832 success: {data832_segment_transfer_success}") + # except Exception as e: + # logger.error(f"Failed to transfer segmented data to data832: {e}") + # data832_segment_transfer_success = False + + # # ── STEP 6: Pruning ─────────────────────────────────────────────────────── + # logger.info("Scheduling file pruning tasks.") + # prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) + + # logger.info("Scheduling pruning of NERSC pscratch raw data.") + # try: + # prune_controller.prune( + # file_path=f"{folder_name}/{path.name}", + # source_endpoint=config.nersc832_alsdev_pscratch_raw, + # check_endpoint=None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule raw data pruning: {e}") + + # if nersc_reconstruction_success: + # logger.info("Scheduling pruning of NERSC pscratch reconstruction data.") + # try: + # prune_controller.prune( + # file_path=scratch_path_tiff, + # source_endpoint=config.nersc832_alsdev_pscratch_scratch, + # check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule reconstruction data pruning: {e}") + + # if any_seg_success: + # logger.info("Scheduling pruning of NERSC pscratch segmentation data.") + # try: + # prune_controller.prune( + # file_path=scratch_path_segment, + # source_endpoint=config.nersc832_alsdev_pscratch_scratch, + # check_endpoint=config.data832_scratch if data832_segment_transfer_success else None, + # days_from_now=1.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule segmentation data pruning: {e}") + + # if data832_tiff_transfer_success: + # logger.info("Scheduling pruning of data832 scratch reconstruction TIFF data.") + # try: + # prune_controller.prune( + # file_path=scratch_path_tiff, + # source_endpoint=config.data832_scratch, + # check_endpoint=None, + # days_from_now=30.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule data832 tiff pruning: {e}") + + # if data832_segment_transfer_success: + # logger.info("Scheduling pruning of data832 scratch segmentation data.") + # try: + # prune_controller.prune( + # file_path=scratch_path_segment, + # source_endpoint=config.data832_scratch, + # check_endpoint=None, + # days_from_now=30.0 + # ) + # except Exception as e: + # logger.warning(f"Failed to schedule data832 segment pruning: {e}") + + # # TODO: ingest to scicat + + # if nersc_reconstruction_success and any_seg_success: + # logger.info("NERSC reconstruction + multi-segmentation flow completed successfully.") + # return True + # else: + # logger.warning( + # f"Flow completed with issues: recon={nersc_reconstruction_success}, " + # f"sam3={sam3_success}, dino={dino_success}, cellpose={cellpose_success}" + # ) + # return False + @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( From cd239b30e7eebe400bac8bb6dfbf7c6c8aa13b27 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 14:55:01 -0800 Subject: [PATCH 35/72] making sam3 results go into its own folder so its not messy --- orchestration/flows/bl832/nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 2b554571..c8f91b55 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -733,7 +733,7 @@ def segmentation( input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" output_folder = recon_folder_path.replace('/rec', '/seg') - output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}" + output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}/sam3" logger.info(f"Input directory: {input_dir}") logger.info(f"Output directory: {output_dir}") From 0b4c709d4d700c155c3a7cc737aa7f7c301c7252 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 15:21:04 -0800 Subject: [PATCH 36/72] using the latest segmentation versions --- orchestration/flows/bl832/nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index c8f91b55..8a7caf05 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1230,7 +1230,7 @@ def segmentation_cellpose( CELLPOSE_DEFAULTS = { "defaults": True, - "num_nodes": 10, + "num_nodes": 16, "nproc_per_node": 4, "qos": "regular", "account": "amsc006", From 3461d17d1ab2dc5caa46a478a8490c4e812efc51 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Feb 2026 16:21:22 -0800 Subject: [PATCH 37/72] using combine_sam_dino_v2 --- orchestration/flows/bl832/nersc.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 8a7caf05..b5b9ae5c 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -713,9 +713,9 @@ def segmentation( num_nodes: int = 26, ) -> dict: """ - Run SAM3 segmentation at NERSC Perlmutter (v5 with overlap + max confidence stitching). + Run SAM3 segmentation at NERSC Perlmutter (v6 with overlap + max confidence stitching). """ - logger.info("Starting NERSC segmentation process (inference_v5).") + logger.info("Starting NERSC segmentation process (inference_v6).") user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" @@ -793,8 +793,8 @@ def segmentation( else: confidence_str = str(confidence) - prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", - "Water-based Pith cells", "Xylem vessels"] + # prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", + # "Water-based Pith cells", "Xylem vessels"] # prompts_str = " ".join([f'"{p}"' for p in prompts]) # if num_nodes <= 4: @@ -888,7 +888,7 @@ def segmentation( --rdzv_id=$SLURM_JOB_ID \ --rdzv_backend=c10d \ --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ - src/inference_v5.py \ + src/inference_v6.py \ --input-dir "${{INPUT_DIR}}" \ --output-dir "${{OUTPUT_DIR}}" \ --patch-size {patch_size} \ @@ -1230,7 +1230,7 @@ def segmentation_cellpose( CELLPOSE_DEFAULTS = { "defaults": True, - "num_nodes": 16, + "num_nodes": 10, "nproc_per_node": 4, "qos": "regular", "account": "amsc006", @@ -1400,7 +1400,7 @@ def combine_segmentations( cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" conda_env_path = f"{cfs_path}/envs/dino_demo" - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo" + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo" seg_folder = recon_folder_path.replace("/rec", "/seg") input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" @@ -1487,7 +1487,7 @@ def combine_segmentations( echo "Cellpose+DINO exit status: $CELLPOSE_DINO_STATUS" echo "--- Running SAM3 + DINO combination ---" -python -m src.combine_sam_dino \\ +python -m src.combine_sam_dino_v2 \\ --input-dir "{input_dir}" \\ --instance-masks-dir "{sam3_results}" \\ --semantic-masks-dir "{dino_results}/semantic_masks" \\ From a869d02981c73512ac8bdc766c50b6c2b30203f3 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Sat, 21 Feb 2026 08:54:24 -0800 Subject: [PATCH 38/72] removing cellpose from the multiseg workflow and increasing the number of nodes for dino/sam --- orchestration/flows/bl832/config.py | 2 +- orchestration/flows/bl832/nersc.py | 86 ++++++++++++++++------------- 2 files changed, 49 insertions(+), 39 deletions(-) diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 91667479..8d8f8682 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -29,4 +29,4 @@ def _beam_specific_config(self) -> None: self.scicat = self.config["scicat"] self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_num_nodes = 16 - self.nersc_segment_num_nodes = 26 + self.nersc_segment_num_nodes = 42 diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index b5b9ae5c..2de92af6 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1046,7 +1046,7 @@ def segmentation_dino( DINO_DEFAULTS = { "defaults": True, "batch_size": 4, - "num_nodes": 4, + "num_nodes": 8, "nproc_per_node": 4, "qos": "regular", "account": "amsc006", @@ -1407,7 +1407,7 @@ def combine_segmentations( seg_base = f"{pscratch_path}/8.3.2/scratch/{seg_folder}" sam3_results = f"{seg_base}/sam3" - cellpose_results = f"{seg_base}/cellpose" + # cellpose_results = f"{seg_base}/cellpose" dino_results = f"{seg_base}/dino" combined_output = f"{seg_base}/combined" @@ -1458,7 +1458,7 @@ def combine_segmentations( module load conda conda activate {conda_env_path} -mkdir -p {combined_output}/cellpose_dino +# mkdir -p {combined_output}/cellpose_dino mkdir -p {combined_output}/sam_dino mkdir -p {pscratch_path}/tomo_seg_logs @@ -1467,7 +1467,7 @@ def combine_segmentations( echo "============================================================" echo "Input: {input_dir}" echo "SAM3: {sam3_results}" -echo "Cellpose: {cellpose_results}" +# echo "Cellpose: {cellpose_results}" echo "DINO: {dino_results}" echo "Output: {combined_output}" echo "============================================================" @@ -1476,15 +1476,15 @@ def combine_segmentations( cd {seg_scripts_dir} -echo "--- Running Cellpose + DINO combination ---" -python -m src.combine_cellpose_dino \\ - --input-dir "{input_dir}" \\ - --instance-masks-dir "{cellpose_results}/instance_masks" \\ - --semantic-masks-dir "{dino_results}/semantic_masks" \\ - --output-dir "{combined_output}/cellpose_dino" +# echo "--- Running Cellpose + DINO combination ---" +# python -m src.combine_cellpose_dino \\ +# --input-dir "{input_dir}" \\ +# --instance-masks-dir "{cellpose_results}/instance_masks" \\ +# --semantic-masks-dir "{dino_results}/semantic_masks" \\ +# --output-dir "{combined_output}/cellpose_dino" -CELLPOSE_DINO_STATUS=$? -echo "Cellpose+DINO exit status: $CELLPOSE_DINO_STATUS" +# CELLPOSE_DINO_STATUS=$? +# echo "Cellpose+DINO exit status: $CELLPOSE_DINO_STATUS" echo "--- Running SAM3 + DINO combination ---" python -m src.combine_sam_dino_v2 \\ @@ -1506,13 +1506,15 @@ def combine_segmentations( echo "SEGMENTATION COMBINATION COMPLETED: $(date)" echo "============================================================" echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" -echo "Cellpose+DINO status: $CELLPOSE_DINO_STATUS" +# echo "Cellpose+DINO status: $CELLPOSE_DINO_STATUS" echo "SAM3+DINO status: $SAM_DINO_STATUS" echo "============================================================" chmod -R 2775 {combined_output} 2>/dev/null || true -if [ $CELLPOSE_DINO_STATUS -ne 0 ] || [ $SAM_DINO_STATUS -ne 0 ]; then +# if [ $CELLPOSE_DINO_STATUS -ne 0 ] || [ $SAM_DINO_STATUS -ne 0 ]; then +if [ $SAM_DINO_STATUS -ne 0 ]; then + exit 1 fi exit 0 @@ -2456,9 +2458,9 @@ def nersc_forge_recon_multisegment_flow( dino_future = nersc_segmentation_dino_task.submit( recon_folder_path=scratch_path_tiff, config=config ) - cellpose_future = nersc_segmentation_cellpose_task.submit( - recon_folder_path=scratch_path_tiff, config=config - ) + # cellpose_future = nersc_segmentation_cellpose_task.submit( + # recon_folder_path=scratch_path_tiff, config=config + # ) # ── STEP 4: Transfer each model's output as it completes ───────────────── sam3_result = sam3_future.result() @@ -2492,27 +2494,31 @@ def nersc_forge_recon_multisegment_flow( except Exception as e: logger.error(f"Failed to transfer DINO outputs to data832: {e}") - cellpose_success = cellpose_future.result() - logger.info(f"Cellpose segmentation result: {cellpose_success}") - if cellpose_success: - logger.info("Transferring Cellpose segmentation outputs to data832") - cellpose_segment_path = f"{folder_name}/seg{file_name}/cellpose" - try: - data832_cellpose_transfer_success = transfer_controller.copy( - file_path=cellpose_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch - ) - logger.info(f"Cellpose transfer to data832 success: {data832_cellpose_transfer_success}") - except Exception as e: - logger.error(f"Failed to transfer Cellpose outputs to data832: {e}") + # cellpose_success = cellpose_future.result() + # logger.info(f"Cellpose segmentation result: {cellpose_success}") + # if cellpose_success: + # logger.info("Transferring Cellpose segmentation outputs to data832") + # cellpose_segment_path = f"{folder_name}/seg{file_name}/cellpose" + # try: + # data832_cellpose_transfer_success = transfer_controller.copy( + # file_path=cellpose_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch + # ) + # logger.info(f"Cellpose transfer to data832 success: {data832_cellpose_transfer_success}") + # except Exception as e: + # logger.error(f"Failed to transfer Cellpose outputs to data832: {e}") + + # any_seg_success = any([sam3_success, dino_success, cellpose_success]) + any_seg_success = any([sam3_success, dino_success]) - any_seg_success = any([sam3_success, dino_success, cellpose_success]) - logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}, Cellpose: {cellpose_success}") + # logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}, Cellpose: {cellpose_success}") + logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}") # ── STEP 5: Combine (after all three complete) ──────────────────────────── - if dino_success and (sam3_success or cellpose_success): + # if dino_success and (sam3_success or cellpose_success): + if dino_success and sam3_success: logger.info("Running segmentation combination (SAM3+DINO and Cellpose+DINO).") combine_success = controller.combine_segmentations(recon_folder_path=scratch_path_tiff) logger.info(f"Combination result: {combine_success}") @@ -2564,7 +2570,7 @@ def nersc_forge_recon_multisegment_flow( check_endpoint=config.data832_scratch if any([ data832_sam3_transfer_success, data832_dino_transfer_success, - data832_cellpose_transfer_success + # data832_cellpose_transfer_success ]) else None, days_from_now=1.0 ) @@ -2582,8 +2588,10 @@ def nersc_forge_recon_multisegment_flow( except Exception as e: logger.warning(f"Failed to schedule data832 tiff pruning: {e}") - if any([data832_sam3_transfer_success, data832_dino_transfer_success, - data832_cellpose_transfer_success, data832_combined_transfer_success]): + if any([data832_sam3_transfer_success, + data832_dino_transfer_success, + # data832_cellpose_transfer_success, + data832_combined_transfer_success]): try: prune_controller.prune( file_path=scratch_path_segment, @@ -2600,7 +2608,9 @@ def nersc_forge_recon_multisegment_flow( else: logger.warning( f"Flow completed with issues: recon={nersc_reconstruction_success}, " - f"sam3={sam3_success}, dino={dino_success}, cellpose={cellpose_success}" + # f"sam3={sam3_success}, dino={dino_success}, cellpose={cellpose_success}" + f"sam3={sam3_success}, dino={dino_success}" + ) return False From 790e1e0a0ecf9eb209358c9c9737ef58d84a5df9 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Sat, 21 Feb 2026 09:23:17 -0800 Subject: [PATCH 39/72] updating default number of num nodes for sam3 segmenation --- orchestration/flows/bl832/nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 2de92af6..410d7dfd 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -710,7 +710,7 @@ def build_multi_resolution_optimize( def segmentation( self, recon_folder_path: str = "", - num_nodes: int = 26, + num_nodes: int = 42, ) -> dict: """ Run SAM3 segmentation at NERSC Perlmutter (v6 with overlap + max confidence stitching). From a5d92c833984359a1b598e665aead434a847539c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Sat, 21 Feb 2026 09:35:03 -0800 Subject: [PATCH 40/72] adding extract_regions task to the multiseg flow --- orchestration/flows/bl832/nersc.py | 235 ++++++++++++++++++++++++++++- 1 file changed, 231 insertions(+), 4 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 410d7dfd..9f8cb6fe 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1380,6 +1380,169 @@ def segmentation_cellpose( return False else: return False + + def seg_extract_regions( + self, + recon_folder_path: str = "", + ) -> bool: + """ + Extract Hydrated and Dehydrated Xylem regions using DINO semantic masks + and SAM3 instance masks at NERSC Perlmutter via SFAPI Slurm job. + + :param recon_folder_path: Relative path to the reconstructed data folder, + e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' + :return: True if the job completed successfully, False otherwise. + """ + logger.info("Starting NERSC region extraction process.") + + user = self.client.user() + pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" + cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" + conda_env_path = f"{cfs_path}/envs/dino_demo" + + seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo" + + seg_folder = recon_folder_path.replace("/rec", "/seg") + input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" + seg_base = f"{pscratch_path}/8.3.2/scratch/{seg_folder}" + + sam3_results = f"{seg_base}/sam3" + dino_results = f"{seg_base}/dino" + combined_output = f"{seg_base}/combined" + extract_output = f"{combined_output}/extract_regions" + + logger.info(f"Extract regions input dir: {input_dir}") + logger.info(f"Extract regions output dir: {extract_output}") + + EXTRACT_DEFAULTS = { + "defaults": True, + "num_nodes": 8, + "qos": "regular", + "account": "amsc006", + "constraint": "cpu", + "walltime": "01:00:00", + "dilate_px": 5, + "xylem_mode": "all", + } + try: + seg_options = Variable.get("nersc-extract-regions-options", default={"defaults": True}, _sync=True) + if isinstance(seg_options, str): + import json + seg_options = json.loads(seg_options) + except Exception as e: + logger.warning(f"Could not load nersc-extract-regions-options: {e}. Using defaults.") + seg_options = {"defaults": True} + + use_defaults = seg_options.get("defaults", True) + opts = EXTRACT_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in EXTRACT_DEFAULTS.items()} + + num_nodes = opts["num_nodes"] + qos = opts["qos"] + account = opts["account"] + constraint = opts["constraint"] + walltime = opts["walltime"] + dilate_px = opts["dilate_px"] + xylem_mode = opts["xylem_mode"] + + job_name = f"extract_{Path(recon_folder_path).name}" + + job_script = f"""#!/bin/bash +#SBATCH -q {qos} +#SBATCH -A {account} +#SBATCH -N {num_nodes} +#SBATCH -C {constraint} +#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_CPU +#SBATCH --job-name={job_name} +#SBATCH --time={walltime} +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=128 +#SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out +#SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err + +module load conda +conda activate {conda_env_path} + +mkdir -p {extract_output} +mkdir -p {pscratch_path}/tomo_seg_logs + +echo "============================================================" +echo "REGION EXTRACTION STARTED: $(date)" +echo "============================================================" +echo "Input: {input_dir}" +echo "SAM3 masks: {sam3_results}" +echo "DINO masks: {dino_results}/semantic_masks" +echo "Output: {extract_output}" +echo "Xylem mode: {xylem_mode}" +echo "Dilate px: {dilate_px}" +echo "============================================================" + +START_TIME=$(date +%s) + +cd {seg_scripts_dir} + +# Extract Hydrated and Dehydrated Xylem regions +python -m src.extract_region \\ + --input-dir "{input_dir}" \\ + --semantic-masks-dir "{dino_results}/semantic_masks" \\ + --instance-masks-dir "{sam3_results}" \\ + --output-dir "{extract_output}" \\ + --xylem-mode {xylem_mode} \\ + --dilate-px {dilate_px} + +EXTRACT_STATUS=$? + +END_TIME=$(date +%s) +DURATION=$((END_TIME - START_TIME)) +MINUTES=$((DURATION / 60)) +SECONDS=$((DURATION % 60)) + +echo "" +echo "============================================================" +echo "REGION EXTRACTION COMPLETED: $(date)" +echo "============================================================" +echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" +echo "Exit status: $EXTRACT_STATUS" +echo "============================================================" + +chmod -R 2775 {extract_output} 2>/dev/null || true + +exit $EXTRACT_STATUS +""" + try: + logger.info("Submitting region extraction job to Perlmutter.") + perlmutter = self.client.compute(Machine.perlmutter) + job = perlmutter.submit_job(job_script) + logger.info(f"Submitted job ID: {job.jobid}") + + try: + job.update() + except Exception as update_err: + logger.warning(f"Initial job update failed, continuing: {update_err}") + + time.sleep(60) + logger.info(f"Job {job.jobid} current state: {job.state}") + + job.complete() + logger.info("Region extraction job completed successfully.") + return True + + except Exception as e: + logger.error(f"Error during region extraction job submission or completion: {e}") + match = re.search(r"Job not found:\s*(\d+)", str(e)) + if match: + jobid = match.group(1) + logger.info(f"Attempting to recover job {jobid}.") + try: + job = self.client.compute(Machine.perlmutter).job(jobid=jobid) + time.sleep(30) + job.complete() + logger.info("Region extraction job completed successfully after recovery.") + return True + except Exception as recovery_err: + logger.error(f"Failed to recover job {jobid}: {recovery_err}") + return False + else: + return False def combine_segmentations( self, @@ -2518,13 +2681,42 @@ def nersc_forge_recon_multisegment_flow( # ── STEP 5: Combine (after all three complete) ──────────────────────────── # if dino_success and (sam3_success or cellpose_success): + # if dino_success and sam3_success: + # logger.info("Running segmentation combination (SAM3+DINO and Cellpose+DINO).") + # combine_success = controller.combine_segmentations(recon_folder_path=scratch_path_tiff) + # logger.info(f"Combination result: {combine_success}") + # if combine_success: + # logger.info("Transferring combined segmentation outputs to data832") + # combined_segment_path = f"{folder_name}/seg{file_name}/combined" + # try: + # data832_combined_transfer_success = transfer_controller.copy( + # file_path=combined_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch + # ) + # logger.info(f"Combined transfer to data832 success: {data832_combined_transfer_success}") + # except Exception as e: + # logger.error(f"Failed to transfer combined outputs to data832: {e}") + # else: + # logger.warning("Skipping combination: requires DINO plus at least one of SAM3/Cellpose.") + + # ── STEP 5: Combine + Extract Regions concurrently (after SAM3+DINO complete) ── + # if dino_success and (sam3_success or cellpose_success): if dino_success and sam3_success: - logger.info("Running segmentation combination (SAM3+DINO and Cellpose+DINO).") - combine_success = controller.combine_segmentations(recon_folder_path=scratch_path_tiff) + logger.info("Running segmentation combination and region extraction concurrently.") + + combine_future = nersc_combine_segmentations_task.submit( + recon_folder_path=scratch_path_tiff, config=config + ) + extract_future = nersc_extract_regions_task.submit( + recon_folder_path=scratch_path_tiff, config=config + ) + + combine_success = combine_future.result() logger.info(f"Combination result: {combine_success}") if combine_success: logger.info("Transferring combined segmentation outputs to data832") - combined_segment_path = f"{folder_name}/seg{file_name}/combined" + combined_segment_path = f"{folder_name}/seg{file_name}/combined/sam_dino" try: data832_combined_transfer_success = transfer_controller.copy( file_path=combined_segment_path, @@ -2534,8 +2726,24 @@ def nersc_forge_recon_multisegment_flow( logger.info(f"Combined transfer to data832 success: {data832_combined_transfer_success}") except Exception as e: logger.error(f"Failed to transfer combined outputs to data832: {e}") + + extract_success = extract_future.result() + logger.info(f"Region extraction result: {extract_success}") + if extract_success: + logger.info("Transferring extracted region outputs to data832") + extract_segment_path = f"{folder_name}/seg{file_name}/combined/extract_regions" + try: + data832_extract_transfer_success = transfer_controller.copy( + file_path=extract_segment_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.data832_scratch + ) + logger.info(f"Extract regions transfer to data832 success: {data832_extract_transfer_success}") + except Exception as e: + logger.error(f"Failed to transfer extracted region outputs to data832: {e}") else: - logger.warning("Skipping combination: requires DINO plus at least one of SAM3/Cellpose.") + logger.warning("Skipping combination and extraction: requires DINO plus SAM3.") + # ── STEP 6: Pruning ─────────────────────────────────────────────────────── logger.info("Scheduling file pruning tasks.") @@ -2915,6 +3123,25 @@ def nersc_segmentation_cellpose_task( return success +@task(name="nersc_extract_regions_task") +def nersc_extract_regions_task( + recon_folder_path: str, + config: Optional[Config832] = None, +) -> bool: + logger = get_run_logger() + if config is None: + logger.info("No config provided, using default Config832.") + config = Config832() + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + logger.info(f"Starting NERSC region extraction task for {recon_folder_path=}") + success = tomography_controller.seg_extract_regions(recon_folder_path=recon_folder_path) + if not success: + logger.error("Region extraction failed.") + else: + logger.info("Region extraction successful.") + return success + + @task(name="nersc_combine_segmentations_task") def nersc_combine_segmentations_task( recon_folder_path: str, From f0a91fd2a5b48c0df07ed3b1a994d4e05b9b235a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Sat, 21 Feb 2026 09:39:52 -0800 Subject: [PATCH 41/72] removing some commented code --- orchestration/flows/bl832/nersc.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 9f8cb6fe..d8926539 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1570,7 +1570,6 @@ def combine_segmentations( seg_base = f"{pscratch_path}/8.3.2/scratch/{seg_folder}" sam3_results = f"{seg_base}/sam3" - # cellpose_results = f"{seg_base}/cellpose" dino_results = f"{seg_base}/dino" combined_output = f"{seg_base}/combined" @@ -1630,7 +1629,6 @@ def combine_segmentations( echo "============================================================" echo "Input: {input_dir}" echo "SAM3: {sam3_results}" -# echo "Cellpose: {cellpose_results}" echo "DINO: {dino_results}" echo "Output: {combined_output}" echo "============================================================" @@ -1639,16 +1637,6 @@ def combine_segmentations( cd {seg_scripts_dir} -# echo "--- Running Cellpose + DINO combination ---" -# python -m src.combine_cellpose_dino \\ -# --input-dir "{input_dir}" \\ -# --instance-masks-dir "{cellpose_results}/instance_masks" \\ -# --semantic-masks-dir "{dino_results}/semantic_masks" \\ -# --output-dir "{combined_output}/cellpose_dino" - -# CELLPOSE_DINO_STATUS=$? -# echo "Cellpose+DINO exit status: $CELLPOSE_DINO_STATUS" - echo "--- Running SAM3 + DINO combination ---" python -m src.combine_sam_dino_v2 \\ --input-dir "{input_dir}" \\ From ba6dd585c112b97dd9b7d7a5c2c0e47e811c253c Mon Sep 17 00:00:00 2001 From: David Abramov Date: Sat, 21 Feb 2026 09:55:49 -0800 Subject: [PATCH 42/72] copy recon/segment results from prscratch to cfs when the flow is done --- orchestration/flows/bl832/nersc.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index d8926539..009fba0e 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2732,6 +2732,17 @@ def nersc_forge_recon_multisegment_flow( else: logger.warning("Skipping combination and extraction: requires DINO plus SAM3.") + logger.info("Copying rec and seg folders from pscratch to NERSC CFS.") + for cfs_path in [scratch_path_tiff, scratch_path_segment]: + try: + transfer_controller.copy( + file_path=cfs_path, + source=config.nersc832_alsdev_pscratch_scratch, + destination=config.nersc832_alsdev_scratch + ) + logger.info(f"CFS transfer success: {cfs_path}") + except Exception as e: + logger.error(f"Failed to copy {cfs_path} to NERSC CFS: {e}") # ── STEP 6: Pruning ─────────────────────────────────────────────────────── logger.info("Scheduling file pruning tasks.") From 8e588de45f4e642fe9cfcbfa60b63ddba33577ae Mon Sep 17 00:00:00 2001 From: David Abramov Date: Sat, 21 Feb 2026 13:15:41 -0800 Subject: [PATCH 43/72] using new code --- orchestration/flows/bl832/nersc.py | 37 +++++++++++++++--------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 009fba0e..0d03cb9b 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -896,6 +896,7 @@ def segmentation( --confidence {confidence_str} \ --overlap-ratio {overlap} \ --prompts 'Cortex' 'Phloem Fibers' 'Phloem' 'Hydrated Xylem vessels' 'Air-based Pith cells' 'Water-based Pith cells' 'Dehydrated Xylem vessels' \ + # --prompts 'Phloem Fibers' 'Hydrated Xylem vessels' 'Air-based Pith cells' 'Dehydrated Xylem vessels' \ --bpe-path "${{BPE_PATH}}" \ --original-checkpoint "${{ORIG_CKPT}}" \ --finetuned-checkpoint "${{FT_CKPT}}" @@ -1638,7 +1639,7 @@ def combine_segmentations( cd {seg_scripts_dir} echo "--- Running SAM3 + DINO combination ---" -python -m src.combine_sam_dino_v2 \\ +python -m src.combine_sam_dino_v3 \\ --input-dir "{input_dir}" \\ --instance-masks-dir "{sam3_results}" \\ --semantic-masks-dir "{dino_results}/semantic_masks" \\ @@ -2696,9 +2697,9 @@ def nersc_forge_recon_multisegment_flow( combine_future = nersc_combine_segmentations_task.submit( recon_folder_path=scratch_path_tiff, config=config ) - extract_future = nersc_extract_regions_task.submit( - recon_folder_path=scratch_path_tiff, config=config - ) + # extract_future = nersc_extract_regions_task.submit( + # recon_folder_path=scratch_path_tiff, config=config + # ) combine_success = combine_future.result() logger.info(f"Combination result: {combine_success}") @@ -2715,20 +2716,20 @@ def nersc_forge_recon_multisegment_flow( except Exception as e: logger.error(f"Failed to transfer combined outputs to data832: {e}") - extract_success = extract_future.result() - logger.info(f"Region extraction result: {extract_success}") - if extract_success: - logger.info("Transferring extracted region outputs to data832") - extract_segment_path = f"{folder_name}/seg{file_name}/combined/extract_regions" - try: - data832_extract_transfer_success = transfer_controller.copy( - file_path=extract_segment_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch - ) - logger.info(f"Extract regions transfer to data832 success: {data832_extract_transfer_success}") - except Exception as e: - logger.error(f"Failed to transfer extracted region outputs to data832: {e}") + # extract_success = extract_future.result() + # logger.info(f"Region extraction result: {extract_success}") + # if extract_success: + # logger.info("Transferring extracted region outputs to data832") + # extract_segment_path = f"{folder_name}/seg{file_name}/combined/extract_regions" + # try: + # data832_extract_transfer_success = transfer_controller.copy( + # file_path=extract_segment_path, + # source=config.nersc832_alsdev_pscratch_scratch, + # destination=config.data832_scratch + # ) + # logger.info(f"Extract regions transfer to data832 success: {data832_extract_transfer_success}") + # except Exception as e: + # logger.error(f"Failed to transfer extracted region outputs to data832: {e}") else: logger.warning("Skipping combination and extraction: requires DINO plus SAM3.") From 42dce23682954a829768a24ab3f7d107916f08f9 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Sat, 21 Feb 2026 13:30:10 -0800 Subject: [PATCH 44/72] shortened prompt list --- orchestration/flows/bl832/nersc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 0d03cb9b..ec32296b 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -801,7 +801,8 @@ def segmentation( # qos = "realtime" # else: # qos = "regular" - +# --prompts 'Cortex' 'Phloem Fibers' 'Phloem' 'Hydrated Xylem vessels' 'Air-based Pith cells' 'Water-based Pith cells' 'Dehydrated Xylem vessels' \ + walltime = "00:59:00" job_name = f"seg_{Path(recon_folder_path).name}" @@ -895,8 +896,7 @@ def segmentation( --batch-size "${{BATCH_SIZE}}" \ --confidence {confidence_str} \ --overlap-ratio {overlap} \ - --prompts 'Cortex' 'Phloem Fibers' 'Phloem' 'Hydrated Xylem vessels' 'Air-based Pith cells' 'Water-based Pith cells' 'Dehydrated Xylem vessels' \ - # --prompts 'Phloem Fibers' 'Hydrated Xylem vessels' 'Air-based Pith cells' 'Dehydrated Xylem vessels' \ + --prompts 'Phloem Fibers' 'Hydrated Xylem vessels' 'Air-based Pith cells' 'Dehydrated Xylem vessels' \ --bpe-path "${{BPE_PATH}}" \ --original-checkpoint "${{ORIG_CKPT}}" \ --finetuned-checkpoint "${{FT_CKPT}}" From 0a1682ff2258ef2e0c424fa78ebbd7a857aa5a91 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Sat, 21 Feb 2026 13:49:28 -0800 Subject: [PATCH 45/72] fixing combine step --- orchestration/flows/bl832/nersc.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index ec32296b..c15296a6 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1550,7 +1550,7 @@ def combine_segmentations( recon_folder_path: str = "", ) -> bool: """ - Run CPU-based combination of Cellpose+DINO and SAM3+DINO segmentation results + Run CPU-based combination of SAM3+DINO segmentation results at NERSC Perlmutter via SFAPI Slurm job. :param recon_folder_path: Relative path to the reconstructed data folder, @@ -1584,6 +1584,7 @@ def combine_segmentations( "account": "amsc006", "constraint": "cpu", "walltime": "01:00:00", + "dilate_px": 5, } try: seg_options = Variable.get("nersc-combine-seg-options", default={"defaults": True}, _sync=True) @@ -1602,6 +1603,7 @@ def combine_segmentations( account = opts["account"] constraint = opts["constraint"] walltime = opts["walltime"] + dilate_px = opts["dilate_px"] job_name = f"combine_{Path(recon_folder_path).name}" @@ -1621,7 +1623,6 @@ def combine_segmentations( module load conda conda activate {conda_env_path} -# mkdir -p {combined_output}/cellpose_dino mkdir -p {combined_output}/sam_dino mkdir -p {pscratch_path}/tomo_seg_logs @@ -1632,18 +1633,22 @@ def combine_segmentations( echo "SAM3: {sam3_results}" echo "DINO: {dino_results}" echo "Output: {combined_output}" +echo "Dilate: {dilate_px}px" echo "============================================================" START_TIME=$(date +%s) cd {seg_scripts_dir} -echo "--- Running SAM3 + DINO combination ---" +echo "--- Running SAM3 + DINO combination (v3) ---" python -m src.combine_sam_dino_v3 \\ --input-dir "{input_dir}" \\ --instance-masks-dir "{sam3_results}" \\ --semantic-masks-dir "{dino_results}/semantic_masks" \\ - --output-dir "{combined_output}/sam_dino" + --output-dir "{combined_output}/sam_dino" \\ + --dilate-px {dilate_px} \\ + --save-extracted \\ + --dino-trust Cortex Phloem_Fibers Phloem Air-based_Pith_cells Water-based_Pith_cells SAM_DINO_STATUS=$? echo "SAM3+DINO exit status: $SAM_DINO_STATUS" @@ -1658,15 +1663,12 @@ def combine_segmentations( echo "SEGMENTATION COMBINATION COMPLETED: $(date)" echo "============================================================" echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" -# echo "Cellpose+DINO status: $CELLPOSE_DINO_STATUS" echo "SAM3+DINO status: $SAM_DINO_STATUS" echo "============================================================" chmod -R 2775 {combined_output} 2>/dev/null || true -# if [ $CELLPOSE_DINO_STATUS -ne 0 ] || [ $SAM_DINO_STATUS -ne 0 ]; then if [ $SAM_DINO_STATUS -ne 0 ]; then - exit 1 fi exit 0 From 51c85f9582dcd453801234294ea43d76d37f4fff Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 24 Feb 2026 09:11:00 -0800 Subject: [PATCH 46/72] reservation --- orchestration/flows/bl832/nersc.py | 51 +++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index c15296a6..12cd9f23 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -262,7 +262,7 @@ def reconstruct_multinode( job_script = f"""#!/bin/bash #SBATCH -q regular # {qos} #SBATCH -A amsc006 # als -#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_CPU +#SBATCH --reservation=_CAP_MarchModCon_CPU #SBATCH -C cpu #SBATCH --job-name=tomo_recon_{folder_name}_{file_name} #SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out @@ -809,7 +809,7 @@ def segmentation( job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} -#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_GPU +#SBATCH --reservation=_CAP_MarchModCon_GPU #SBATCH -N {num_nodes} #SBATCH -C {constraint} # gpu #SBATCH --job-name={job_name} @@ -1081,7 +1081,7 @@ def segmentation_dino( #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_GPU +#SBATCH --reservation=_CAP_MarchModCon_GPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node=1 @@ -1264,7 +1264,7 @@ def segmentation_cellpose( #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_GPU +#SBATCH --reservation=_CAP_MarchModCon_GPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node=1 @@ -1452,7 +1452,7 @@ def seg_extract_regions( #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_CPU +#SBATCH --reservation=_CAP_MarchModCon_GPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks=1 @@ -1612,7 +1612,7 @@ def combine_segmentations( #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=_CAP_March_ModCon_Dry_Run_CPU +#SBATCH --reservation=_CAP_MarchModCon_GPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks=1 @@ -3035,6 +3035,37 @@ def nersc_multiresolution_task( return nersc_multiresolution_success +@task(name="nersc_tiff_to_zarr_task") +def nersc_tiff_to_zarr_task( + file_path: str, + config: Optional[Config832] = None, +) -> bool: + """ + Run tiff-to-zarr (single-node) multiresolution task at NERSC. + + :param file_path: Path to the raw .h5 file (used to derive recon path). + :param config: Configuration object for the flow. + :return: True if the task completed successfully, False otherwise. + """ + logger = get_run_logger() + if config is None: + logger.info("No config provided, using default Config832.") + config = Config832() + + logger.info("Initializing NERSC Tomography HPC Controller.") + tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) + + logger.info(f"Starting NERSC tiff-to-zarr task for {file_path=}") + success = tomography_controller.build_multi_resolution(file_path=file_path) + + if not success: + logger.error("Tiff-to-zarr failed.") + else: + logger.info("Tiff-to-zarr successful.") + return success + + + @flow(name="nersc_multiresolution_integration_test", flow_run_name="nersc_multiresolution_integration_test") def nersc_multiresolution_integration_test() -> bool: """ @@ -3183,8 +3214,12 @@ def nersc_segmentation_integration_test() -> bool: if __name__ == "__main__": - nersc_multiresolution_integration_test() - + # nersc_multiresolution_integration_test() + nersc_tiff_to_zarr_task( + file_path='DD-00842_hexemer/20260222_122341_petiole51.h5', + config=Config832() + ) + # Run the integration test flow # from sfapi_client import Client From 2aebe99e5d9cc869b7edccba2f597e4fa4f1ba2a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Tue, 24 Feb 2026 10:18:36 -0800 Subject: [PATCH 47/72] fixing combine step reservation (CPU) --- orchestration/flows/bl832/nersc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 12cd9f23..912c362e 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1612,7 +1612,7 @@ def combine_segmentations( #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=_CAP_MarchModCon_GPU +#SBATCH --reservation=_CAP_MarchModCon_CPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks=1 From 5e1778ca770008e21cc32c33f22ced6dffe2f99a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Mar 2026 10:53:26 -0700 Subject: [PATCH 48/72] linting --- orchestration/flows/bl832/nersc.py | 115 ++++++++++++++--------------- 1 file changed, 54 insertions(+), 61 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 912c362e..647efb57 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -207,7 +207,7 @@ def reconstruct( def reconstruct_multinode( self, file_path: str = "", - num_nodes: int = 16, + num_nodes: int = 2, ) -> bool: """ @@ -254,15 +254,14 @@ def reconstruct_multinode( if num_nodes > 8: qos = "premium" -#SBATCH -q regular -#SBATCH -A amsc006 -#SBATCH --reservation=_CAP_MarchModCon_CPU - +# If using with a reservation: +# SBATCH -q regular +# SBATCH --reservation=_CAP_MarchModCon_CPU +# SBATCH -A amsc006 # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately job_script = f"""#!/bin/bash -#SBATCH -q regular # {qos} -#SBATCH -A amsc006 # als -#SBATCH --reservation=_CAP_MarchModCon_CPU +#SBATCH -q {qos} +#SBATCH -A als #SBATCH -C cpu #SBATCH --job-name=tomo_recon_{folder_name}_{file_name} #SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out @@ -600,7 +599,7 @@ def build_multi_resolution_optimize( user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - + multires_image = self.config.ghcr_images832["multires_image"] recon_scripts_dir = self.config.nersc832_alsdev_recon_scripts.root_path @@ -706,7 +705,6 @@ def build_multi_resolution_optimize( else: return False - def segmentation( self, recon_folder_path: str = "", @@ -721,7 +719,7 @@ def segmentation( pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" conda_env_path = f"{cfs_path}/envs/sam3-py311" - + # Paths # seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo/" @@ -730,11 +728,11 @@ def segmentation( bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" original_checkpoint = f"{checkpoints_dir}/sam3.pt" # finetuned_checkpoint = f"{checkpoints_dir}/checkpoint_v3.pt" - + input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" output_folder = recon_folder_path.replace('/rec', '/seg') output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}/sam3" - + logger.info(f"Input directory: {input_dir}") logger.info(f"Output directory: {output_dir}") logger.info(f"Conda environment: {conda_env_path}") @@ -748,7 +746,7 @@ def segmentation( default_account = "als" default_constraint = "gpu" default_checkpoint = "checkpoint_v6.pt" - + # Load options from Prefect variable try: seg_options = Variable.get("nersc-segmentation-options", default={}) @@ -758,10 +756,10 @@ def segmentation( except Exception as e: logger.warning(f"Could not load nersc-segmentation-options variable: {e}. Using defaults.") seg_options = {"defaults": True} - + # Determine which values to use use_defaults = seg_options.get("defaults", True) - + if use_defaults: logger.info("Using hardcoded default segmentation parameters") batch_size = default_batch_size @@ -782,9 +780,9 @@ def segmentation( account = seg_options.get("account", default_account) constraint = seg_options.get("constraint", default_constraint) checkpoint = seg_options.get("checkpoint", default_checkpoint) - # batch_size = 16 + # batch_size = 16 # nproc_per_node = 4 - + finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" # Format confidence for command line (handles both single value and list) @@ -793,15 +791,14 @@ def segmentation( else: confidence_str = str(confidence) - # prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", - # "Water-based Pith cells", "Xylem vessels"] + # prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", + # "Water-based Pith cells", "Xylem vessels"] # prompts_str = " ".join([f'"{p}"' for p in prompts]) - + # if num_nodes <= 4: # qos = "realtime" # else: # qos = "regular" -# --prompts 'Cortex' 'Phloem Fibers' 'Phloem' 'Hydrated Xylem vessels' 'Air-based Pith cells' 'Water-based Pith cells' 'Dehydrated Xylem vessels' \ walltime = "00:59:00" job_name = f"seg_{Path(recon_folder_path).name}" @@ -814,9 +811,9 @@ def segmentation( #SBATCH -C {constraint} # gpu #SBATCH --job-name={job_name} #SBATCH --time={walltime} -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=4 -#SBATCH --cpus-per-task=128 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=128 #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err @@ -934,37 +931,37 @@ def segmentation( exit $SEG_STATUS """ - + try: logger.info("Submitting segmentation job to Perlmutter (v5).") perlmutter = self.client.compute(Machine.perlmutter) - + # Ensure directories exist logger.info("Creating necessary directories...") perlmutter.run(f"mkdir -p {pscratch_path}/tomo_seg_logs") perlmutter.run(f"mkdir -p {output_dir}") - + # Submit job job = perlmutter.submit_job(job_script) logger.info(f"Submitted job ID: {job.jobid}") - + # Initial update try: job.update() except Exception as update_err: logger.warning(f"Initial job update failed, continuing: {update_err}") - + # Wait briefly before polling time.sleep(60) logger.info(f"Job {job.jobid} current state: {job.state}") - + # Wait for completion job.complete() logger.info("Segmentation job completed successfully.") - + # Fetch timing data from output file timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, job.jobid, job_name) - + if timing: logger.info("=" * 60) logger.info("SEGMENTATION TIMING BREAKDOWN") @@ -975,19 +972,19 @@ def segmentation( logger.info(f" Throughput: {timing.get('throughput', 'N/A')} images/min") logger.info(f" Exit status: {timing.get('exit_status', 'N/A')}") logger.info("=" * 60) - + return { "success": True, "job_id": job.jobid, "timing": timing, "output_dir": output_dir } - + except Exception as e: logger.error(f"Error during segmentation job: {e}") import traceback logger.error(traceback.format_exc()) - + # Attempt recovery match = re.search(r"Job not found:\s*(\d+)", str(e)) if match: @@ -998,7 +995,7 @@ def segmentation( time.sleep(30) job.complete() logger.info("Segmentation job completed after recovery.") - + timing = self._fetch_seg_timing_from_output(perlmutter, pscratch_path, jobid, job_name) return { "success": True, @@ -1008,7 +1005,7 @@ def segmentation( } except Exception as recovery_err: logger.error(f"Failed to recover job {jobid}: {recovery_err}") - + return { "success": False, "job_id": None, @@ -1381,7 +1378,7 @@ def segmentation_cellpose( return False else: return False - + def seg_extract_regions( self, recon_folder_path: str = "", @@ -1712,7 +1709,7 @@ def combine_segmentations( def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: str, job_name: str) -> dict: """ Fetch and parse timing data from the SLURM output file. - + :param perlmutter: SFAPI compute object for Perlmutter :param pscratch_path: Path to the user's pscratch directory :param job_id: SLURM job ID @@ -1720,11 +1717,11 @@ def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: :return: Dictionary with timing breakdown """ output_file = f"{pscratch_path}/tomo_seg_logs/{job_name}_{job_id}.out" - + try: # Use SFAPI to read the output file result = perlmutter.run(f"cat {output_file}") - + # Handle different result types if isinstance(result, str): output = result @@ -1734,15 +1731,15 @@ def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: output = result.stdout else: output = str(result) - - logger.info(f"Job output file contents (last 50 lines):") + + logger.info("Job output file contents (last 50 lines):") lines = output.strip().split('\n') for line in lines[-50:]: logger.info(f" {line}") - + # Parse timing data from the output timing = {} - + for line in lines: if "Total time:" in line: # Extract: "Total time: 5m 23s (323s)" @@ -1750,40 +1747,39 @@ def _fetch_seg_timing_from_output(self, perlmutter, pscratch_path: str, job_id: if match: timing['total_time'] = f"{match.group(1)}m {match.group(2)}s" timing['total_seconds'] = int(match.group(3)) - + elif "Images processed:" in line: # Extract: "Images processed: 100" match = re.search(r'Images processed:\s+(\d+)', line) if match: timing['num_images'] = int(match.group(1)) - + elif "Time per image:" in line: # Extract: "Time per image: 3.230s" match = re.search(r'Time per image:\s+([\d.]+)s', line) if match: timing['time_per_image'] = f"{match.group(1)}s" - + elif "Throughput:" in line: # Extract: "Throughput: 18.58 images/minute" match = re.search(r'Throughput:\s+([\d.]+)\s+images/minute', line) if match: timing['throughput'] = float(match.group(1)) - + elif "Exit status:" in line: # Extract: "Exit status: 0" match = re.search(r'Exit status:\s+(\d+)', line) if match: timing['exit_status'] = int(match.group(1)) - + return timing if timing else None - + except Exception as e: logger.warning(f"Error fetching timing data from output: {e}") import traceback logger.warning(traceback.format_exc()) return None - def start_streaming_service( self, walltime: datetime.timedelta = datetime.timedelta(minutes=30), @@ -2366,7 +2362,7 @@ def nersc_forge_recon_segment_flow( logger.info("Reconstruction Successful.") # STEP 3: Send reconstructed data (tiff) to data832 - logger.info(f"Transferring reconstructed TIFFs from NERSC pscratch to data832") + logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") try: data832_tiff_transfer_success = transfer_controller.copy( file_path=scratch_path_tiff, @@ -2407,7 +2403,7 @@ def nersc_forge_recon_segment_flow( logger.info("Segmentation at NERSC Successful") # STEP 5: Transfer segmented data to data832 - logger.info(f"Transferring segmented data from NERSC pscratch to data832") + logger.info("Transferring segmented data from NERSC pscratch to data832") try: data832_segment_transfer_success = transfer_controller.copy( file_path=scratch_path_segment, @@ -2546,9 +2542,9 @@ def nersc_forge_recon_multisegment_flow( nersc_reconstruction_success = False sam3_success = False dino_success = False - cellpose_success = False + # cellpose_success = False data832_tiff_transfer_success = False - data832_segment_transfer_success = False + # data832_segment_transfer_success = False # ── STEP 1: Multinode Reconstruction ───────────────────────────────────── logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") @@ -2666,7 +2662,6 @@ def nersc_forge_recon_multisegment_flow( # any_seg_success = any([sam3_success, dino_success, cellpose_success]) any_seg_success = any([sam3_success, dino_success]) - # logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}, Cellpose: {cellpose_success}") logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}") @@ -2691,7 +2686,7 @@ def nersc_forge_recon_multisegment_flow( # else: # logger.warning("Skipping combination: requires DINO plus at least one of SAM3/Cellpose.") - # ── STEP 5: Combine + Extract Regions concurrently (after SAM3+DINO complete) ── + # ── STEP 5: Combine + Extract Regions concurrently (after SAM3+DINO complete) ── # if dino_success and (sam3_success or cellpose_success): if dino_success and sam3_success: logger.info("Running segmentation combination and region extraction concurrently.") @@ -3065,7 +3060,6 @@ def nersc_tiff_to_zarr_task( return success - @flow(name="nersc_multiresolution_integration_test", flow_run_name="nersc_multiresolution_integration_test") def nersc_multiresolution_integration_test() -> bool: """ @@ -3246,7 +3240,6 @@ def nersc_segmentation_integration_test() -> bool: # print(f"Integration test result: {result}") - # if __name__ == "__main__": # config = Config832() From dd78a6e9755cc61d8b7b24c541ad24f993b6016a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Mar 2026 10:59:33 -0700 Subject: [PATCH 49/72] removing cellpose --- orchestration/flows/bl832/nersc.py | 367 +---------------------------- 1 file changed, 3 insertions(+), 364 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 647efb57..c545d645 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1198,187 +1198,6 @@ def segmentation_dino( else: return False - def segmentation_cellpose( - self, - recon_folder_path: str = "", - ) -> bool: - """ - Run Cellpose segmentation at NERSC Perlmutter via SFAPI Slurm job. - - :param recon_folder_path: Relative path to the reconstructed data folder, - e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' - :return: True if the job completed successfully, False otherwise. - """ - logger.info("Starting NERSC Cellpose segmentation process.") - - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" - conda_env_path = f"{cfs_path}/envs/dino_demo" - - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo" - cellpose_checkpoint = f"{cfs_path}/tomography_segmentation_scripts/cellpose/petiole_model_flow0" - - input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" - seg_folder = recon_folder_path.replace("/rec", "/seg") - output_dir = f"{pscratch_path}/8.3.2/scratch/{seg_folder}/cellpose" - - logger.info(f"Cellpose input dir: {input_dir}") - logger.info(f"Cellpose output dir: {output_dir}") - - CELLPOSE_DEFAULTS = { - "defaults": True, - "num_nodes": 10, - "nproc_per_node": 4, - "qos": "regular", - "account": "amsc006", - "constraint": "gpu&hbm80g", - "walltime": "00:59:00", - } - try: - seg_options = Variable.get("nersc-cellpose-seg-options", default={"defaults": True}, _sync=True) - if isinstance(seg_options, str): - import json - seg_options = json.loads(seg_options) - except Exception as e: - logger.warning(f"Could not load nersc-cellpose-seg-options: {e}. Using defaults.") - seg_options = {"defaults": True} - - use_defaults = seg_options.get("defaults", True) - opts = CELLPOSE_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in CELLPOSE_DEFAULTS.items()} - - num_nodes = opts["num_nodes"] - nproc_per_node = opts["nproc_per_node"] - qos = opts["qos"] - account = opts["account"] - constraint = opts["constraint"] - walltime = opts["walltime"] - - job_name = f"cellpose_{Path(recon_folder_path).name}" - - job_script = f"""#!/bin/bash -#SBATCH -q {qos} -#SBATCH -A {account} -#SBATCH -N {num_nodes} -#SBATCH -C {constraint} -#SBATCH --reservation=_CAP_MarchModCon_GPU -#SBATCH --job-name={job_name} -#SBATCH --time={walltime} -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=4 -#SBATCH --cpus-per-task=128 -#SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out -#SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err - -export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) -export MASTER_PORT=29500 - -module load conda -conda activate {conda_env_path} - -HF_HOME_ROOT="{cfs_path}/.cache/huggingface" -mkdir -p "${{HF_HOME_ROOT}}/hub" "${{HF_HOME_ROOT}}/datasets" -export HF_HOME="${{HF_HOME_ROOT}}" -export HF_HUB_CACHE="${{HF_HOME_ROOT}}/hub" -export TRANSFORMERS_CACHE="${{HF_HUB_CACHE}}" -export HF_DATASETS_CACHE="${{HF_HOME_ROOT}}/datasets" - -chmod -R 2775 "{cfs_path}/tomography_segmentation_scripts/.cache" 2>/dev/null || true -chmod -R 2775 "${{HF_HOME_ROOT}}" 2>/dev/null || true - -mkdir -p {output_dir} -mkdir -p {pscratch_path}/tomo_seg_logs - -echo "============================================================" -echo "CELLPOSE SEGMENTATION STARTED: $(date)" -echo "============================================================" -echo "Master: $MASTER_ADDR:$MASTER_PORT" -echo "Nodes: $SLURM_JOB_NODELIST" -echo "Job ID: $SLURM_JOB_ID" -echo "Input: {input_dir}" -echo "Output: {output_dir}" -echo "============================================================" - -NUM_IMAGES=$(ls {input_dir}/*.tif* 2>/dev/null | wc -l) -echo "Images to process: ${{NUM_IMAGES}}" - -START_TIME=$(date +%s) - -cd {seg_scripts_dir} - -export TORCH_DISTRIBUTED_DEBUG=DETAIL -export NCCL_DEBUG=INFO -export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 - -srun --ntasks-per-node=1 --gpus-per-task=4 \\ - torchrun \\ - --nnodes={num_nodes} \\ - --nproc_per_node={nproc_per_node} \\ - --rdzv_id=$SLURM_JOB_ID \\ - --rdzv_backend=c10d \\ - --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \\ - -m src.inference_cellpose_v3 \\ - --input-dir "{input_dir}" \\ - --output-dir "{output_dir}" \\ - --finetuned-checkpoint "{cellpose_checkpoint}" \\ - --save-overlay - -SEG_STATUS=$? - -END_TIME=$(date +%s) -DURATION=$((END_TIME - START_TIME)) -MINUTES=$((DURATION / 60)) -SECONDS=$((DURATION % 60)) - -echo "" -echo "============================================================" -echo "CELLPOSE SEGMENTATION COMPLETED: $(date)" -echo "============================================================" -echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" -echo "Images processed: ${{NUM_IMAGES}}" -echo "Exit status: $SEG_STATUS" -echo "============================================================" - -chmod -R 2775 {output_dir} 2>/dev/null || true - -exit $SEG_STATUS -""" - try: - logger.info("Submitting Cellpose segmentation job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - - time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("Cellpose segmentation job completed successfully.") - return True - - except Exception as e: - logger.error(f"Error during Cellpose segmentation job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Cellpose segmentation job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False - def seg_extract_regions( self, recon_folder_path: str = "", @@ -2504,7 +2323,7 @@ def nersc_forge_recon_multisegment_flow( num_nodes: Optional[int] = None, ) -> bool: """ - Transfer raw data to NERSC, run reconstruction, then run SAM3, DINO, and Cellpose + Transfer raw data to NERSC, run reconstruction, then run SAM3 and DINOv3 segmentation concurrently, followed by a combination step. :param file_path: The path to the file to be processed. @@ -2542,7 +2361,6 @@ def nersc_forge_recon_multisegment_flow( nersc_reconstruction_success = False sam3_success = False dino_success = False - # cellpose_success = False data832_tiff_transfer_success = False # data832_segment_transfer_success = False @@ -2599,8 +2417,8 @@ def nersc_forge_recon_multisegment_flow( logger.error(f"Failed to transfer TIFFs to data832: {e}") data832_tiff_transfer_success = False - # ── STEP 3: SAM3 / DINO / Cellpose concurrently ────────────────────────── - logger.info("Submitting SAM3, DINO, and Cellpose segmentation tasks concurrently.") + # ── STEP 3: SAM3 / DINOv3 ────────────────────────── + logger.info("Submitting SAM3 and DINOv3 segmentation tasks concurrently.") sam3_future = nersc_segmentation_task.submit( recon_folder_path=scratch_path_tiff, config=config @@ -2608,9 +2426,6 @@ def nersc_forge_recon_multisegment_flow( dino_future = nersc_segmentation_dino_task.submit( recon_folder_path=scratch_path_tiff, config=config ) - # cellpose_future = nersc_segmentation_cellpose_task.submit( - # recon_folder_path=scratch_path_tiff, config=config - # ) # ── STEP 4: Transfer each model's output as it completes ───────────────── sam3_result = sam3_future.result() @@ -2644,50 +2459,11 @@ def nersc_forge_recon_multisegment_flow( except Exception as e: logger.error(f"Failed to transfer DINO outputs to data832: {e}") - # cellpose_success = cellpose_future.result() - # logger.info(f"Cellpose segmentation result: {cellpose_success}") - # if cellpose_success: - # logger.info("Transferring Cellpose segmentation outputs to data832") - # cellpose_segment_path = f"{folder_name}/seg{file_name}/cellpose" - # try: - # data832_cellpose_transfer_success = transfer_controller.copy( - # file_path=cellpose_segment_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.data832_scratch - # ) - # logger.info(f"Cellpose transfer to data832 success: {data832_cellpose_transfer_success}") - # except Exception as e: - # logger.error(f"Failed to transfer Cellpose outputs to data832: {e}") - - # any_seg_success = any([sam3_success, dino_success, cellpose_success]) any_seg_success = any([sam3_success, dino_success]) - # logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}, Cellpose: {cellpose_success}") logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}") - # ── STEP 5: Combine (after all three complete) ──────────────────────────── - # if dino_success and (sam3_success or cellpose_success): - # if dino_success and sam3_success: - # logger.info("Running segmentation combination (SAM3+DINO and Cellpose+DINO).") - # combine_success = controller.combine_segmentations(recon_folder_path=scratch_path_tiff) - # logger.info(f"Combination result: {combine_success}") - # if combine_success: - # logger.info("Transferring combined segmentation outputs to data832") - # combined_segment_path = f"{folder_name}/seg{file_name}/combined" - # try: - # data832_combined_transfer_success = transfer_controller.copy( - # file_path=combined_segment_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.data832_scratch - # ) - # logger.info(f"Combined transfer to data832 success: {data832_combined_transfer_success}") - # except Exception as e: - # logger.error(f"Failed to transfer combined outputs to data832: {e}") - # else: - # logger.warning("Skipping combination: requires DINO plus at least one of SAM3/Cellpose.") - # ── STEP 5: Combine + Extract Regions concurrently (after SAM3+DINO complete) ── - # if dino_success and (sam3_success or cellpose_success): if dino_success and sam3_success: logger.info("Running segmentation combination and region extraction concurrently.") @@ -2775,7 +2551,6 @@ def nersc_forge_recon_multisegment_flow( check_endpoint=config.data832_scratch if any([ data832_sam3_transfer_success, data832_dino_transfer_success, - # data832_cellpose_transfer_success ]) else None, days_from_now=1.0 ) @@ -2795,7 +2570,6 @@ def nersc_forge_recon_multisegment_flow( if any([data832_sam3_transfer_success, data832_dino_transfer_success, - # data832_cellpose_transfer_success, data832_combined_transfer_success]): try: prune_controller.prune( @@ -2813,127 +2587,11 @@ def nersc_forge_recon_multisegment_flow( else: logger.warning( f"Flow completed with issues: recon={nersc_reconstruction_success}, " - # f"sam3={sam3_success}, dino={dino_success}, cellpose={cellpose_success}" f"sam3={sam3_success}, dino={dino_success}" ) return False - # sam3_result = sam3_future.result() - # dino_success = dino_future.result() - # cellpose_success = cellpose_future.result() - - # # nersc_segmentation_task (SAM3) returns a dict - # if isinstance(sam3_result, dict): - # sam3_success = sam3_result.get('success', False) - # else: - # sam3_success = bool(sam3_result) - - # logger.info( - # f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}, Cellpose: {cellpose_success}" - # ) - - # any_seg_success = any([sam3_success, dino_success, cellpose_success]) - - # # ── STEP 4: Combine (sync, after all three complete) ───────────────────── - # if dino_success and (sam3_success or cellpose_success): - # logger.info("Running segmentation combination (SAM3+DINO and Cellpose+DINO).") - # combine_success = controller.combine_segmentations( - # recon_folder_path=scratch_path_tiff - # ) - # logger.info(f"Combination result: {combine_success}") - # else: - # logger.warning("Skipping combination: requires DINO plus at least one of SAM3/Cellpose.") - - # # ── STEP 5: Transfer segmentation outputs to data832 ───────────────────── - # if any_seg_success: - # logger.info("Transferring segmentation outputs from NERSC pscratch to data832") - # try: - # data832_segment_transfer_success = transfer_controller.copy( - # file_path=scratch_path_segment, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.data832_scratch - # ) - # logger.info(f"Transfer segmented data to data832 success: {data832_segment_transfer_success}") - # except Exception as e: - # logger.error(f"Failed to transfer segmented data to data832: {e}") - # data832_segment_transfer_success = False - - # # ── STEP 6: Pruning ─────────────────────────────────────────────────────── - # logger.info("Scheduling file pruning tasks.") - # prune_controller = get_prune_controller(prune_type=PruneMethod.GLOBUS, config=config) - - # logger.info("Scheduling pruning of NERSC pscratch raw data.") - # try: - # prune_controller.prune( - # file_path=f"{folder_name}/{path.name}", - # source_endpoint=config.nersc832_alsdev_pscratch_raw, - # check_endpoint=None, - # days_from_now=1.0 - # ) - # except Exception as e: - # logger.warning(f"Failed to schedule raw data pruning: {e}") - - # if nersc_reconstruction_success: - # logger.info("Scheduling pruning of NERSC pscratch reconstruction data.") - # try: - # prune_controller.prune( - # file_path=scratch_path_tiff, - # source_endpoint=config.nersc832_alsdev_pscratch_scratch, - # check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, - # days_from_now=1.0 - # ) - # except Exception as e: - # logger.warning(f"Failed to schedule reconstruction data pruning: {e}") - - # if any_seg_success: - # logger.info("Scheduling pruning of NERSC pscratch segmentation data.") - # try: - # prune_controller.prune( - # file_path=scratch_path_segment, - # source_endpoint=config.nersc832_alsdev_pscratch_scratch, - # check_endpoint=config.data832_scratch if data832_segment_transfer_success else None, - # days_from_now=1.0 - # ) - # except Exception as e: - # logger.warning(f"Failed to schedule segmentation data pruning: {e}") - - # if data832_tiff_transfer_success: - # logger.info("Scheduling pruning of data832 scratch reconstruction TIFF data.") - # try: - # prune_controller.prune( - # file_path=scratch_path_tiff, - # source_endpoint=config.data832_scratch, - # check_endpoint=None, - # days_from_now=30.0 - # ) - # except Exception as e: - # logger.warning(f"Failed to schedule data832 tiff pruning: {e}") - - # if data832_segment_transfer_success: - # logger.info("Scheduling pruning of data832 scratch segmentation data.") - # try: - # prune_controller.prune( - # file_path=scratch_path_segment, - # source_endpoint=config.data832_scratch, - # check_endpoint=None, - # days_from_now=30.0 - # ) - # except Exception as e: - # logger.warning(f"Failed to schedule data832 segment pruning: {e}") - - # # TODO: ingest to scicat - - # if nersc_reconstruction_success and any_seg_success: - # logger.info("NERSC reconstruction + multi-segmentation flow completed successfully.") - # return True - # else: - # logger.warning( - # f"Flow completed with issues: recon={nersc_reconstruction_success}, " - # f"sam3={sam3_success}, dino={dino_success}, cellpose={cellpose_success}" - # ) - # return False - @flow(name="nersc_streaming_flow", on_cancellation=[cancellation_hook]) def nersc_streaming_flow( @@ -3131,25 +2789,6 @@ def nersc_segmentation_dino_task( return success -@task(name="nersc_segmentation_cellpose_task") -def nersc_segmentation_cellpose_task( - recon_folder_path: str, - config: Optional[Config832] = None, -) -> bool: - logger = get_run_logger() - if config is None: - logger.info("No config provided, using default Config832.") - config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) - logger.info(f"Starting NERSC Cellpose segmentation task for {recon_folder_path=}") - success = tomography_controller.segmentation_cellpose(recon_folder_path=recon_folder_path) - if not success: - logger.error("Cellpose segmentation failed.") - else: - logger.info("Cellpose segmentation successful.") - return success - - @task(name="nersc_extract_regions_task") def nersc_extract_regions_task( recon_folder_path: str, From b73bfe5cb5efad7a590dda3ac729016a70cc9ad4 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Mar 2026 11:01:03 -0700 Subject: [PATCH 50/72] removing extract_regions flow (replaced by the combine step) --- orchestration/flows/bl832/nersc.py | 199 ----------------------------- 1 file changed, 199 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index c545d645..19274db4 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1198,169 +1198,6 @@ def segmentation_dino( else: return False - def seg_extract_regions( - self, - recon_folder_path: str = "", - ) -> bool: - """ - Extract Hydrated and Dehydrated Xylem regions using DINO semantic masks - and SAM3 instance masks at NERSC Perlmutter via SFAPI Slurm job. - - :param recon_folder_path: Relative path to the reconstructed data folder, - e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' - :return: True if the job completed successfully, False otherwise. - """ - logger.info("Starting NERSC region extraction process.") - - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" - conda_env_path = f"{cfs_path}/envs/dino_demo" - - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo" - - seg_folder = recon_folder_path.replace("/rec", "/seg") - input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" - seg_base = f"{pscratch_path}/8.3.2/scratch/{seg_folder}" - - sam3_results = f"{seg_base}/sam3" - dino_results = f"{seg_base}/dino" - combined_output = f"{seg_base}/combined" - extract_output = f"{combined_output}/extract_regions" - - logger.info(f"Extract regions input dir: {input_dir}") - logger.info(f"Extract regions output dir: {extract_output}") - - EXTRACT_DEFAULTS = { - "defaults": True, - "num_nodes": 8, - "qos": "regular", - "account": "amsc006", - "constraint": "cpu", - "walltime": "01:00:00", - "dilate_px": 5, - "xylem_mode": "all", - } - try: - seg_options = Variable.get("nersc-extract-regions-options", default={"defaults": True}, _sync=True) - if isinstance(seg_options, str): - import json - seg_options = json.loads(seg_options) - except Exception as e: - logger.warning(f"Could not load nersc-extract-regions-options: {e}. Using defaults.") - seg_options = {"defaults": True} - - use_defaults = seg_options.get("defaults", True) - opts = EXTRACT_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in EXTRACT_DEFAULTS.items()} - - num_nodes = opts["num_nodes"] - qos = opts["qos"] - account = opts["account"] - constraint = opts["constraint"] - walltime = opts["walltime"] - dilate_px = opts["dilate_px"] - xylem_mode = opts["xylem_mode"] - - job_name = f"extract_{Path(recon_folder_path).name}" - - job_script = f"""#!/bin/bash -#SBATCH -q {qos} -#SBATCH -A {account} -#SBATCH -N {num_nodes} -#SBATCH -C {constraint} -#SBATCH --reservation=_CAP_MarchModCon_GPU -#SBATCH --job-name={job_name} -#SBATCH --time={walltime} -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=128 -#SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out -#SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err - -module load conda -conda activate {conda_env_path} - -mkdir -p {extract_output} -mkdir -p {pscratch_path}/tomo_seg_logs - -echo "============================================================" -echo "REGION EXTRACTION STARTED: $(date)" -echo "============================================================" -echo "Input: {input_dir}" -echo "SAM3 masks: {sam3_results}" -echo "DINO masks: {dino_results}/semantic_masks" -echo "Output: {extract_output}" -echo "Xylem mode: {xylem_mode}" -echo "Dilate px: {dilate_px}" -echo "============================================================" - -START_TIME=$(date +%s) - -cd {seg_scripts_dir} - -# Extract Hydrated and Dehydrated Xylem regions -python -m src.extract_region \\ - --input-dir "{input_dir}" \\ - --semantic-masks-dir "{dino_results}/semantic_masks" \\ - --instance-masks-dir "{sam3_results}" \\ - --output-dir "{extract_output}" \\ - --xylem-mode {xylem_mode} \\ - --dilate-px {dilate_px} - -EXTRACT_STATUS=$? - -END_TIME=$(date +%s) -DURATION=$((END_TIME - START_TIME)) -MINUTES=$((DURATION / 60)) -SECONDS=$((DURATION % 60)) - -echo "" -echo "============================================================" -echo "REGION EXTRACTION COMPLETED: $(date)" -echo "============================================================" -echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" -echo "Exit status: $EXTRACT_STATUS" -echo "============================================================" - -chmod -R 2775 {extract_output} 2>/dev/null || true - -exit $EXTRACT_STATUS -""" - try: - logger.info("Submitting region extraction job to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - - time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() - logger.info("Region extraction job completed successfully.") - return True - - except Exception as e: - logger.error(f"Error during region extraction job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.compute(Machine.perlmutter).job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Region extraction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False - def combine_segmentations( self, recon_folder_path: str = "", @@ -2470,9 +2307,6 @@ def nersc_forge_recon_multisegment_flow( combine_future = nersc_combine_segmentations_task.submit( recon_folder_path=scratch_path_tiff, config=config ) - # extract_future = nersc_extract_regions_task.submit( - # recon_folder_path=scratch_path_tiff, config=config - # ) combine_success = combine_future.result() logger.info(f"Combination result: {combine_success}") @@ -2489,20 +2323,6 @@ def nersc_forge_recon_multisegment_flow( except Exception as e: logger.error(f"Failed to transfer combined outputs to data832: {e}") - # extract_success = extract_future.result() - # logger.info(f"Region extraction result: {extract_success}") - # if extract_success: - # logger.info("Transferring extracted region outputs to data832") - # extract_segment_path = f"{folder_name}/seg{file_name}/combined/extract_regions" - # try: - # data832_extract_transfer_success = transfer_controller.copy( - # file_path=extract_segment_path, - # source=config.nersc832_alsdev_pscratch_scratch, - # destination=config.data832_scratch - # ) - # logger.info(f"Extract regions transfer to data832 success: {data832_extract_transfer_success}") - # except Exception as e: - # logger.error(f"Failed to transfer extracted region outputs to data832: {e}") else: logger.warning("Skipping combination and extraction: requires DINO plus SAM3.") @@ -2789,25 +2609,6 @@ def nersc_segmentation_dino_task( return success -@task(name="nersc_extract_regions_task") -def nersc_extract_regions_task( - recon_folder_path: str, - config: Optional[Config832] = None, -) -> bool: - logger = get_run_logger() - if config is None: - logger.info("No config provided, using default Config832.") - config = Config832() - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) - logger.info(f"Starting NERSC region extraction task for {recon_folder_path=}") - success = tomography_controller.seg_extract_regions(recon_folder_path=recon_folder_path) - if not success: - logger.error("Region extraction failed.") - else: - logger.info("Region extraction successful.") - return success - - @task(name="nersc_combine_segmentations_task") def nersc_combine_segmentations_task( recon_folder_path: str, From 58fe7c53929f01130705f9ad7e411b64af56a2ae Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Mar 2026 11:36:00 -0700 Subject: [PATCH 51/72] renaming segmantion flows/tasks to segmentation_sam3 to differentiate from dino --- orchestration/flows/bl832/nersc.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 19274db4..c8da4218 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -705,7 +705,7 @@ def build_multi_resolution_optimize( else: return False - def segmentation( + def segmentation_sam3( self, recon_folder_path: str = "", num_nodes: int = 42, @@ -2031,8 +2031,8 @@ def nersc_forge_recon_segment_flow( data832_tiff_transfer_success = False # STEP 4: Run the Segmentation Task at NERSC - logger.info(f"Starting NERSC segmentation task for {scratch_path_tiff=}") - seg_result = nersc_segmentation_task( + logger.info(f"Starting NERSC SAM3 segmentation task for {scratch_path_tiff=}") + seg_result = nersc_segmentation_sam3_task( recon_folder_path=scratch_path_tiff, config=config ) @@ -2257,7 +2257,7 @@ def nersc_forge_recon_multisegment_flow( # ── STEP 3: SAM3 / DINOv3 ────────────────────────── logger.info("Submitting SAM3 and DINOv3 segmentation tasks concurrently.") - sam3_future = nersc_segmentation_task.submit( + sam3_future = nersc_segmentation_sam3_task.submit( recon_folder_path=scratch_path_tiff, config=config ) dino_future = nersc_segmentation_dino_task.submit( @@ -2300,9 +2300,9 @@ def nersc_forge_recon_multisegment_flow( logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}") - # ── STEP 5: Combine + Extract Regions concurrently (after SAM3+DINO complete) ── + # ── STEP 5: Combine Segmentations (after SAM3+DINO complete) ── if dino_success and sam3_success: - logger.info("Running segmentation combination and region extraction concurrently.") + logger.info("Running segmentation combination.") combine_future = nersc_combine_segmentations_task.submit( recon_folder_path=scratch_path_tiff, config=config @@ -2556,8 +2556,8 @@ def nersc_multiresolution_integration_test() -> bool: return flow_success -@task(name="nersc_segmentation_task") -def nersc_segmentation_task( +@task(name="nersc_segmentation_sam3_task") +def nersc_segmentation_sam3_task( recon_folder_path: str, config: Optional[Config832] = None, ) -> bool: @@ -2580,7 +2580,7 @@ def nersc_segmentation_task( config=config ) logger.info(f"Starting NERSC segmentation task for {recon_folder_path=}") - nersc_segmentation_success = tomography_controller.segmentation( + nersc_segmentation_success = tomography_controller.segmentation_sam3( recon_folder_path=recon_folder_path, ) if not nersc_segmentation_success: @@ -2628,17 +2628,17 @@ def nersc_combine_segmentations_task( return success -@flow(name="nersc_segmentation_integration_test", flow_run_name="nersc_segmentation_integration_test") -def nersc_segmentation_integration_test() -> bool: +@flow(name="nersc_segmentation_sam3_integration_test", flow_run_name="nersc_segmentation_sam3_integration_test") +def nersc_segmentation_sam3_integration_test() -> bool: """ - Integration test for the NERSC segmentation task. + Integration test for the NERSC SAM3 segmentation task. :return: True if the segmentation task completed successfully, False otherwise. """ logger = get_run_logger() - logger.info("Starting NERSC segmentation integration test.") + logger.info("Starting NERSC SAM3 segmentation integration test.") recon_folder_path = 'synaps-i/rec20211222_125057_petiole4' # 'test' # - flow_success = nersc_segmentation_task( + flow_success = nersc_segmentation_sam3_task( recon_folder_path=recon_folder_path, config=Config832() ) From 827d7c79286d27f1180042fd8f3a64985d2d2a9a Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Mar 2026 11:40:36 -0700 Subject: [PATCH 52/72] removing commented code --- orchestration/flows/bl832/nersc.py | 184 +---------------------------- 1 file changed, 3 insertions(+), 181 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index c8da4218..c80c1e87 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -871,7 +871,7 @@ def segmentation_sam3( # Change to script directory cd {seg_scripts_dir} -# Run inference with v5 +# Run inference with v6 export TORCH_DISTRIBUTED_DEBUG=DETAIL export NCCL_DEBUG=INFO export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 @@ -933,7 +933,7 @@ def segmentation_sam3( """ try: - logger.info("Submitting segmentation job to Perlmutter (v5).") + logger.info("Submitting segmentation job to Perlmutter (v6).") perlmutter = self.client.compute(Machine.perlmutter) # Ensure directories exist @@ -2408,7 +2408,6 @@ def nersc_forge_recon_multisegment_flow( logger.warning( f"Flow completed with issues: recon={nersc_reconstruction_success}, " f"sam3={sam3_success}, dino={dino_success}" - ) return False @@ -2425,7 +2424,7 @@ def nersc_streaming_flow( controller: NERSCTomographyHPCController = get_controller( hpc_type=HPC.NERSC, config=config - ) # type: ignore + ) job_id = controller.start_streaming_service(walltime=walltime) save_block(SlurmJobBlock(job_id=job_id)) @@ -2644,180 +2643,3 @@ def nersc_segmentation_sam3_integration_test() -> bool: ) logger.info(f"Flow success: {flow_success}") return flow_success - - -if __name__ == "__main__": - - # nersc_multiresolution_integration_test() - nersc_tiff_to_zarr_task( - file_path='DD-00842_hexemer/20260222_122341_petiole51.h5', - config=Config832() - ) - - # Run the integration test flow - - # from sfapi_client import Client - # from sfapi_client.compute import Machine - - # # Use your existing client setup - # client = NERSCTomographyHPCController.create_sfapi_client() - # perlmutter = client.compute(Machine.perlmutter) - - # job = perlmutter.job(jobid=48781402) - # job.cancel() - # print(f"Job {job.jobid} cancelled, state: {job.state}") - - # job = perlmutter.job(jobid=48778803) - # job.cancel() - # print(f"Job {job.jobid} cancelled, state: {job.state}") - - # job = perlmutter.job(jobid=48777760) - # job.cancel() - # print(f"Job {job.jobid} cancelled, state: {job.state}") - - # nersc_forge_recon_segment_flow('/global/raw/raw/DD-00842_hexemer/20260212_110324_petiole24.h5') - # result = nersc_segmentation_integration_test() - # print(f"Integration test result: {result}") - - -# if __name__ == "__main__": - -# config = Config832() - - # pull_shifter_image_flow(config=config) - - # # Fibers ------------------------------------------ - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20251218_111600_silkraw.h5", - # num_nodes=4, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - # num_nodes=8, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 8 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - # num_nodes=4, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - # num_nodes=2, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 2 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230215_135338_PET_Al_PP_Al2O3_fibers_in_glass_pipette.h5", - # num_nodes=1, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") - # print(f"Total reconstruction time with 1 node: {end - start} seconds") - - # # # # Fungi ------------------------------------------ - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", - # num_nodes=8, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 8 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", - # num_nodes=4, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", - # num_nodes=2, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 2 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20230606_151124_jong-seto_fungal-mycelia_roll-AQ_fungi1_fast.h5", - # num_nodes=1, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") - # print(f"Total reconstruction time with 1 node: {end - start} seconds") - - # # # # Silk ------------------------------------------ - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20251218_111600_silkraw.h5", - # num_nodes=8, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 8 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 8 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20251218_111600_silkraw.h5", - # num_nodes=4, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 4 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 4 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20251218_111600_silkraw.h5", - # num_nodes=2, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 2 nodes: {end - start} seconds") - # print(f"Total reconstruction time with 2 nodes: {end - start} seconds") - - # start = time.time() - # nersc_recon_flow( - # file_path="dabramov/20251218_111600_silkraw.h5", - # num_nodes=1, - # config=config - # ) - # end = time.time() - # logger.info(f"Total reconstruction time with 1 node: {end - start} seconds") - # print(f"Total reconstruction time with 1 node: {end - start} seconds") From 6e039838b037d0beb2e6f50e569c8f183ff49a52 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Mar 2026 11:45:17 -0700 Subject: [PATCH 53/72] removing multiresolution multinode optimization efforts from this PR --- orchestration/flows/bl832/nersc.py | 150 +---------------------------- 1 file changed, 1 insertion(+), 149 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index c80c1e87..f5556edb 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -547,124 +547,6 @@ def build_multi_resolution( {multires_image} \ bash -c "python tiff_to_zarr.py {recon_path} --raw_file {raw_path}" -date -""" - try: - logger.info("Submitting Tiff to Zarr job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - - time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - - return True - - except Exception as e: - logger.warning(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - return False - - def build_multi_resolution_optimize( - self, - file_path: str = "", - num_nodes: int = 4, - ) -> bool: - """ - Use NERSC to make multiresolution version of tomography results with multi-node scaling. - """ - logger.info("Starting NERSC multiresolution process (multi-node).") - - user = self.client.user() - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - - multires_image = self.config.ghcr_images832["multires_image"] - recon_scripts_dir = self.config.nersc832_alsdev_recon_scripts.root_path - - path = Path(file_path) - folder_name = path.parent.name - file_name = path.stem - - recon_path = f"scratch/{folder_name}/rec{file_name}/" - raw_path = f"raw/{folder_name}/{file_name}.h5" - - # Scale time with nodes (less time needed with more workers) - walltime = "0:30:00" if num_nodes <= 4 else "0:15:00" - - job_script = f"""#!/bin/bash -#SBATCH -q realtime -#SBATCH -A als -#SBATCH -C cpu -#SBATCH --job-name=tomo_multires_{folder_name}_{file_name} -#SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out -#SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err -#SBATCH -N {num_nodes} -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=128 -#SBATCH --time={walltime} -#SBATCH --exclusive - -date -echo "Starting multi-node Zarr conversion with {num_nodes} nodes" - -# Get scheduler node -SCHEDULER_NODE=$(hostname) -SCHEDULER_PORT=8786 -SCHEDULER_ADDR="tcp://$SCHEDULER_NODE:$SCHEDULER_PORT" - -echo "Scheduler will run on: $SCHEDULER_ADDR" - -# Start Dask scheduler on first node -srun --nodes=1 --ntasks=1 --exclusive podman-hpc run \\ - --volume {pscratch_path}/8.3.2:/alsdata \\ - {multires_image} \\ - dask scheduler --port $SCHEDULER_PORT & - -SCHEDULER_PID=$! -sleep 10 # Give scheduler time to start - -# Start Dask workers on all nodes -for i in $(seq 1 {num_nodes}); do - srun --nodes=1 --ntasks=1 --exclusive podman-hpc run \\ - --env NUMEXPR_MAX_THREADS=128 \\ - --env OMP_NUM_THREADS=128 \\ - --volume {pscratch_path}/8.3.2:/alsdata \\ - {multires_image} \\ - dask worker $SCHEDULER_ADDR --nthreads 32 --nworkers 4 --memory-limit 60GB & -done - -sleep 15 # Give workers time to connect - -# Run the conversion script, connecting to the cluster -srun --nodes=1 --ntasks=1 podman-hpc run \\ - --volume {recon_scripts_dir}/tiff_to_zarr_multinode.py:/alsuser/tiff_to_zarr_multinode.py \\ - --volume {pscratch_path}/8.3.2:/alsdata \\ - --volume {pscratch_path}/8.3.2:/alsuser/ \\ - {multires_image} \\ - bash -c "python tiff_to_zarr_multinode.py {recon_path} --raw_file {raw_path} --scheduler $SCHEDULER_ADDR" -wait date """ try: @@ -2497,7 +2379,7 @@ def nersc_multiresolution_task( config=config ) logger.info(f"Starting NERSC multiresolution task for {file_path=}") - nersc_multiresolution_success = tomography_controller.build_multi_resolution_optimize( + nersc_multiresolution_success = tomography_controller.build_multi_resolution( file_path=file_path, ) if not nersc_multiresolution_success: @@ -2507,36 +2389,6 @@ def nersc_multiresolution_task( return nersc_multiresolution_success -@task(name="nersc_tiff_to_zarr_task") -def nersc_tiff_to_zarr_task( - file_path: str, - config: Optional[Config832] = None, -) -> bool: - """ - Run tiff-to-zarr (single-node) multiresolution task at NERSC. - - :param file_path: Path to the raw .h5 file (used to derive recon path). - :param config: Configuration object for the flow. - :return: True if the task completed successfully, False otherwise. - """ - logger = get_run_logger() - if config is None: - logger.info("No config provided, using default Config832.") - config = Config832() - - logger.info("Initializing NERSC Tomography HPC Controller.") - tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) - - logger.info(f"Starting NERSC tiff-to-zarr task for {file_path=}") - success = tomography_controller.build_multi_resolution(file_path=file_path) - - if not success: - logger.error("Tiff-to-zarr failed.") - else: - logger.info("Tiff-to-zarr successful.") - return success - - @flow(name="nersc_multiresolution_integration_test", flow_run_name="nersc_multiresolution_integration_test") def nersc_multiresolution_integration_test() -> bool: """ From eee373327b909af9dfd56a38cdd02086236198b8 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Mar 2026 13:32:55 -0700 Subject: [PATCH 54/72] Adding pytests for bl832/nersc.py: reconstruction, segmentation, multiresolution --- orchestration/flows/bl832/nersc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index f5556edb..592040c9 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -2081,7 +2081,9 @@ def nersc_forge_recon_multisegment_flow( sam3_success = False dino_success = False data832_tiff_transfer_success = False - # data832_segment_transfer_success = False + data832_sam3_transfer_success = False + data832_dino_transfer_success = False + data832_combined_transfer_success = False # ── STEP 1: Multinode Reconstruction ───────────────────────────────────── logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") From 87e107d74fdf247da79647d573ed0dae844f7936 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 13 Mar 2026 13:33:21 -0700 Subject: [PATCH 55/72] Adding pytests for bl832/nersc.py: reconstruction, segmentation, multiresolution --- orchestration/_tests/test_bl832/__init__.py | 0 orchestration/_tests/test_bl832/test_nersc.py | 689 ++++++++++++++++++ 2 files changed, 689 insertions(+) create mode 100644 orchestration/_tests/test_bl832/__init__.py create mode 100644 orchestration/_tests/test_bl832/test_nersc.py diff --git a/orchestration/_tests/test_bl832/__init__.py b/orchestration/_tests/test_bl832/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py new file mode 100644 index 00000000..69400ab4 --- /dev/null +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -0,0 +1,689 @@ +# orchestration/_tests/bl832/test_nersc.py + +import pytest +from uuid import uuid4 + +from prefect.blocks.system import Secret +from prefect.testing.utilities import prefect_test_harness + + +# ────────────────────────────────────────────────────────────────────────────── +# Session fixture +# ────────────────────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True, scope="session") +def prefect_test_fixture(): + """Set up Prefect test harness and required secrets for the full session.""" + with prefect_test_harness(): + Secret(value=str(uuid4())).save(name="globus-client-id", overwrite=True) + Secret(value=str(uuid4())).save(name="globus-client-secret", overwrite=True) + yield + + +# ────────────────────────────────────────────────────────────────────────────── +# Shared fixtures +# ────────────────────────────────────────────────────────────────────────────── + +@pytest.fixture +def mock_sfapi_client(mocker): + """Mock sfapi_client.Client with a completed job on Perlmutter.""" + mock_client = mocker.MagicMock() + + mock_user = mocker.MagicMock() + mock_user.name = "testuser" + mock_client.user.return_value = mock_user + + mock_job = mocker.MagicMock() + mock_job.jobid = "12345" + mock_job.state = "COMPLETED" + + mock_compute = mocker.MagicMock() + mock_compute.submit_job.return_value = mock_job + mock_client.compute.return_value = mock_compute + + mocker.patch("orchestration.flows.bl832.nersc.Client", return_value=mock_client) + return mock_client + + +@pytest.fixture +def mock_config832(mocker): + """ + Mock Config832 constructor so any call to Config832() returns our mock. + Tests that call flows must pass config=None so Prefect's type validation + is never given a MagicMock — the flow will call Config832() internally and + get our mock back. + """ + mock_config = mocker.MagicMock() + + mock_config.ghcr_images832 = { + "recon_image": "mock_recon_image", + "multires_image": "mock_multires_image", + } + + for attr in [ + "nersc832_alsdev_raw", + "nersc832_alsdev_scratch", + "nersc832_alsdev_recon_scripts", + "nersc832_alsdev_pscratch_scratch", + "nersc832_alsdev_pscratch_raw", + "data832_scratch", + ]: + ep = mocker.MagicMock() + ep.root_path = f"/mock/{attr}" + setattr(mock_config, attr, ep) + + mock_config.nersc_recon_num_nodes = 4 + + mocker.patch("orchestration.flows.bl832.nersc.Config832", return_value=mock_config) + return mock_config + + +@pytest.fixture +def mock_recon_success(): + return {"success": True, "job_id": "11111", "timing": None} + + +@pytest.fixture +def mock_seg_sam3_success(): + return {"success": True, "job_id": "22222", "timing": None, "output_dir": "/out/sam3"} + + +def _make_future(mocker, value): + """Return a mock Prefect future whose .result() yields the given value.""" + f = mocker.MagicMock() + f.result.return_value = value + return f + + +# ────────────────────────────────────────────────────────────────────────────── +# create_sfapi_client +# ────────────────────────────────────────────────────────────────────────────── + +def test_create_sfapi_client_success(mocker): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=True) + mocker.patch("builtins.open", side_effect=[ + mocker.mock_open(read_data="client_id_value")(), + mocker.mock_open(read_data='{"key": "value"}')(), + ]) + mocker.patch("orchestration.flows.bl832.nersc.JsonWebKey.import_key", return_value="mock_secret") + mock_client_cls = mocker.patch("orchestration.flows.bl832.nersc.Client") + + client = NERSCTomographyHPCController.create_sfapi_client() + + mock_client_cls.assert_called_once_with("client_id_value", "mock_secret") + assert client == mock_client_cls.return_value + + +def test_create_sfapi_client_missing_paths(mocker): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", return_value=None) + + with pytest.raises(ValueError, match="Missing NERSC credentials paths."): + NERSCTomographyHPCController.create_sfapi_client() + + +def test_create_sfapi_client_missing_files(mocker): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.os.getenv", side_effect=lambda x: { + "PATH_NERSC_CLIENT_ID": "/path/to/client_id", + "PATH_NERSC_PRI_KEY": "/path/to/client_secret", + }.get(x)) + mocker.patch("orchestration.flows.bl832.nersc.os.path.isfile", return_value=False) + + with pytest.raises(FileNotFoundError, match="NERSC credential files are missing."): + NERSCTomographyHPCController.create_sfapi_client() + + +# ────────────────────────────────────────────────────────────────────────────── +# reconstruct +# ────────────────────────────────────────────────────────────────────────────── + +def test_reconstruct_success(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + + result = controller.reconstruct(file_path="folder/file.h5") + + mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + assert result is True + + +def test_reconstruct_submission_failure(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + + result = controller.reconstruct(file_path="folder/file.h5") + + assert result is False + + +# ────────────────────────────────────────────────────────────────────────────── +# build_multi_resolution +# ────────────────────────────────────────────────────────────────────────────── + +def test_build_multi_resolution_success(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + + result = controller.build_multi_resolution(file_path="folder/file.h5") + + mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + assert result is True + + +def test_build_multi_resolution_submission_failure(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Submission failed") + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + + result = controller.build_multi_resolution(file_path="folder/file.h5") + + assert result is False + + +# ────────────────────────────────────────────────────────────────────────────── +# segmentation_sam3 +# ────────────────────────────────────────────────────────────────────────────── + +def test_segmentation_sam3_success(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) + + result = controller.segmentation_sam3(recon_folder_path="folder/recfile") + + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + assert isinstance(result, dict) + assert result["success"] is True + assert result["job_id"] == "12345" + + +def test_segmentation_sam3_submission_failure(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("GPU queue full") + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + + result = controller.segmentation_sam3(recon_folder_path="folder/recfile") + + assert isinstance(result, dict) + assert result["success"] is False + assert result["job_id"] is None + + +def test_segmentation_sam3_uses_variable_options(mocker, mock_sfapi_client, mock_config832): + """Custom Prefect variable options should be forwarded into the job script.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={ + "defaults": False, + "batch_size": 8, + "patch_size": 512, + "confidence": [0.6, 0.7], + "overlap": 0.5, + "qos": "debug", + "account": "als_test", + "constraint": "gpu", + "checkpoint": "checkpoint_v7.pt", + }) + + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + mocker.patch.object(controller, "_fetch_seg_timing_from_output", return_value=None) + + captured_scripts = [] + original_return = mock_sfapi_client.compute.return_value.submit_job.return_value + + def capture_script(script): + captured_scripts.append(script) + return original_return + + mock_sfapi_client.compute.return_value.submit_job.side_effect = capture_script + + controller.segmentation_sam3(recon_folder_path="folder/recfile") + + assert captured_scripts, "submit_job was never called" + script = captured_scripts[0] + assert "checkpoint_v7.pt" in script + assert "--patch-size 512" in script + assert "0.6 0.7" in script + assert "--overlap-ratio 0.5" in script + assert "#SBATCH -q debug" in script + assert "#SBATCH -A als_test" in script + + +# ────────────────────────────────────────────────────────────────────────────── +# segmentation_dino +# ────────────────────────────────────────────────────────────────────────────── + +def test_segmentation_dino_success(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + + result = controller.segmentation_dino(recon_folder_path="folder/recfile") + + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + assert result is True + + +def test_segmentation_dino_submission_failure(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("No GPU nodes") + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + + result = controller.segmentation_dino(recon_folder_path="folder/recfile") + + assert result is False + + +def test_segmentation_dino_output_paths(mocker, mock_sfapi_client, mock_config832): + """ + Output dir should swap /rec for /seg in the folder name and route to /dino. + + Given recon_folder_path="folder/recfile", the code does: + seg_folder = "folder/recfile".replace("/rec", "/seg") → "folder/segfile" + output_dir = ".../scratch/folder/segfile/dino" + So the script contains "segfile" and "/dino", not a literal "/seg/" segment. + """ + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + + captured_scripts = [] + original_return = mock_sfapi_client.compute.return_value.submit_job.return_value + + def capture(script): + captured_scripts.append(script) + return original_return + + mock_sfapi_client.compute.return_value.submit_job.side_effect = capture + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller.segmentation_dino(recon_folder_path="folder/recfile") + + script = captured_scripts[0] + # The rec→seg substitution turns "recfile" into "segfile" + assert "segfile" in script + assert "/dino" in script + + +# ────────────────────────────────────────────────────────────────────────────── +# combine_segmentations +# ────────────────────────────────────────────────────────────────────────────── + +def test_combine_segmentations_success(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + from sfapi_client.compute import Machine + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + + result = controller.combine_segmentations(recon_folder_path="folder/recfile") + + mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) + mock_sfapi_client.compute.return_value.submit_job.assert_called_once() + mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() + assert result is True + + +def test_combine_segmentations_submission_failure(mocker, mock_sfapi_client, mock_config832): + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("Cluster down") + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + + result = controller.combine_segmentations(recon_folder_path="folder/recfile") + + assert result is False + + +def test_combine_segmentations_script_references_sam3_and_dino(mocker, mock_sfapi_client, mock_config832): + """The combination job script should reference both model output directories.""" + from orchestration.flows.bl832.nersc import NERSCTomographyHPCController + + mocker.patch("orchestration.flows.bl832.nersc.time.sleep") + mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) + + captured_scripts = [] + original_return = mock_sfapi_client.compute.return_value.submit_job.return_value + + def capture(script): + captured_scripts.append(script) + return original_return + + mock_sfapi_client.compute.return_value.submit_job.side_effect = capture + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) + controller.combine_segmentations(recon_folder_path="folder/recfile") + + script = captured_scripts[0] + assert "/sam3" in script + assert "/dino" in script + assert "combine_sam_dino_v3" in script + assert "/combined" in script + + +# ────────────────────────────────────────────────────────────────────────────── +# Prefect tasks +# +# We call task.fn() to bypass Prefect's task runner machinery, but the task +# bodies call get_run_logger() which requires an active flow/task run context. +# Patching it at the module level avoids MissingContextError without needing +# a full prefect_test_harness flow run. +# ────────────────────────────────────────────────────────────────────────────── + +def test_nersc_segmentation_sam3_task_success(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_segmentation_sam3_task + + mocker.patch("orchestration.flows.bl832.nersc.get_run_logger", return_value=mocker.MagicMock()) + mock_controller = mocker.MagicMock() + mock_controller.segmentation_sam3.return_value = { + "success": True, "job_id": "99", "timing": None, "output_dir": "/out" + } + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + result = nersc_segmentation_sam3_task.fn( + recon_folder_path="folder/recfile", + config=mock_config832 + ) + + mock_controller.segmentation_sam3.assert_called_once_with(recon_folder_path="folder/recfile") + assert result == {"success": True, "job_id": "99", "timing": None, "output_dir": "/out"} + + +def test_nersc_segmentation_sam3_task_failure(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_segmentation_sam3_task + + mocker.patch("orchestration.flows.bl832.nersc.get_run_logger", return_value=mocker.MagicMock()) + mock_controller = mocker.MagicMock() + mock_controller.segmentation_sam3.return_value = { + "success": False, "job_id": None, "timing": None, "output_dir": None + } + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + result = nersc_segmentation_sam3_task.fn( + recon_folder_path="folder/recfile", + config=mock_config832 + ) + + assert result["success"] is False + + +def test_nersc_segmentation_dino_task_success(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_segmentation_dino_task + + mocker.patch("orchestration.flows.bl832.nersc.get_run_logger", return_value=mocker.MagicMock()) + mock_controller = mocker.MagicMock() + mock_controller.segmentation_dino.return_value = True + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + result = nersc_segmentation_dino_task.fn( + recon_folder_path="folder/recfile", + config=mock_config832 + ) + + mock_controller.segmentation_dino.assert_called_once_with(recon_folder_path="folder/recfile") + assert result is True + + +def test_nersc_segmentation_dino_task_failure(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_segmentation_dino_task + + mocker.patch("orchestration.flows.bl832.nersc.get_run_logger", return_value=mocker.MagicMock()) + mock_controller = mocker.MagicMock() + mock_controller.segmentation_dino.return_value = False + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + result = nersc_segmentation_dino_task.fn( + recon_folder_path="folder/recfile", + config=mock_config832 + ) + + assert result is False + + +def test_nersc_combine_segmentations_task_success(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_combine_segmentations_task + + mocker.patch("orchestration.flows.bl832.nersc.get_run_logger", return_value=mocker.MagicMock()) + mock_controller = mocker.MagicMock() + mock_controller.combine_segmentations.return_value = True + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + result = nersc_combine_segmentations_task.fn( + recon_folder_path="folder/recfile", + config=mock_config832 + ) + + mock_controller.combine_segmentations.assert_called_once_with(recon_folder_path="folder/recfile") + assert result is True + + +def test_nersc_combine_segmentations_task_failure(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_combine_segmentations_task + + mocker.patch("orchestration.flows.bl832.nersc.get_run_logger", return_value=mocker.MagicMock()) + mock_controller = mocker.MagicMock() + mock_controller.combine_segmentations.return_value = False + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + result = nersc_combine_segmentations_task.fn( + recon_folder_path="folder/recfile", + config=mock_config832 + ) + + assert result is False + + +# ────────────────────────────────────────────────────────────────────────────── +# nersc_forge_recon_segment_flow +# +# Prefect validates the `config` parameter against Optional[Config832] at +# runtime, so passing a MagicMock raises ParameterTypeError. The fix is to +# pass config=None — the flow's `if config is None: config = Config832()` +# branch then runs, calling the already-mocked constructor and returning our +# mock_config832 instance. +# ────────────────────────────────────────────────────────────────────────────── + +def test_forge_recon_segment_flow_success(mocker, mock_config832, mock_recon_success, mock_seg_sam3_success): + from orchestration.flows.bl832.nersc import nersc_forge_recon_segment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct_multinode.return_value = mock_recon_success + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + mock_transfer = mocker.MagicMock() + mock_transfer.copy.return_value = True + mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + mock_seg_task = mocker.patch( + "orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task", + return_value=mock_seg_sam3_success + ) + + result = nersc_forge_recon_segment_flow(file_path="folder/file.h5", num_nodes=4) + + assert result is True + mock_controller.reconstruct_multinode.assert_called_once() + mock_seg_task.assert_called_once() + assert mock_transfer.copy.call_count >= 2 + + +def test_forge_recon_segment_flow_recon_failure(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_forge_recon_segment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct_multinode.return_value = {"success": False, "job_id": None, "timing": None} + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mocker.MagicMock()) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + with pytest.raises(ValueError, match="Reconstruction at NERSC Failed"): + nersc_forge_recon_segment_flow(file_path="folder/file.h5", num_nodes=4) + + +def test_forge_recon_segment_flow_seg_failure(mocker, mock_config832, mock_recon_success): + """Flow should return False (not raise) when only segmentation fails.""" + from orchestration.flows.bl832.nersc import nersc_forge_recon_segment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct_multinode.return_value = mock_recon_success + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + mock_transfer = mocker.MagicMock() + mock_transfer.copy.return_value = True + mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + mocker.patch( + "orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task", + return_value={"success": False, "job_id": None, "timing": None, "output_dir": None} + ) + + result = nersc_forge_recon_segment_flow(file_path="folder/file.h5", num_nodes=4) + + assert result is False + + +# ────────────────────────────────────────────────────────────────────────────── +# nersc_forge_recon_multisegment_flow +# ────────────────────────────────────────────────────────────────────────────── + +def test_forge_recon_multisegment_flow_both_succeed(mocker, mock_config832, mock_recon_success): + from orchestration.flows.bl832.nersc import nersc_forge_recon_multisegment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct_multinode.return_value = mock_recon_success + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + mock_transfer = mocker.MagicMock() + mock_transfer.copy.return_value = True + mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") + mock_dino_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dino_task") + mock_combine_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_combine_segmentations_task") + + mock_sam3_task.submit.return_value = _make_future( + mocker, {"success": True, "job_id": "1", "timing": None, "output_dir": "/out"} + ) + mock_dino_task.submit.return_value = _make_future(mocker, True) + mock_combine_task.submit.return_value = _make_future(mocker, True) + + result = nersc_forge_recon_multisegment_flow(file_path="folder/file.h5", num_nodes=4) + + assert result is True + mock_sam3_task.submit.assert_called_once() + mock_dino_task.submit.assert_called_once() + mock_combine_task.submit.assert_called_once() + + +def test_forge_recon_multisegment_flow_only_sam3_succeeds(mocker, mock_config832, mock_recon_success): + """When only SAM3 succeeds, combine should be skipped but flow returns True.""" + from orchestration.flows.bl832.nersc import nersc_forge_recon_multisegment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct_multinode.return_value = mock_recon_success + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + mock_transfer = mocker.MagicMock() + mock_transfer.copy.return_value = True + mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") + mock_dino_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dino_task") + mock_combine_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_combine_segmentations_task") + + mock_sam3_task.submit.return_value = _make_future( + mocker, {"success": True, "job_id": "1", "timing": None, "output_dir": "/out"} + ) + mock_dino_task.submit.return_value = _make_future(mocker, False) + + result = nersc_forge_recon_multisegment_flow(file_path="folder/file.h5", num_nodes=4) + + assert result is True + mock_combine_task.submit.assert_not_called() + + +def test_forge_recon_multisegment_flow_both_seg_fail(mocker, mock_config832, mock_recon_success): + from orchestration.flows.bl832.nersc import nersc_forge_recon_multisegment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct_multinode.return_value = mock_recon_success + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + + mock_transfer = mocker.MagicMock() + mock_transfer.copy.return_value = False + mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") + mock_dino_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dino_task") + mock_combine_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_combine_segmentations_task") + + mock_sam3_task.submit.return_value = _make_future( + mocker, {"success": False, "job_id": None, "timing": None, "output_dir": None} + ) + mock_dino_task.submit.return_value = _make_future(mocker, False) + + result = nersc_forge_recon_multisegment_flow(file_path="folder/file.h5", num_nodes=4) + + assert result is False + mock_combine_task.submit.assert_not_called() + + +def test_forge_recon_multisegment_flow_recon_failure(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_forge_recon_multisegment_flow + + mock_controller = mocker.MagicMock() + mock_controller.reconstruct_multinode.return_value = {"success": False, "job_id": None, "timing": None} + mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) + mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mocker.MagicMock()) + mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) + + with pytest.raises(ValueError, match="Reconstruction at NERSC Failed"): + nersc_forge_recon_multisegment_flow(file_path="folder/file.h5", num_nodes=4) From 950944715b3795736d5e95ba4808cdc9a0b54835 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 18 Mar 2026 10:45:45 -0700 Subject: [PATCH 56/72] Moving recon/segmentation num_nodes configuration to config.yaml --- config.yml | 6 ++++++ orchestration/flows/bl832/config.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/config.yml b/config.yml index bc8ffa20..9196a68c 100644 --- a/config.yml +++ b/config.yml @@ -159,3 +159,9 @@ prefect: scicat: jobs_api_url: https://dataportal.als.lbl.gov/api/ingest/jobs + +hpc_submission_settings832: + nersc_reconstruction: + num_nodes: 16 + nersc_segmentation: + num_nodes: 42 diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 8d8f8682..e750f770 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -28,5 +28,5 @@ def _beam_specific_config(self) -> None: self.alcf832_scratch = self.endpoints["alcf832_scratch"] self.scicat = self.config["scicat"] self.ghcr_images832 = self.config["ghcr_images832"] - self.nersc_recon_num_nodes = 16 - self.nersc_segment_num_nodes = 42 + self.nersc_recon_num_nodes = self.config["hpc_submission_settings832"]["nersc_reconstruction"]["num_nodes"] + self.nersc_segment_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation"]["num_nodes"] From d637f5eefbb8392a2399c4c176f2d04cc642bb7f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 18 Mar 2026 11:01:40 -0700 Subject: [PATCH 57/72] Replacing the original reconstruct code with the multinode version throughout. Setting defaults to 4 nodes so we can always run without a reservation --- config.yml | 8 +- orchestration/flows/bl832/config.py | 3 +- orchestration/flows/bl832/dispatcher.py | 3 +- orchestration/flows/bl832/nersc.py | 228 +----------------------- 4 files changed, 13 insertions(+), 229 deletions(-) diff --git a/config.yml b/config.yml index 9196a68c..0a9f560c 100644 --- a/config.yml +++ b/config.yml @@ -162,6 +162,8 @@ scicat: hpc_submission_settings832: nersc_reconstruction: - num_nodes: 16 - nersc_segmentation: - num_nodes: 42 + num_nodes: 4 + nersc_segmentation_sam3: + num_nodes: 4 + nersc_segmentation_dino: + num_nodes: 4 diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index e750f770..c0812af0 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -29,4 +29,5 @@ def _beam_specific_config(self) -> None: self.scicat = self.config["scicat"] self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_num_nodes = self.config["hpc_submission_settings832"]["nersc_reconstruction"]["num_nodes"] - self.nersc_segment_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation"]["num_nodes"] + self.nersc_segment_sam3_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation_sam3"]["num_nodes"] + self.nersc_segment_dino_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation_dino"]["num_nodes"] diff --git a/orchestration/flows/bl832/dispatcher.py b/orchestration/flows/bl832/dispatcher.py index 60fefb38..f3243c12 100644 --- a/orchestration/flows/bl832/dispatcher.py +++ b/orchestration/flows/bl832/dispatcher.py @@ -25,8 +25,9 @@ class FlowParameterMapper: # From nersc.py "nersc_recon_flow/nersc_recon_flow": [ "file_path", + "num_nodes", "config"], - "nersc_recon_multinode_flow/nersc_recon_multinode_flow": [ + "nersc_forge_recon_multisegment_flow/nersc_forge_recon_multisegment_flow": [ "file_path", "num_nodes", "config"] diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 592040c9..b47d9cfa 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -78,138 +78,8 @@ def create_sfapi_client() -> Client: def reconstruct( self, file_path: str = "", - ) -> bool: - """ - Use NERSC for tomography reconstruction - - :param file_path: Path to the file to reconstruct - :return: True if successful, False otherwise - """ - logger.info("Starting NERSC reconstruction process.") - - user = self.client.user() - - raw_path = self.config.nersc832_alsdev_raw.root_path - logger.info(f"{raw_path=}") - - recon_image = self.config.ghcr_images832["recon_image"] - logger.info(f"{recon_image=}") - - recon_scripts_dir = self.config.nersc832_alsdev_recon_scripts.root_path - logger.info(f"{recon_scripts_dir=}") - - scratch_path = self.config.nersc832_alsdev_scratch.root_path - logger.info(f"{scratch_path=}") - - pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - logger.info(f"{pscratch_path=}") - - path = Path(file_path) - folder_name = path.parent.name - if not folder_name: - folder_name = "" - - file_name = f"{path.stem}.h5" - - logger.info(f"File name: {file_name}") - logger.info(f"Folder name: {folder_name}") - - # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately - # Note: If q=debug, there is no minimum time limit - # However, if q=preempt, there is a minimum time limit of 2 hours. Otherwise the job won't run. - # The realtime queue can only be used for select accounts (e.g. ALS) - job_script = f"""#!/bin/bash -#SBATCH -q realtime -#SBATCH -A als -#SBATCH -C cpu -#SBATCH --job-name=tomo_recon_{folder_name}_{file_name} -#SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out -#SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err -#SBATCH -N 1 -#SBATCH --ntasks-per-node 1 -#SBATCH --cpus-per-task 128 -#SBATCH --time=0:15:00 -#SBATCH --exclusive - -date -echo "Creating directory {pscratch_path}/8.3.2/raw/{folder_name}" -mkdir -p {pscratch_path}/8.3.2/raw/{folder_name} -mkdir -p {pscratch_path}/8.3.2/scratch/{folder_name} - -echo "Copying file {raw_path}/{folder_name}/{file_name} to {pscratch_path}/8.3.2/raw/{folder_name}/" -cp {raw_path}/{folder_name}/{file_name} {pscratch_path}/8.3.2/raw/{folder_name} -if [ $? -ne 0 ]; then - echo "Failed to copy data to pscratch." - exit 1 -fi - -# chmod -R 2775 {pscratch_path}/8.3.2 -chmod 2775 {pscratch_path}/8.3.2/raw/{folder_name} -chmod 2775 {pscratch_path}/8.3.2/scratch/{folder_name} -chmod 664 {pscratch_path}/8.3.2/raw/{folder_name}/{file_name} - - -echo "Verifying copied files..." -ls -l {pscratch_path}/8.3.2/raw/{folder_name}/ - -echo "Running reconstruction container..." -srun podman-hpc run \ ---env NUMEXPR_MAX_THREADS=128 \\ ---env NUMEXPR_NUM_THREADS=128 \\ ---env OMP_NUM_THREADS=128 \\ ---env MKL_NUM_THREADS=128 \\ ---volume {recon_scripts_dir}/sfapi_reconstruction.py:/alsuser/sfapi_reconstruction.py \ ---volume {pscratch_path}/8.3.2:/alsdata \ ---volume {pscratch_path}/8.3.2:/alsuser/ \ -{recon_image} \ -bash -c "python sfapi_reconstruction.py {file_name} {folder_name}" -date -""" - - try: - logger.info("Submitting reconstruction job script to Perlmutter.") - perlmutter = self.client.compute(Machine.perlmutter) - job = perlmutter.submit_job(job_script) - logger.info(f"Submitted job ID: {job.jobid}") - - try: - job.update() - except Exception as update_err: - logger.warning(f"Initial job update failed, continuing: {update_err}") - - time.sleep(60) - logger.info(f"Job {job.jobid} current state: {job.state}") - - job.complete() # Wait until the job completes - logger.info("Reconstruction job completed successfully.") - return True - - except Exception as e: - logger.info(f"Error during job submission or completion: {e}") - match = re.search(r"Job not found:\s*(\d+)", str(e)) - - if match: - jobid = match.group(1) - logger.info(f"Attempting to recover job {jobid}.") - try: - job = self.client.perlmutter.job(jobid=jobid) - time.sleep(30) - job.complete() - logger.info("Reconstruction job completed successfully after recovery.") - return True - except Exception as recovery_err: - logger.error(f"Failed to recover job {jobid}: {recovery_err}") - return False - else: - # Unknown error: cannot recover - return False - - def reconstruct_multinode( - self, - file_path: str = "", num_nodes: int = 2, ) -> bool: - """ Use NERSC for tomography reconstruction @@ -1587,97 +1457,7 @@ def schedule_pruning( @flow(name="nersc_recon_flow", flow_run_name="nersc_recon-{file_path}") def nersc_recon_flow( file_path: str, - config: Optional[Config832] = None, -) -> bool: - """ - Perform tomography reconstruction on NERSC. - - :param file_path: Path to the file to reconstruct. - """ - logger = get_run_logger() - - if config is None: - logger.info("Initializing Config") - config = Config832() - - logger.info(f"Starting NERSC reconstruction flow for {file_path=}") - controller = get_controller( - hpc_type=HPC.NERSC, - config=config - ) - logger.info("NERSC reconstruction controller initialized") - - nersc_reconstruction_success = controller.reconstruct( - file_path=file_path, - ) - logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") - nersc_multi_res_success = controller.build_multi_resolution( - file_path=file_path, - ) - logger.info(f"NERSC multi-resolution success: {nersc_multi_res_success}") - - path = Path(file_path) - folder_name = path.parent.name - file_name = path.stem - - tiff_file_path = f"{folder_name}/rec{file_name}" - zarr_file_path = f"{folder_name}/rec{file_name}.zarr" - - logger.info(f"{tiff_file_path=}") - logger.info(f"{zarr_file_path=}") - - # Transfer reconstructed data - logger.info("Preparing transfer.") - transfer_controller = get_transfer_controller( - transfer_type=CopyMethod.GLOBUS, - config=config - ) - - logger.info("Copy from /pscratch/sd/a/alsdev/8.3.2 to /global/cfs/cdirs/als/data_mover/8.3.2/scratch.") - transfer_controller.copy( - file_path=tiff_file_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.nersc832_alsdev_scratch - ) - - transfer_controller.copy( - file_path=zarr_file_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.nersc832_alsdev_scratch - ) - - logger.info("Copy from NERSC /global/cfs/cdirs/als/data_mover/8.3.2/scratch to data832") - transfer_controller.copy( - file_path=tiff_file_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch - ) - - transfer_controller.copy( - file_path=zarr_file_path, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch - ) - - logger.info("Scheduling pruning tasks.") - schedule_pruning( - config=config, - raw_file_path=file_path, - tiff_file_path=tiff_file_path, - zarr_file_path=zarr_file_path - ) - - # TODO: Ingest into SciCat - if nersc_reconstruction_success and nersc_multi_res_success: - return True - else: - return False - - -@flow(name="nersc_recon_multinode_flow", flow_run_name="nersc_recon_multinode-{file_path}") -def nersc_recon_multinode_flow( - file_path: str, - num_nodes: Optional[int] = 16, + num_nodes: Optional[int] = 4, config: Optional[Config832] = None, ) -> bool: """ @@ -1706,7 +1486,7 @@ def nersc_recon_multinode_flow( logger.info(f"Configured to use {num_nodes} nodes for reconstruction") logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") - nersc_reconstruction_success = controller.reconstruct_multinode( + nersc_reconstruction_success = controller.reconstruct( file_path=file_path, num_nodes=num_nodes ) @@ -1858,7 +1638,7 @@ def nersc_forge_recon_segment_flow( # STEP 2: Run Multinode Reconstruction at NERSC logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") - recon_result = controller.reconstruct_multinode( + recon_result = controller.reconstruct( file_path=file_path, num_nodes=num_nodes ) @@ -2087,7 +1867,7 @@ def nersc_forge_recon_multisegment_flow( # ── STEP 1: Multinode Reconstruction ───────────────────────────────────── logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") - recon_result = controller.reconstruct_multinode( + recon_result = controller.reconstruct( file_path=file_path, num_nodes=num_nodes ) From 35c80197ca5faaa4a981a314e57df3550ff6f5c5 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 18 Mar 2026 11:23:28 -0700 Subject: [PATCH 58/72] removing the sam3 forge segmentation flow, and renaming the nersc_forge_recon_segment_flow to nersc_petiole_segment_flow --- orchestration/flows/bl832/dispatcher.py | 21 ++- orchestration/flows/bl832/nersc.py | 236 +----------------------- 2 files changed, 15 insertions(+), 242 deletions(-) diff --git a/orchestration/flows/bl832/dispatcher.py b/orchestration/flows/bl832/dispatcher.py index f3243c12..6f216f6d 100644 --- a/orchestration/flows/bl832/dispatcher.py +++ b/orchestration/flows/bl832/dispatcher.py @@ -27,7 +27,7 @@ class FlowParameterMapper: "file_path", "num_nodes", "config"], - "nersc_forge_recon_multisegment_flow/nersc_forge_recon_multisegment_flow": [ + "nersc_petiole_segment_flow/nersc_petiole_segment_flow": [ "file_path", "num_nodes", "config"] @@ -64,7 +64,7 @@ class DecisionFlowInputModel(BaseModel): def setup_decision_settings( alcf_recon: bool, nersc_recon: bool, - nersc_recon_multinode: bool, + nersc_petiole_segment: bool, new_file_832: bool ) -> dict: """ @@ -72,7 +72,7 @@ def setup_decision_settings( :param alcf_recon: Boolean indicating whether to run the ALCF reconstruction flow. :param nersc_recon: Boolean indicating whether to run the NERSC reconstruction flow. - :param nersc_recon_multinode: Boolean indicating whether to run the NERSC multinode reconstruction flow. + :param nersc_petiole_segment: Boolean indicating whether to run the NERSC petiole segmentation flow. :param new_file_832: Boolean indicating whether to move files to NERSC. :return: A dictionary containing the settings for each flow. """ @@ -80,13 +80,13 @@ def setup_decision_settings( try: logger.info(f"Setting up decision settings: alcf_recon={alcf_recon}, " f"nersc_recon={nersc_recon}, " - f"nersc_recon_multinode={nersc_recon_multinode}, " + f"nersc_petiole_segment={nersc_petiole_segment}, " f"new_file_832={new_file_832}") # Define which flows to run based on the input settings settings = { "alcf_recon_flow/alcf_recon_flow": alcf_recon, "nersc_recon_flow/nersc_recon_flow": nersc_recon, - "nersc_recon_multinode_flow/nersc_recon_multinode_flow": nersc_recon_multinode, + "nersc_petiole_segment_flow/nersc_petiole_segment_flow": nersc_petiole_segment, "new_832_file_flow/new_file_832": new_file_832 } # Save the settings in a JSON block for later retrieval by other flows @@ -164,10 +164,13 @@ async def dispatcher( nersc_params = FlowParameterMapper.get_flow_parameters("nersc_recon_flow/nersc_recon_flow", available_params) tasks.append(run_recon_flow_async("nersc_recon_flow/nersc_recon_flow", nersc_params)) - if decision_settings.get("nersc_recon_multinode_flow/nersc_recon_multinode_flow"): - nersc_multinode_params = FlowParameterMapper.get_flow_parameters( - "nersc_recon_multinode_flow/nersc_recon_multinode_flow", available_params) - tasks.append(run_recon_flow_async("nersc_recon_multinode_flow/nersc_recon_multinode_flow", nersc_multinode_params)) + if decision_settings.get("nersc_petiole_segment_flow/nersc_petiole_segment_flow"): + nersc_petiole_segment_params = FlowParameterMapper.get_flow_parameters( + "nersc_petiole_segment_flow/nersc_petiole_segment_flow", available_params + ) + tasks.append( + run_recon_flow_async("nersc_petiole_segment_flow/nersc_petiole_segment_flow", nersc_petiole_segment_params) + ) # Run ALCF and NERSC flows in parallel, if any if tasks: diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index b47d9cfa..47b2a377 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -1584,239 +1584,9 @@ def nersc_recon_flow( return False -@flow(name="nersc_forge_recon_segment_flow", flow_run_name="nersc_recon_seg-{file_path}") -def nersc_forge_recon_segment_flow( - file_path: str, - config: Optional[Config832] = None, - num_nodes: Optional[int] = None, -) -> bool: - """ - Process and transfer a file from bl832 to NERSC and run reconstruction and segmentation. - - :param file_path: The path to the file to be processed. - :param config: Configuration object for the flow. - :return: True if the flow completed successfully, False otherwise. - """ - logger = get_run_logger() - - # STEP 1: Setup Configuration - if config is None: - logger.info("Initializing Config") - config = Config832() - - # Paths - path = Path(file_path) - folder_name = path.parent.name - file_name = path.stem - scratch_path_tiff = f"{folder_name}/rec{file_name}" - scratch_path_segment = f"{folder_name}/seg{file_name}" - - logger.info(f"Starting NERSC reconstruction + segmentation flow for {file_path=}") - logger.info(f"Reconstructed TIFFs will be at: {scratch_path_tiff}") - logger.info(f"Segmented output will be at: {scratch_path_segment}") - - transfer_controller = get_transfer_controller( - transfer_type=CopyMethod.GLOBUS, - config=config - ) - - controller = get_controller( - hpc_type=HPC.NERSC, - config=config - ) - logger.info("NERSC reconstruction controller initialized") - - if num_nodes is None: - num_nodes = config.nersc_recon_num_nodes - logger.info(f"Configured to use {num_nodes} nodes for reconstruction") - - # Track success for pruning decisions - nersc_reconstruction_success = False - nersc_segmentation_success = False - data832_tiff_transfer_success = False - data832_segment_transfer_success = False - - # STEP 2: Run Multinode Reconstruction at NERSC - logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") - recon_result = controller.reconstruct( - file_path=file_path, - num_nodes=num_nodes - ) - - if isinstance(recon_result, dict): - nersc_reconstruction_success = recon_result.get('success', False) - timing = recon_result.get('timing') - - if timing: - logger.info("=" * 50) - logger.info("TIMING BREAKDOWN") - logger.info("=" * 50) - logger.info(f" Total job time: {timing.get('total', 'N/A')}s") - logger.info(f" Container pull: {timing.get('container_pull', 'N/A')}s") - logger.info( - f" File copy: {timing.get('file_copy', 'N/A')}s " - f"(skipped: {timing.get('copy_skipped', 'N/A')})" - ) - logger.info(f" Metadata detection: {timing.get('metadata', 'N/A')}s") - logger.info(f" RECONSTRUCTION: {timing.get('reconstruction', 'N/A')}s <-- actual recon time") - logger.info(f" Num slices: {timing.get('num_slices', 'N/A')}") - logger.info("=" * 50) - - # Calculate overhead - if all(k in timing for k in ['total', 'reconstruction']): - overhead = timing['total'] - timing['reconstruction'] - logger.info(f" Overhead: {overhead}s") - logger.info(f" Reconstruction %: {100 * timing['reconstruction'] / timing['total']:.1f}%") - logger.info("=" * 50) - else: - nersc_reconstruction_success = recon_result - - logger.info(f"NERSC reconstruction success: {nersc_reconstruction_success}") - - if not nersc_reconstruction_success: - logger.error("Reconstruction Failed.") - raise ValueError("Reconstruction at NERSC Failed") - else: - logger.info("Reconstruction Successful.") - - # STEP 3: Send reconstructed data (tiff) to data832 - logger.info("Transferring reconstructed TIFFs from NERSC pscratch to data832") - try: - data832_tiff_transfer_success = transfer_controller.copy( - file_path=scratch_path_tiff, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch - ) - logger.info(f"Transfer reconstructed TIFF data to data832 success: {data832_tiff_transfer_success}") - except Exception as e: - logger.error(f"Failed to transfer TIFFs to data832: {e}") - data832_tiff_transfer_success = False - - # STEP 4: Run the Segmentation Task at NERSC - logger.info(f"Starting NERSC SAM3 segmentation task for {scratch_path_tiff=}") - seg_result = nersc_segmentation_sam3_task( - recon_folder_path=scratch_path_tiff, - config=config - ) - if isinstance(seg_result, dict): - nersc_segmentation_success = seg_result.get('success', False) - timing = seg_result.get('timing') - - if timing: - logger.info("=" * 50) - logger.info("SEGMENTATION TIMING BREAKDOWN") - logger.info("=" * 50) - logger.info(f" Total time: {timing.get('total_time', 'N/A')}") - logger.info(f" Images processed: {timing.get('num_images', 'N/A')}") - logger.info(f" Time per image: {timing.get('time_per_image', 'N/A')}") - logger.info(f" Throughput: {timing.get('throughput', 'N/A')} images/min") - logger.info(f" Exit status: {timing.get('exit_status', 'N/A')}") - logger.info("=" * 50) - else: - nersc_segmentation_success = bool(seg_result) - - if not nersc_segmentation_success: - logger.warning("Segmentation at NERSC Failed") - else: - logger.info("Segmentation at NERSC Successful") - - # STEP 5: Transfer segmented data to data832 - logger.info("Transferring segmented data from NERSC pscratch to data832") - try: - data832_segment_transfer_success = transfer_controller.copy( - file_path=scratch_path_segment, - source=config.nersc832_alsdev_pscratch_scratch, - destination=config.data832_scratch - ) - logger.info(f"Transfer segmented data to data832 success: {data832_segment_transfer_success}") - except Exception as e: - logger.error(f"Failed to transfer segmented data to data832: {e}") - data832_segment_transfer_success = False - - # STEP 6: Schedule Pruning of files - logger.info("Scheduling file pruning tasks.") - prune_controller = get_prune_controller( - prune_type=PruneMethod.GLOBUS, - config=config - ) - - # Prune raw from NERSC pscratch - logger.info("Scheduling pruning of NERSC pscratch raw data.") - try: - prune_controller.prune( - file_path=f"{folder_name}/{path.name}", - source_endpoint=config.nersc832_alsdev_pscratch_raw, - check_endpoint=None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule raw data pruning: {e}") - - # Prune TIFFs from NERSC pscratch/scratch - if nersc_reconstruction_success: - logger.info("Scheduling pruning of NERSC pscratch reconstruction data.") - try: - prune_controller.prune( - file_path=scratch_path_tiff, - source_endpoint=config.nersc832_alsdev_pscratch_scratch, - check_endpoint=config.data832_scratch if data832_tiff_transfer_success else None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule reconstruction data pruning: {e}") - - # Prune segmented data from NERSC pscratch/scratch - if nersc_segmentation_success: - logger.info("Scheduling pruning of NERSC pscratch segmentation data.") - try: - prune_controller.prune( - file_path=scratch_path_segment, - source_endpoint=config.nersc832_alsdev_pscratch_scratch, - check_endpoint=config.data832_scratch if data832_segment_transfer_success else None, - days_from_now=1.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule segmentation data pruning: {e}") - - # Prune reconstructed TIFFs from data832 scratch (longer retention) - if data832_tiff_transfer_success: - logger.info("Scheduling pruning of data832 scratch reconstruction TIFF data.") - try: - prune_controller.prune( - file_path=scratch_path_tiff, - source_endpoint=config.data832_scratch, - check_endpoint=None, - days_from_now=30.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule data832 tiff pruning: {e}") - - # Prune segmented data from data832 scratch (longer retention) - if data832_segment_transfer_success: - logger.info("Scheduling pruning of data832 scratch segmentation data.") - try: - prune_controller.prune( - file_path=scratch_path_segment, - source_endpoint=config.data832_scratch, - check_endpoint=None, - days_from_now=30.0 - ) - except Exception as e: - logger.warning(f"Failed to schedule data832 segment pruning: {e}") - - # TODO: ingest to scicat - - if nersc_reconstruction_success and nersc_segmentation_success: - logger.info("NERSC reconstruction + segmentation flow completed successfully.") - return True - else: - logger.warning(f"Flow completed with issues: recon={nersc_reconstruction_success}, seg={nersc_segmentation_success}") - return False - - -@flow(name="nersc_forge_recon_multisegment_flow", - flow_run_name="nersc_recon_multiseg-{file_path}") -def nersc_forge_recon_multisegment_flow( +@flow(name="nersc_petiole_segment_flow", + flow_run_name="nersc_petiole_seg-{file_path}") +def nersc_petiole_segment_flow( file_path: str, config: Optional[Config832] = None, num_nodes: Optional[int] = None, From 8c25418dafa4b58c521340a20aa604cd092a9cdf Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 18 Mar 2026 11:25:42 -0700 Subject: [PATCH 59/72] if to elif --- orchestration/flows/bl832/nersc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 47b2a377..4ee86b4f 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -119,9 +119,9 @@ def reconstruct( if num_nodes == 8: qos = "debug" - if num_nodes < 8: + elif num_nodes < 8: qos = "realtime" - if num_nodes > 8: + elif num_nodes > 8: qos = "premium" # If using with a reservation: From 1ea50f606e6038d6e080af0fc191e24cd57867f2 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 18 Mar 2026 11:33:39 -0700 Subject: [PATCH 60/72] setting nersc account for slurm based on config settings --- config.yml | 1 + orchestration/flows/bl832/config.py | 4 ++++ orchestration/flows/bl832/nersc.py | 27 ++++++++++----------------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/config.yml b/config.yml index 0a9f560c..df931f26 100644 --- a/config.yml +++ b/config.yml @@ -167,3 +167,4 @@ hpc_submission_settings832: num_nodes: 4 nersc_segmentation_dino: num_nodes: 4 + nersc_account: als diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index c0812af0..81aba613 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -13,6 +13,7 @@ def _beam_specific_config(self) -> None: self.apps = transfer.build_apps(self.config) self.tc: TransferClient = transfer.init_transfer_client(self.apps["als_transfer"]) self.flow_client = flows.get_flows_client() + # Globus Endpoints self.spot832 = self.endpoints["spot832"] self.data832 = self.endpoints["data832"] self.data832_raw = self.endpoints["data832_raw"] @@ -26,8 +27,11 @@ def _beam_specific_config(self) -> None: self.nersc832_alsdev_recon_scripts = self.endpoints["nersc832_alsdev_recon_scripts"] self.alcf832_raw = self.endpoints["alcf832_raw"] self.alcf832_scratch = self.endpoints["alcf832_scratch"] + # SciCat self.scicat = self.config["scicat"] + # NERSC HPC submission settings self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_num_nodes = self.config["hpc_submission_settings832"]["nersc_reconstruction"]["num_nodes"] self.nersc_segment_sam3_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation_sam3"]["num_nodes"] self.nersc_segment_dino_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation_dino"]["num_nodes"] + self.nersc_account = self.config["hpc_submission_settings832"]["nersc_account"] diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 4ee86b4f..45ce278e 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -124,6 +124,8 @@ def reconstruct( elif num_nodes > 8: qos = "premium" + account = self.config.nersc_account + # If using with a reservation: # SBATCH -q regular # SBATCH --reservation=_CAP_MarchModCon_CPU @@ -131,7 +133,7 @@ def reconstruct( # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately job_script = f"""#!/bin/bash #SBATCH -q {qos} -#SBATCH -A als +#SBATCH -A {account} #SBATCH -C cpu #SBATCH --job-name=tomo_recon_{folder_name}_{file_name} #SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out @@ -393,10 +395,12 @@ def build_multi_resolution( raw_path = f"raw/{folder_name}/{file_name}.h5" logger.info(f"{raw_path=}") + account = self.config.nersc_account + # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately job_script = f"""#!/bin/bash #SBATCH -q realtime -#SBATCH -A als +#SBATCH -A {account} #SBATCH -C cpu #SBATCH --job-name=tomo_multires_{folder_name}_{file_name} #SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out @@ -495,7 +499,7 @@ def segmentation_sam3( default_confidence = [0.5] default_overlap = 0.25 # assuming this was your original default default_qos = "regular" - default_account = "als" + default_account = self.config.nersc_account default_constraint = "gpu" default_checkpoint = "checkpoint_v6.pt" @@ -532,8 +536,6 @@ def segmentation_sam3( account = seg_options.get("account", default_account) constraint = seg_options.get("constraint", default_constraint) checkpoint = seg_options.get("checkpoint", default_checkpoint) - # batch_size = 16 - # nproc_per_node = 4 finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" @@ -543,15 +545,6 @@ def segmentation_sam3( else: confidence_str = str(confidence) - # prompts = ["Cortex", "Phloem Fibers", "Air-based Pith cells", - # "Water-based Pith cells", "Xylem vessels"] - # prompts_str = " ".join([f'"{p}"' for p in prompts]) - - # if num_nodes <= 4: - # qos = "realtime" - # else: - # qos = "regular" - walltime = "00:59:00" job_name = f"seg_{Path(recon_folder_path).name}" @@ -796,11 +789,11 @@ def segmentation_dino( DINO_DEFAULTS = { "defaults": True, "batch_size": 4, - "num_nodes": 8, + "num_nodes": 4, "nproc_per_node": 4, "qos": "regular", - "account": "amsc006", - "constraint": "gpu&hbm80g", + "account": self.config.nersc_account, # amsc006 + "constraint": "gpu", # "gpu&hbm80g", "walltime": "00:59:00", } try: From 28ce316684317edef923b098b44641fa4b97790f Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 18 Mar 2026 11:39:48 -0700 Subject: [PATCH 61/72] Loading cpus-per-task from config for reconstruction slurm submission --- config.yml | 1 + orchestration/flows/bl832/config.py | 2 +- orchestration/flows/bl832/nersc.py | 7 ++++--- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/config.yml b/config.yml index df931f26..9b1e4514 100644 --- a/config.yml +++ b/config.yml @@ -163,6 +163,7 @@ scicat: hpc_submission_settings832: nersc_reconstruction: num_nodes: 4 + cpus-per-task: 128 nersc_segmentation_sam3: num_nodes: 4 nersc_segmentation_dino: diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 81aba613..e9d339b3 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -31,7 +31,7 @@ def _beam_specific_config(self) -> None: self.scicat = self.config["scicat"] # NERSC HPC submission settings self.ghcr_images832 = self.config["ghcr_images832"] - self.nersc_recon_num_nodes = self.config["hpc_submission_settings832"]["nersc_reconstruction"]["num_nodes"] + self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] self.nersc_segment_sam3_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation_sam3"]["num_nodes"] self.nersc_segment_dino_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation_dino"]["num_nodes"] self.nersc_account = self.config["hpc_submission_settings832"]["nersc_account"] diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 45ce278e..7dd59c21 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -125,6 +125,7 @@ def reconstruct( qos = "premium" account = self.config.nersc_account + cpus_per_task = self.config.nersc_recon_settings.get("cpus-per-task", 128) # If using with a reservation: # SBATCH -q regular @@ -140,7 +141,7 @@ def reconstruct( #SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err #SBATCH -N {num_nodes} #SBATCH --ntasks={num_nodes} -#SBATCH --cpus-per-task=128 +#SBATCH --cpus-per-task={cpus_per_task} #SBATCH --time=0:30:00 #SBATCH --exclusive #SBATCH --image={recon_image} @@ -1475,7 +1476,7 @@ def nersc_recon_flow( logger.info("NERSC reconstruction controller initialized") if num_nodes is None: - num_nodes = config.nersc_recon_num_nodes + num_nodes = config.nersc_recon_settings.get("num_nodes", 4) logger.info(f"Configured to use {num_nodes} nodes for reconstruction") logger.info(f"Using multi-node reconstruction with {num_nodes} nodes") @@ -1617,7 +1618,7 @@ def nersc_petiole_segment_flow( logger.info("NERSC controller initialized") if num_nodes is None: - num_nodes = config.nersc_recon_num_nodes + num_nodes = config.nersc_recon_settings.get("num_nodes", 4) logger.info(f"Configured to use {num_nodes} nodes for reconstruction") nersc_reconstruction_success = False From 1e97c5557c2a7a43dab494cf5dcdf4704d1590e3 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 18 Mar 2026 11:57:38 -0700 Subject: [PATCH 62/72] setting sam3 checkpoint/conda/model/vocab paths in config --- config.yml | 7 +++++++ orchestration/flows/bl832/config.py | 4 ++-- orchestration/flows/bl832/nersc.py | 25 +++++++++++-------------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/config.yml b/config.yml index 9b1e4514..9d37c4de 100644 --- a/config.yml +++ b/config.yml @@ -166,6 +166,13 @@ hpc_submission_settings832: cpus-per-task: 128 nersc_segmentation_sam3: num_nodes: 4 + cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 + conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311 + seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo/ + checkpoints_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/ + bpe_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz + original_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/sam3.pt + finetuned_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/checkpoint_v6.pt nersc_segmentation_dino: num_nodes: 4 nersc_account: als diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index e9d339b3..3d06850d 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -32,6 +32,6 @@ def _beam_specific_config(self) -> None: # NERSC HPC submission settings self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] - self.nersc_segment_sam3_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation_sam3"]["num_nodes"] - self.nersc_segment_dino_num_nodes = self.config["hpc_submission_settings832"]["nersc_segmentation_dino"]["num_nodes"] + self.nersc_segment_sam3_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_sam3"] + self.nersc_segment_dino_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_dino"] self.nersc_account = self.config["hpc_submission_settings832"]["nersc_account"] diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 7dd59c21..546d26e3 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -474,17 +474,15 @@ def segmentation_sam3( user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" - conda_env_path = f"{cfs_path}/envs/sam3-py311" - - # Paths - # seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v4/forge_feb_seg_model_demo/" - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo/" - checkpoints_dir = f"{cfs_path}/tomography_segmentation_scripts/sam3_finetune/sam3/" - bpe_path = f"{checkpoints_dir}/bpe_simple_vocab_16e6.txt.gz" - original_checkpoint = f"{checkpoints_dir}/sam3.pt" - # finetuned_checkpoint = f"{checkpoints_dir}/checkpoint_v3.pt" + sam3_settings = self.config.nersc_segment_sam3_settings + cfs_path = sam3_settings["cfs_path"] + conda_env_path = sam3_settings["conda_env_path"] + seg_scripts_dir = sam3_settings["seg_scripts_dir"] + checkpoints_dir = sam3_settings["checkpoints_dir"] + bpe_path = sam3_settings["bpe_path"] + original_checkpoint = sam3_settings["original_checkpoint_path"] + finetuned_checkpoint = sam3_settings["finetuned_checkpoint_path"] input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" output_folder = recon_folder_path.replace('/rec', '/seg') @@ -502,7 +500,7 @@ def segmentation_sam3( default_qos = "regular" default_account = self.config.nersc_account default_constraint = "gpu" - default_checkpoint = "checkpoint_v6.pt" + default_checkpoint = finetuned_checkpoint # Load options from Prefect variable try: @@ -526,7 +524,7 @@ def segmentation_sam3( qos = default_qos account = default_account constraint = default_constraint - checkpoint = default_checkpoint + finetuned_checkpoint = default_checkpoint else: logger.info("Using parameters from nersc-segmentation-options variable") batch_size = seg_options.get("batch_size", default_batch_size) @@ -537,8 +535,7 @@ def segmentation_sam3( account = seg_options.get("account", default_account) constraint = seg_options.get("constraint", default_constraint) checkpoint = seg_options.get("checkpoint", default_checkpoint) - - finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" + finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" # Format confidence for command line (handles both single value and list) if isinstance(confidence, list): From 664671a92767f01bf166633e3d3076b0d9fa9a04 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 18 Mar 2026 12:03:03 -0700 Subject: [PATCH 63/72] adding prompts to config --- config.yml | 5 +++++ orchestration/flows/bl832/nersc.py | 11 ++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/config.yml b/config.yml index 9d37c4de..7c6d1a36 100644 --- a/config.yml +++ b/config.yml @@ -173,6 +173,11 @@ hpc_submission_settings832: bpe_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz original_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/sam3.pt finetuned_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/checkpoint_v6.pt + prompts: + - "Phloem Fibers" + - "Hydrated Xylem vessels" + - "Air-based Pith cells" + - "Dehydrated Xylem vessels" nersc_segmentation_dino: num_nodes: 4 nersc_account: als diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 546d26e3..33d3aefc 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -465,7 +465,7 @@ def build_multi_resolution( def segmentation_sam3( self, recon_folder_path: str = "", - num_nodes: int = 42, + num_nodes: int = 4, ) -> dict: """ Run SAM3 segmentation at NERSC Perlmutter (v6 with overlap + max confidence stitching). @@ -484,6 +484,11 @@ def segmentation_sam3( original_checkpoint = sam3_settings["original_checkpoint_path"] finetuned_checkpoint = sam3_settings["finetuned_checkpoint_path"] + prompts = sam3_settings["prompts"] + if not isinstance(prompts, list) or not prompts: + raise ValueError("nersc_segmentation_sam3.prompts must be a non-empty list") + prompts_str = " ".join(f"'{p}'" for p in prompts)\ + input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" output_folder = recon_folder_path.replace('/rec', '/seg') output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}/sam3" @@ -496,7 +501,7 @@ def segmentation_sam3( default_batch_size = 1 default_patch_size = 400 default_confidence = [0.5] - default_overlap = 0.25 # assuming this was your original default + default_overlap = 0.25 default_qos = "regular" default_account = self.config.nersc_account default_constraint = "gpu" @@ -636,7 +641,7 @@ def segmentation_sam3( --batch-size "${{BATCH_SIZE}}" \ --confidence {confidence_str} \ --overlap-ratio {overlap} \ - --prompts 'Phloem Fibers' 'Hydrated Xylem vessels' 'Air-based Pith cells' 'Dehydrated Xylem vessels' \ + --prompts {prompts_str} \ --bpe-path "${{BPE_PATH}}" \ --original-checkpoint "${{ORIG_CKPT}}" \ --finetuned-checkpoint "${{FT_CKPT}}" From 5f694d282a4ab7841b2c91ffd05871832eb0f5cf Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 18 Mar 2026 14:53:12 -0700 Subject: [PATCH 64/72] Adding dino and combine segmentations settings to config --- config.yml | 18 ++++++++++++ orchestration/flows/bl832/config.py | 1 + orchestration/flows/bl832/nersc.py | 45 +++++++++++++++++------------ 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/config.yml b/config.yml index 7c6d1a36..501d4c5a 100644 --- a/config.yml +++ b/config.yml @@ -166,6 +166,9 @@ hpc_submission_settings832: cpus-per-task: 128 nersc_segmentation_sam3: num_nodes: 4 + ntasks-per-node: 1 + gpus-per-node: 4 + cpus-per-task: 128 cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311 seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo/ @@ -180,4 +183,19 @@ hpc_submission_settings832: - "Dehydrated Xylem vessels" nersc_segmentation_dino: num_nodes: 4 + ntasks-per-node: 1 + gpus-per-node: 4 + cpus-per-task: 128 + cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 + conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo + seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo/ + dino_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt + nersc_combine_segmentations: + num_nodes: 4 + ntasks: 1 + cpus-per-task: 128 + cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 + conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo + seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo + nersc_account: als diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 3d06850d..856a0a6c 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -34,4 +34,5 @@ def _beam_specific_config(self) -> None: self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] self.nersc_segment_sam3_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_sam3"] self.nersc_segment_dino_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_dino"] + self.nersc_combine_segmentation_settings = self.config["hpc_submission_settings832"]["nersc_combine_segmentations"] self.nersc_account = self.config["hpc_submission_settings832"]["nersc_account"] diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 33d3aefc..b153c951 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -484,6 +484,10 @@ def segmentation_sam3( original_checkpoint = sam3_settings["original_checkpoint_path"] finetuned_checkpoint = sam3_settings["finetuned_checkpoint_path"] + ntasks_per_node = sam3_settings["ntasks-per-node"] + gpus_per_node = sam3_settings["gpus-per-node"] + cpus_per_task = sam3_settings["cpus-per-task"] + prompts = sam3_settings["prompts"] if not isinstance(prompts, list) or not prompts: raise ValueError("nersc_segmentation_sam3.prompts must be a non-empty list") @@ -559,9 +563,9 @@ def segmentation_sam3( #SBATCH -C {constraint} # gpu #SBATCH --job-name={job_name} #SBATCH --time={walltime} -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=4 -#SBATCH --cpus-per-task=128 +#SBATCH --ntasks-per-node={ntasks_per_node} +#SBATCH --gpus-per-node={gpus_per_node} +#SBATCH --cpus-per-task={cpus_per_task} #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err @@ -776,11 +780,16 @@ def segmentation_dino( user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" - conda_env_path = f"{cfs_path}/envs/dino_demo" - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo" - dino_checkpoint = f"{cfs_path}/tomography_segmentation_scripts/dino/best.ckpt" + # Load from config + dino_settings = self.config.nersc_segment_dino_settings + cfs_path = dino_settings["cfs_path"] + conda_env_path = dino_settings["conda_env_path"] + seg_scripts_dir = dino_settings["seg_scripts_dir"] + dino_checkpoint = dino_settings["dino_checkpoint_path"] + cpus_per_task = dino_settings["cpus-per-task"] + gpus_per_node = dino_settings["gpus-per-node"] + ntasks_per_node = dino_settings["ntasks-per-node"] input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" seg_folder = recon_folder_path.replace("/rec", "/seg") @@ -829,9 +838,9 @@ def segmentation_dino( #SBATCH --reservation=_CAP_MarchModCon_GPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} -#SBATCH --ntasks-per-node=1 -#SBATCH --gpus-per-node=4 -#SBATCH --cpus-per-task=128 +#SBATCH --ntasks-per-node={ntasks_per_node} +#SBATCH --gpus-per-node={gpus_per_node} +#SBATCH --cpus-per-task={cpus_per_task} #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err @@ -962,10 +971,10 @@ def combine_segmentations( user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - cfs_path = "/global/cfs/cdirs/als/data_mover/8.3.2" - conda_env_path = f"{cfs_path}/envs/dino_demo" - seg_scripts_dir = f"{cfs_path}/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo" + combine_settings = self.config.nersc_combine_segmentation_settings + conda_env_path = combine_settings["conda_env_path"] + seg_scripts_dir = combine_settings["seg_scripts_dir"] seg_folder = recon_folder_path.replace("/rec", "/seg") input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" @@ -980,9 +989,9 @@ def combine_segmentations( COMBINE_DEFAULTS = { "defaults": True, - "num_nodes": 8, + "num_nodes": combine_settings["num_nodes"], "qos": "regular", - "account": "amsc006", + "account": self.config.nersc_account, # "amsc006", "constraint": "cpu", "walltime": "01:00:00", "dilate_px": 5, @@ -1008,16 +1017,16 @@ def combine_segmentations( job_name = f"combine_{Path(recon_folder_path).name}" +# #SBATCH --reservation=_CAP_MarchModCon_CPU job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=_CAP_MarchModCon_CPU #SBATCH --job-name={job_name} #SBATCH --time={walltime} -#SBATCH --ntasks=1 -#SBATCH --cpus-per-task=128 +#SBATCH --ntasks={combine_settings["ntasks"]} +#SBATCH --cpus-per-task={combine_settings["cpus-per-task"]} #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err From 0e66b2ff63e09e6fa6cfff7b5e0523eaaf75145e Mon Sep 17 00:00:00 2001 From: David Abramov Date: Fri, 20 Mar 2026 13:52:32 -0700 Subject: [PATCH 65/72] Updating pytests --- config.yml | 4 +- orchestration/_tests/test_bl832/test_nersc.py | 164 ++++++++---------- orchestration/_tests/test_globus_flow.py | 4 +- orchestration/_tests/test_sfapi_flow.py | 4 +- orchestration/flows/bl832/nersc.py | 33 +++- 5 files changed, 107 insertions(+), 102 deletions(-) diff --git a/config.yml b/config.yml index 501d4c5a..ef698e08 100644 --- a/config.yml +++ b/config.yml @@ -165,6 +165,7 @@ hpc_submission_settings832: num_nodes: 4 cpus-per-task: 128 nersc_segmentation_sam3: + reservation: "" num_nodes: 4 ntasks-per-node: 1 gpus-per-node: 4 @@ -182,6 +183,7 @@ hpc_submission_settings832: - "Air-based Pith cells" - "Dehydrated Xylem vessels" nersc_segmentation_dino: + reservation: "" num_nodes: 4 ntasks-per-node: 1 gpus-per-node: 4 @@ -191,11 +193,11 @@ hpc_submission_settings832: seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo/ dino_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt nersc_combine_segmentations: + reservation: "" num_nodes: 4 ntasks: 1 cpus-per-task: 128 cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo - nersc_account: als diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 69400ab4..de3dff8d 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -49,6 +49,7 @@ def mock_sfapi_client(mocker): def mock_config832(mocker): """ Mock Config832 constructor so any call to Config832() returns our mock. + Tests that call flows must pass config=None so Prefect's type validation is never given a MagicMock — the flow will call Config832() internally and get our mock back. @@ -72,7 +73,43 @@ def mock_config832(mocker): ep.root_path = f"/mock/{attr}" setattr(mock_config, attr, ep) + mock_config.nersc_account = "mock_account" mock_config.nersc_recon_num_nodes = 4 + mock_config.nersc_recon_settings = { + "cpus-per-task": 128, + "num_nodes": 4, + } + mock_config.nersc_segment_sam3_settings = { + "cfs_path": "/mock/cfs", + "conda_env_path": "/mock/conda/sam3", + "seg_scripts_dir": "/mock/seg_scripts/sam3", + "checkpoints_dir": "/mock/checkpoints", + "bpe_path": "/mock/bpe.model", + "original_checkpoint_path": "/mock/original.pt", + "finetuned_checkpoint_path": "/mock/finetuned.pt", + "ntasks-per-node": 1, + "gpus-per-node": 4, + "cpus-per-task": 32, + "prompts": ["cell wall", "lumen"], + } + mock_config.nersc_segment_dino_settings = { + "cfs_path": "/mock/cfs", + "conda_env_path": "/mock/conda/dino", + "seg_scripts_dir": "/mock/seg_scripts/dino", + "dino_checkpoint_path": "/mock/dino.pt", + "cpus-per-task": 32, + "gpus-per-node": 4, + "ntasks-per-node": 1, + "reservation": "", + } + mock_config.nersc_combine_segmentation_settings = { + "conda_env_path": "/mock/conda/combine", + "seg_scripts_dir": "/mock/seg_scripts/combine", + "num_nodes": 1, + "ntasks": 128, + "cpus-per-task": 1, + "reservation": "", + } mocker.patch("orchestration.flows.bl832.nersc.Config832", return_value=mock_config) return mock_config @@ -158,7 +195,9 @@ def test_reconstruct_success(mocker, mock_sfapi_client, mock_config832): mock_sfapi_client.compute.assert_called_once_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() - assert result is True + assert isinstance(result, dict) + assert result["success"] is True + assert result["job_id"] == "12345" def test_reconstruct_submission_failure(mocker, mock_sfapi_client, mock_config832): @@ -342,7 +381,6 @@ def capture(script): controller.segmentation_dino(recon_folder_path="folder/recfile") script = captured_scripts[0] - # The rec→seg substitution turns "recfile" into "segfile" assert "segfile" in script assert "/dino" in script @@ -430,7 +468,9 @@ def test_nersc_segmentation_sam3_task_success(mocker, mock_config832): ) mock_controller.segmentation_sam3.assert_called_once_with(recon_folder_path="folder/recfile") - assert result == {"success": True, "job_id": "99", "timing": None, "output_dir": "/out"} + assert isinstance(result, dict) + assert result["success"] is True + assert result["job_id"] == "99" def test_nersc_segmentation_sam3_task_failure(mocker, mock_config832): @@ -448,6 +488,7 @@ def test_nersc_segmentation_sam3_task_failure(mocker, mock_config832): config=mock_config832 ) + assert isinstance(result, dict) assert result["success"] is False @@ -518,84 +559,24 @@ def test_nersc_combine_segmentations_task_failure(mocker, mock_config832): # ────────────────────────────────────────────────────────────────────────────── -# nersc_forge_recon_segment_flow +# nersc_petiole_segment_flow (recon + SAM3 + DINO + combine) +# +# Replaces the former nersc_forge_recon_multisegment_flow tests. +# The cleaned nersc.py exposes nersc_petiole_segment_flow as the canonical +# multi-segmentation flow; controller.reconstruct() is the correct method name +# (reconstruct_multinode no longer exists). # # Prefect validates the `config` parameter against Optional[Config832] at -# runtime, so passing a MagicMock raises ParameterTypeError. The fix is to -# pass config=None — the flow's `if config is None: config = Config832()` -# branch then runs, calling the already-mocked constructor and returning our -# mock_config832 instance. -# ────────────────────────────────────────────────────────────────────────────── - -def test_forge_recon_segment_flow_success(mocker, mock_config832, mock_recon_success, mock_seg_sam3_success): - from orchestration.flows.bl832.nersc import nersc_forge_recon_segment_flow - - mock_controller = mocker.MagicMock() - mock_controller.reconstruct_multinode.return_value = mock_recon_success - mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) - - mock_transfer = mocker.MagicMock() - mock_transfer.copy.return_value = True - mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) - mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) - - mock_seg_task = mocker.patch( - "orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task", - return_value=mock_seg_sam3_success - ) - - result = nersc_forge_recon_segment_flow(file_path="folder/file.h5", num_nodes=4) - - assert result is True - mock_controller.reconstruct_multinode.assert_called_once() - mock_seg_task.assert_called_once() - assert mock_transfer.copy.call_count >= 2 - - -def test_forge_recon_segment_flow_recon_failure(mocker, mock_config832): - from orchestration.flows.bl832.nersc import nersc_forge_recon_segment_flow - - mock_controller = mocker.MagicMock() - mock_controller.reconstruct_multinode.return_value = {"success": False, "job_id": None, "timing": None} - mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) - mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mocker.MagicMock()) - mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) - - with pytest.raises(ValueError, match="Reconstruction at NERSC Failed"): - nersc_forge_recon_segment_flow(file_path="folder/file.h5", num_nodes=4) - - -def test_forge_recon_segment_flow_seg_failure(mocker, mock_config832, mock_recon_success): - """Flow should return False (not raise) when only segmentation fails.""" - from orchestration.flows.bl832.nersc import nersc_forge_recon_segment_flow - - mock_controller = mocker.MagicMock() - mock_controller.reconstruct_multinode.return_value = mock_recon_success - mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) - - mock_transfer = mocker.MagicMock() - mock_transfer.copy.return_value = True - mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mock_transfer) - mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) - mocker.patch( - "orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task", - return_value={"success": False, "job_id": None, "timing": None, "output_dir": None} - ) - - result = nersc_forge_recon_segment_flow(file_path="folder/file.h5", num_nodes=4) - - assert result is False - - -# ────────────────────────────────────────────────────────────────────────────── -# nersc_forge_recon_multisegment_flow +# runtime, so we pass config=None and let the flow call Config832() internally, +# which returns mock_config832 via the fixture patch. # ────────────────────────────────────────────────────────────────────────────── -def test_forge_recon_multisegment_flow_both_succeed(mocker, mock_config832, mock_recon_success): - from orchestration.flows.bl832.nersc import nersc_forge_recon_multisegment_flow +def test_petiole_segment_flow_both_succeed(mocker, mock_config832, mock_recon_success): + """Recon + SAM3 + DINO all succeed → combine is called → flow returns True.""" + from orchestration.flows.bl832.nersc import nersc_petiole_segment_flow mock_controller = mocker.MagicMock() - mock_controller.reconstruct_multinode.return_value = mock_recon_success + mock_controller.reconstruct.return_value = mock_recon_success mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) mock_transfer = mocker.MagicMock() @@ -608,25 +589,26 @@ def test_forge_recon_multisegment_flow_both_succeed(mocker, mock_config832, mock mock_combine_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_combine_segmentations_task") mock_sam3_task.submit.return_value = _make_future( - mocker, {"success": True, "job_id": "1", "timing": None, "output_dir": "/out"} + mocker, {"success": True, "job_id": "1", "timing": None, "output_dir": "/out/sam3"} ) mock_dino_task.submit.return_value = _make_future(mocker, True) mock_combine_task.submit.return_value = _make_future(mocker, True) - result = nersc_forge_recon_multisegment_flow(file_path="folder/file.h5", num_nodes=4) + result = nersc_petiole_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) assert result is True + mock_controller.reconstruct.assert_called_once() mock_sam3_task.submit.assert_called_once() mock_dino_task.submit.assert_called_once() mock_combine_task.submit.assert_called_once() -def test_forge_recon_multisegment_flow_only_sam3_succeeds(mocker, mock_config832, mock_recon_success): +def test_petiole_segment_flow_only_sam3_succeeds(mocker, mock_config832, mock_recon_success): """When only SAM3 succeeds, combine should be skipped but flow returns True.""" - from orchestration.flows.bl832.nersc import nersc_forge_recon_multisegment_flow + from orchestration.flows.bl832.nersc import nersc_petiole_segment_flow mock_controller = mocker.MagicMock() - mock_controller.reconstruct_multinode.return_value = mock_recon_success + mock_controller.reconstruct.return_value = mock_recon_success mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) mock_transfer = mocker.MagicMock() @@ -639,21 +621,22 @@ def test_forge_recon_multisegment_flow_only_sam3_succeeds(mocker, mock_config832 mock_combine_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_combine_segmentations_task") mock_sam3_task.submit.return_value = _make_future( - mocker, {"success": True, "job_id": "1", "timing": None, "output_dir": "/out"} + mocker, {"success": True, "job_id": "1", "timing": None, "output_dir": "/out/sam3"} ) mock_dino_task.submit.return_value = _make_future(mocker, False) - result = nersc_forge_recon_multisegment_flow(file_path="folder/file.h5", num_nodes=4) + result = nersc_petiole_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) assert result is True mock_combine_task.submit.assert_not_called() -def test_forge_recon_multisegment_flow_both_seg_fail(mocker, mock_config832, mock_recon_success): - from orchestration.flows.bl832.nersc import nersc_forge_recon_multisegment_flow +def test_petiole_segment_flow_both_seg_fail(mocker, mock_config832, mock_recon_success): + """Recon succeeds but both segmentations fail → flow returns False.""" + from orchestration.flows.bl832.nersc import nersc_petiole_segment_flow mock_controller = mocker.MagicMock() - mock_controller.reconstruct_multinode.return_value = mock_recon_success + mock_controller.reconstruct.return_value = mock_recon_success mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) mock_transfer = mocker.MagicMock() @@ -670,20 +653,21 @@ def test_forge_recon_multisegment_flow_both_seg_fail(mocker, mock_config832, moc ) mock_dino_task.submit.return_value = _make_future(mocker, False) - result = nersc_forge_recon_multisegment_flow(file_path="folder/file.h5", num_nodes=4) + result = nersc_petiole_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) assert result is False mock_combine_task.submit.assert_not_called() -def test_forge_recon_multisegment_flow_recon_failure(mocker, mock_config832): - from orchestration.flows.bl832.nersc import nersc_forge_recon_multisegment_flow +def test_petiole_segment_flow_recon_failure(mocker, mock_config832): + """Recon failure should raise ValueError immediately.""" + from orchestration.flows.bl832.nersc import nersc_petiole_segment_flow mock_controller = mocker.MagicMock() - mock_controller.reconstruct_multinode.return_value = {"success": False, "job_id": None, "timing": None} + mock_controller.reconstruct.return_value = {"success": False, "job_id": None, "timing": None} mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) mocker.patch("orchestration.flows.bl832.nersc.get_transfer_controller", return_value=mocker.MagicMock()) mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) with pytest.raises(ValueError, match="Reconstruction at NERSC Failed"): - nersc_forge_recon_multisegment_flow(file_path="folder/file.h5", num_nodes=4) + nersc_petiole_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) diff --git a/orchestration/_tests/test_globus_flow.py b/orchestration/_tests/test_globus_flow.py index 4e424bad..c75317e1 100644 --- a/orchestration/_tests/test_globus_flow.py +++ b/orchestration/_tests/test_globus_flow.py @@ -1,5 +1,6 @@ import asyncio # import uuid +from unittest.mock import MagicMock from uuid import UUID, uuid4, uuid5 import warnings @@ -269,7 +270,8 @@ def test_alcf_recon_flow(mocker: MockFixture): ) mock_settings = mocker.MagicMock() - mock_settings.__getitem__ = lambda self, key: {"scicat": "mock_scicat", "ghcr_images832": "mock_ghcr"}[key] + _known_settings = {"scicat": "mock_scicat", "ghcr_images832": "mock_ghcr"} + mock_settings.__getitem__ = lambda self, key: _known_settings.get(key, MagicMock()) mocker.patch( "orchestration.config.settings", mock_settings diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py index 66203d19..d6fcfb23 100644 --- a/orchestration/_tests/test_sfapi_flow.py +++ b/orchestration/_tests/test_sfapi_flow.py @@ -188,7 +188,9 @@ def test_reconstruct_success(mock_sfapi_client, mock_config832): mock_sfapi_client.compute.return_value.submit_job.return_value.complete.assert_called_once() # Assert that the method returns True - assert result is True, "reconstruct should return True on successful job completion." + assert isinstance(result, dict) + assert result["success"] is True + assert result["job_id"] == "12345" def test_reconstruct_submission_failure(mock_sfapi_client, mock_config832): diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index b153c951..e458faf4 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -491,7 +491,7 @@ def segmentation_sam3( prompts = sam3_settings["prompts"] if not isinstance(prompts, list) or not prompts: raise ValueError("nersc_segmentation_sam3.prompts must be a non-empty list") - prompts_str = " ".join(f"'{p}'" for p in prompts)\ + prompts_str = " ".join(f"'{p}'" for p in prompts) input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" output_folder = recon_folder_path.replace('/rec', '/seg') @@ -510,6 +510,7 @@ def segmentation_sam3( default_account = self.config.nersc_account default_constraint = "gpu" default_checkpoint = finetuned_checkpoint + default_reservation = "" # Load options from Prefect variable try: @@ -534,6 +535,7 @@ def segmentation_sam3( account = default_account constraint = default_constraint finetuned_checkpoint = default_checkpoint + reservation = default_reservation else: logger.info("Using parameters from nersc-segmentation-options variable") batch_size = seg_options.get("batch_size", default_batch_size) @@ -545,20 +547,21 @@ def segmentation_sam3( constraint = seg_options.get("constraint", default_constraint) checkpoint = seg_options.get("checkpoint", default_checkpoint) finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" + reservation = seg_options.get("reservation", default_reservation) + # #SBATCH --reservation=_CAP_MarchModCon_GPU + reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" # Format confidence for command line (handles both single value and list) if isinstance(confidence, list): confidence_str = " ".join(str(c) for c in confidence) else: confidence_str = str(confidence) - walltime = "00:59:00" job_name = f"seg_{Path(recon_folder_path).name}" - job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} -#SBATCH --reservation=_CAP_MarchModCon_GPU +{reservation_line} #SBATCH -N {num_nodes} #SBATCH -C {constraint} # gpu #SBATCH --job-name={job_name} @@ -790,6 +793,7 @@ def segmentation_dino( cpus_per_task = dino_settings["cpus-per-task"] gpus_per_node = dino_settings["gpus-per-node"] ntasks_per_node = dino_settings["ntasks-per-node"] + reservation = dino_settings.get("reservation", "") input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" seg_folder = recon_folder_path.replace("/rec", "/seg") @@ -806,6 +810,7 @@ def segmentation_dino( "qos": "regular", "account": self.config.nersc_account, # amsc006 "constraint": "gpu", # "gpu&hbm80g", + "reservation": reservation, # e.g. "_CAP_MarchModCon_GPU" "walltime": "00:59:00", } try: @@ -828,14 +833,16 @@ def segmentation_dino( constraint = opts["constraint"] walltime = opts["walltime"] - job_name = f"dino_{Path(recon_folder_path).name}" + reservation = opts.get("reservation", "") + reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" + job_name = f"dino_{Path(recon_folder_path).name}" job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} #SBATCH -N {num_nodes} #SBATCH -C {constraint} -#SBATCH --reservation=_CAP_MarchModCon_GPU +{reservation_line} #SBATCH --job-name={job_name} #SBATCH --time={walltime} #SBATCH --ntasks-per-node={ntasks_per_node} @@ -995,6 +1002,7 @@ def combine_segmentations( "constraint": "cpu", "walltime": "01:00:00", "dilate_px": 5, + "reservation": combine_settings["reservation"] } try: seg_options = Variable.get("nersc-combine-seg-options", default={"defaults": True}, _sync=True) @@ -1014,6 +1022,9 @@ def combine_segmentations( constraint = opts["constraint"] walltime = opts["walltime"] dilate_px = opts["dilate_px"] + reservation = opts["reservation"] + + reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" job_name = f"combine_{Path(recon_folder_path).name}" @@ -1021,6 +1032,7 @@ def combine_segmentations( job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} +{reservation_line} #SBATCH -N {num_nodes} #SBATCH -C {constraint} #SBATCH --job-name={job_name} @@ -1991,10 +2003,13 @@ def nersc_segmentation_sam3_task( nersc_segmentation_success = tomography_controller.segmentation_sam3( recon_folder_path=recon_folder_path, ) - if not nersc_segmentation_success: - logger.error("Segmentation Failed.") + if isinstance(nersc_segmentation_success, dict): + success = nersc_segmentation_success["success"] + logger.info(f"Segmentation success: {success}") else: - logger.info("Segmentation Successful.") + success = bool(nersc_segmentation_success) + if not success: + logger.error("Segmentation Failed.") return nersc_segmentation_success From 8726dda21b1a8b5882c12851e7b04fc931564c21 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 23 Mar 2026 15:23:10 -0700 Subject: [PATCH 66/72] Moving the rest of job submission variables to config.yml, created a method for using defaults from config.yml vs prefect variable overrides (if defaults=False), updated controller methods to use _load_job_options() --- config.yml | 56 +++- orchestration/flows/bl832/config.py | 2 +- orchestration/flows/bl832/nersc.py | 488 +++++++++++++++++++--------- 3 files changed, 378 insertions(+), 168 deletions(-) diff --git a/config.yml b/config.yml index ef698e08..3228aad9 100644 --- a/config.yml +++ b/config.yml @@ -162,14 +162,43 @@ scicat: hpc_submission_settings832: nersc_reconstruction: + # ── SLURM resource allocation ───────────────────────────────────────────── + qos: realtime + account: als + reservation: "" num_nodes: 4 cpus-per-task: 128 + walltime: "0:30:00" + nersc_multiresolution: + # ── SLURM resource allocation ───────────────────────────────────────────── + qos: realtime + account: als + reservation: "" + cpus-per-task: 128 + walltime: "0:15:00" nersc_segmentation_sam3: + # ── SLURM resource allocation ───────────────────────────────────────────── + qos: regular + account: als + constraint: gpu reservation: "" num_nodes: 4 ntasks-per-node: 1 gpus-per-node: 4 cpus-per-task: 128 + walltime: "00:59:00" + # ── Inference parameters ────────────────────────────────────────────────── + batch_size: 1 + patch_size: 400 + confidence: + - 0.5 + overlap: 0.25 + prompts: + - "Phloem Fibers" + - "Hydrated Xylem vessels" + - "Air-based Pith cells" + - "Dehydrated Xylem vessels" + # ── Paths ───────────────────────────────────────────────────────────────── cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/sam3-py311 seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo/ @@ -177,27 +206,38 @@ hpc_submission_settings832: bpe_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz original_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/sam3.pt finetuned_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/checkpoint_v6.pt - prompts: - - "Phloem Fibers" - - "Hydrated Xylem vessels" - - "Air-based Pith cells" - - "Dehydrated Xylem vessels" nersc_segmentation_dino: + # ── SLURM resource allocation ───────────────────────────────────────────── + qos: regular + account: als + constraint: gpu reservation: "" num_nodes: 4 ntasks-per-node: 1 + nproc_per_node: 4 gpus-per-node: 4 cpus-per-task: 128 + walltime: "00:59:00" + # ── Inference parameters ────────────────────────────────────────────────── + batch_size: 4 + # ── Paths ───────────────────────────────────────────────────────────────── cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_v5_multiseg/forge_feb_seg_model_demo/ dino_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/dino/best.ckpt nersc_combine_segmentations: + # ── SLURM resource allocation ───────────────────────────────────────────── + qos: regular + account: als + constraint: cpu reservation: "" num_nodes: 4 - ntasks: 1 - cpus-per-task: 128 + ntasks: 128 + cpus-per-task: 1 + walltime: "01:00:00" + # ── Combination parameters ──────────────────────────────────────────────── + dilate_px: 5 + # ── Paths ───────────────────────────────────────────────────────────────── cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 conda_env_path: /global/cfs/cdirs/als/data_mover/8.3.2/envs/dino_demo seg_scripts_dir: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/inference_latest/forge_feb_seg_model_demo - nersc_account: als diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index 856a0a6c..b0c9a40f 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -32,7 +32,7 @@ def _beam_specific_config(self) -> None: # NERSC HPC submission settings self.ghcr_images832 = self.config["ghcr_images832"] self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] + self.nersc_multiresolution_settings = self.config["hpc_submission_settings832"]["nersc_multiresolution"] self.nersc_segment_sam3_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_sam3"] self.nersc_segment_dino_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_dino"] self.nersc_combine_segmentation_settings = self.config["hpc_submission_settings832"]["nersc_combine_segmentations"] - self.nersc_account = self.config["hpc_submission_settings832"]["nersc_account"] diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index e458faf4..14207007 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -12,7 +12,7 @@ from prefect.variables import Variable from sfapi_client import Client from sfapi_client.compute import Machine -from typing import Optional +from typing import Any, Optional from orchestration.flows.bl832.config import Config832 from orchestration.flows.bl832.job_controller import get_controller, HPC, TomographyHPCController @@ -28,6 +28,42 @@ load_dotenv() +def _load_job_options(variable_name: str, config_settings: dict[str, Any]) -> dict[str, Any]: + """ + Load job options, using config as defaults and a Prefect Variable as overrides. + + Resolution order: + + 1. Load the named Prefect Variable. + 2. If absent, malformed, or ``defaults: true`` → return ``config_settings`` unchanged. + 3. If ``defaults: false`` → return ``config_settings`` with variable values overlaid. + + The config YAML is the authoritative source for all default values. The Prefect + Variable only needs to contain the keys it wishes to override, and may introduce + keys not present in config (e.g. a bare ``checkpoint`` filename for SAM3). + + :param variable_name: Name of the Prefect Variable to load. + :param config_settings: Settings dict read directly from the Config832 object + (e.g. ``config.nersc_recon_settings``). Used as-is when defaults=True. + :return: Resolved options dict ready for use by the caller. + """ + try: + options = Variable.get(variable_name, default={"defaults": True}, _sync=True) + if isinstance(options, str): + options = json.loads(options) + except Exception as e: + logger.warning(f"Could not load '{variable_name}': {e}. Using config defaults.") + return dict(config_settings) + + if options.get("defaults", True): + logger.info(f"Using config defaults for '{variable_name}'") + return dict(config_settings) + + logger.info(f"Overriding config defaults with variable options for '{variable_name}'") + overrides = {k: v for k, v in options.items() if k != "defaults"} + return {**config_settings, **overrides} + + class NERSCTomographyHPCController(TomographyHPCController, NerscStreamingMixin): """ Implementation for a NERSC-based tomography HPC controller. @@ -117,15 +153,27 @@ def reconstruct( logger.info(f"Folder name: {folder_name}") logger.info(f"Number of nodes: {num_nodes}") - if num_nodes == 8: - qos = "debug" - elif num_nodes < 8: - qos = "realtime" - elif num_nodes > 8: - qos = "premium" + # if num_nodes == 8: + # qos = "debug" + # elif num_nodes < 8: + # qos = "realtime" + # elif num_nodes > 8: + # qos = "premium" + + # account = self.config.nersc_account + # cpus_per_task = self.config.nersc_recon_settings.get("cpus-per-task", 128) + + opts = _load_job_options("nersc-reconstruction-options", self.config.nersc_recon_settings) + + num_nodes = opts.get("num_nodes", num_nodes) + cpus_per_task = opts["cpus-per-task"] + qos = opts["qos"] + account = opts["account"] + reservation = opts.get("reservation", "") + walltime = opts.get("walltime", "0:30:00") + + reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" - account = self.config.nersc_account - cpus_per_task = self.config.nersc_recon_settings.get("cpus-per-task", 128) # If using with a reservation: # SBATCH -q regular @@ -135,6 +183,7 @@ def reconstruct( job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} +{reservation_line} #SBATCH -C cpu #SBATCH --job-name=tomo_recon_{folder_name}_{file_name} #SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out @@ -142,7 +191,7 @@ def reconstruct( #SBATCH -N {num_nodes} #SBATCH --ntasks={num_nodes} #SBATCH --cpus-per-task={cpus_per_task} -#SBATCH --time=0:30:00 +#SBATCH --time={walltime} #SBATCH --exclusive #SBATCH --image={recon_image} @@ -396,20 +445,33 @@ def build_multi_resolution( raw_path = f"raw/{folder_name}/{file_name}.h5" logger.info(f"{raw_path=}") - account = self.config.nersc_account + # account = self.config.nersc_account + + opts = _load_job_options( + "nersc-multiresolution-options", self.config.nersc_multiresolution_settings + ) + + qos = opts["qos"] + account = opts["account"] + cpus_per_task = opts["cpus-per-task"] + reservation = opts.get("reservation", "") + walltime = opts.get("walltime", "0:15:00") + + reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" # IMPORTANT: job script must be deindented to the leftmost column or it will fail immediately job_script = f"""#!/bin/bash -#SBATCH -q realtime +#SBATCH -q {qos} #SBATCH -A {account} +{reservation_line} #SBATCH -C cpu #SBATCH --job-name=tomo_multires_{folder_name}_{file_name} #SBATCH --output={pscratch_path}/tomo_recon_logs/%x_%j.out #SBATCH --error={pscratch_path}/tomo_recon_logs/%x_%j.err #SBATCH -N 1 #SBATCH --ntasks-per-node 1 -#SBATCH --cpus-per-task 128 -#SBATCH --time=0:15:00 +#SBATCH --cpus-per-task {cpus_per_task} +#SBATCH --time={walltime} #SBATCH --exclusive date @@ -475,24 +537,42 @@ def segmentation_sam3( user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - sam3_settings = self.config.nersc_segment_sam3_settings - cfs_path = sam3_settings["cfs_path"] - conda_env_path = sam3_settings["conda_env_path"] - seg_scripts_dir = sam3_settings["seg_scripts_dir"] - checkpoints_dir = sam3_settings["checkpoints_dir"] - bpe_path = sam3_settings["bpe_path"] - original_checkpoint = sam3_settings["original_checkpoint_path"] - finetuned_checkpoint = sam3_settings["finetuned_checkpoint_path"] - - ntasks_per_node = sam3_settings["ntasks-per-node"] - gpus_per_node = sam3_settings["gpus-per-node"] - cpus_per_task = sam3_settings["cpus-per-task"] + opts = _load_job_options("nersc-segmentation-options", self.config.nersc_segment_sam3_settings) + + cfs_path = opts["cfs_path"] + conda_env_path = opts["conda_env_path"] + seg_scripts_dir = opts["seg_scripts_dir"] + checkpoints_dir = opts["checkpoints_dir"] + bpe_path = opts["bpe_path"] + original_checkpoint = opts["original_checkpoint_path"] + ntasks_per_node = opts["ntasks-per-node"] + gpus_per_node = opts["gpus-per-node"] + cpus_per_task = opts["cpus-per-task"] + num_nodes = opts.get("num_nodes", num_nodes) + batch_size = opts["batch_size"] + patch_size = opts["patch_size"] + confidence = opts["confidence"] + overlap = opts["overlap"] + qos = opts["qos"] + account = opts["account"] + constraint = opts["constraint"] + walltime = opts.get("walltime", "00:59:00") + reservation = opts.get("reservation", "") - prompts = sam3_settings["prompts"] + prompts = opts["prompts"] if not isinstance(prompts, list) or not prompts: raise ValueError("nersc_segmentation_sam3.prompts must be a non-empty list") prompts_str = " ".join(f"'{p}'" for p in prompts) + # "checkpoint" in the Prefect Variable is a bare filename that overrides + # the config's finetuned_checkpoint_path. Config supplies the full path + # as the default, so path construction is only needed when the variable + # explicitly provides a different checkpoint filename. + if "checkpoint" in opts and opts["checkpoint"] != Path(opts["finetuned_checkpoint_path"]).name: + finetuned_checkpoint = f"{checkpoints_dir}/{opts['checkpoint']}" + else: + finetuned_checkpoint = opts["finetuned_checkpoint_path"] + input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" output_folder = recon_folder_path.replace('/rec', '/seg') output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}/sam3" @@ -501,63 +581,96 @@ def segmentation_sam3( logger.info(f"Output directory: {output_dir}") logger.info(f"Conda environment: {conda_env_path}") - # Default values (used when defaults=True or variable not found) - default_batch_size = 1 - default_patch_size = 400 - default_confidence = [0.5] - default_overlap = 0.25 - default_qos = "regular" - default_account = self.config.nersc_account - default_constraint = "gpu" - default_checkpoint = finetuned_checkpoint - default_reservation = "" - - # Load options from Prefect variable - try: - seg_options = Variable.get("nersc-segmentation-options", default={}) - if isinstance(seg_options, str): - import json - seg_options = json.loads(seg_options) - except Exception as e: - logger.warning(f"Could not load nersc-segmentation-options variable: {e}. Using defaults.") - seg_options = {"defaults": True} - - # Determine which values to use - use_defaults = seg_options.get("defaults", True) - - if use_defaults: - logger.info("Using hardcoded default segmentation parameters") - batch_size = default_batch_size - patch_size = default_patch_size - confidence = default_confidence - overlap = default_overlap - qos = default_qos - account = default_account - constraint = default_constraint - finetuned_checkpoint = default_checkpoint - reservation = default_reservation - else: - logger.info("Using parameters from nersc-segmentation-options variable") - batch_size = seg_options.get("batch_size", default_batch_size) - patch_size = seg_options.get("patch_size", default_patch_size) - confidence = seg_options.get("confidence", default_confidence) - overlap = seg_options.get("overlap", default_overlap) - qos = seg_options.get("qos", default_qos) - account = seg_options.get("account", default_account) - constraint = seg_options.get("constraint", default_constraint) - checkpoint = seg_options.get("checkpoint", default_checkpoint) - finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" - reservation = seg_options.get("reservation", default_reservation) - - # #SBATCH --reservation=_CAP_MarchModCon_GPU + confidence_str = ( + " ".join(str(c) for c in confidence) + if isinstance(confidence, list) else str(confidence) + ) reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" - # Format confidence for command line (handles both single value and list) - if isinstance(confidence, list): - confidence_str = " ".join(str(c) for c in confidence) - else: - confidence_str = str(confidence) - walltime = "00:59:00" job_name = f"seg_{Path(recon_folder_path).name}" + + # sam3_settings = self.config.nersc_segment_sam3_settings + # cfs_path = sam3_settings["cfs_path"] + # conda_env_path = sam3_settings["conda_env_path"] + # seg_scripts_dir = sam3_settings["seg_scripts_dir"] + # checkpoints_dir = sam3_settings["checkpoints_dir"] + # bpe_path = sam3_settings["bpe_path"] + # original_checkpoint = sam3_settings["original_checkpoint_path"] + # finetuned_checkpoint = sam3_settings["finetuned_checkpoint_path"] + + # ntasks_per_node = sam3_settings["ntasks-per-node"] + # gpus_per_node = sam3_settings["gpus-per-node"] + # cpus_per_task = sam3_settings["cpus-per-task"] + + # prompts = sam3_settings["prompts"] + # if not isinstance(prompts, list) or not prompts: + # raise ValueError("nersc_segmentation_sam3.prompts must be a non-empty list") + # prompts_str = " ".join(f"'{p}'" for p in prompts) + + # input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" + # output_folder = recon_folder_path.replace('/rec', '/seg') + # output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}/sam3" + + # logger.info(f"Input directory: {input_dir}") + # logger.info(f"Output directory: {output_dir}") + # logger.info(f"Conda environment: {conda_env_path}") + + # # Default values (used when defaults=True or variable not found) + # default_batch_size = 1 + # default_patch_size = 400 + # default_confidence = [0.5] + # default_overlap = 0.25 + # default_qos = "regular" + # default_account = self.config.nersc_account + # default_constraint = "gpu" + # default_checkpoint = finetuned_checkpoint + # default_reservation = "" + + # # Load options from Prefect variable + # try: + # seg_options = Variable.get("nersc-segmentation-options", default={}) + # if isinstance(seg_options, str): + # import json + # seg_options = json.loads(seg_options) + # except Exception as e: + # logger.warning(f"Could not load nersc-segmentation-options variable: {e}. Using defaults.") + # seg_options = {"defaults": True} + + # # Determine which values to use + # use_defaults = seg_options.get("defaults", True) + + # if use_defaults: + # logger.info("Using hardcoded default segmentation parameters") + # batch_size = default_batch_size + # patch_size = default_patch_size + # confidence = default_confidence + # overlap = default_overlap + # qos = default_qos + # account = default_account + # constraint = default_constraint + # finetuned_checkpoint = default_checkpoint + # reservation = default_reservation + # else: + # logger.info("Using parameters from nersc-segmentation-options variable") + # batch_size = seg_options.get("batch_size", default_batch_size) + # patch_size = seg_options.get("patch_size", default_patch_size) + # confidence = seg_options.get("confidence", default_confidence) + # overlap = seg_options.get("overlap", default_overlap) + # qos = seg_options.get("qos", default_qos) + # account = seg_options.get("account", default_account) + # constraint = seg_options.get("constraint", default_constraint) + # checkpoint = seg_options.get("checkpoint", default_checkpoint) + # finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" + # reservation = seg_options.get("reservation", default_reservation) + + # # #SBATCH --reservation=_CAP_MarchModCon_GPU + # reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" + # # Format confidence for command line (handles both single value and list) + # if isinstance(confidence, list): + # confidence_str = " ".join(str(c) for c in confidence) + # else: + # confidence_str = str(confidence) + # walltime = "00:59:00" + # job_name = f"seg_{Path(recon_folder_path).name}" job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} @@ -785,58 +898,87 @@ def segmentation_dino( pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" # Load from config - dino_settings = self.config.nersc_segment_dino_settings - cfs_path = dino_settings["cfs_path"] - conda_env_path = dino_settings["conda_env_path"] - seg_scripts_dir = dino_settings["seg_scripts_dir"] - dino_checkpoint = dino_settings["dino_checkpoint_path"] - cpus_per_task = dino_settings["cpus-per-task"] - gpus_per_node = dino_settings["gpus-per-node"] - ntasks_per_node = dino_settings["ntasks-per-node"] - reservation = dino_settings.get("reservation", "") - - input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" - seg_folder = recon_folder_path.replace("/rec", "/seg") - output_dir = f"{pscratch_path}/8.3.2/scratch/{seg_folder}/dino" - - logger.info(f"DINO input dir: {input_dir}") - logger.info(f"DINO output dir: {output_dir}") - - DINO_DEFAULTS = { - "defaults": True, - "batch_size": 4, - "num_nodes": 4, - "nproc_per_node": 4, - "qos": "regular", - "account": self.config.nersc_account, # amsc006 - "constraint": "gpu", # "gpu&hbm80g", - "reservation": reservation, # e.g. "_CAP_MarchModCon_GPU" - "walltime": "00:59:00", - } - try: - seg_options = Variable.get("nersc-dino-seg-options", default={"defaults": True}, _sync=True) - if isinstance(seg_options, str): - import json - seg_options = json.loads(seg_options) - except Exception as e: - logger.warning(f"Could not load nersc-dino-seg-options: {e}. Using defaults.") - seg_options = {"defaults": True} - use_defaults = seg_options.get("defaults", True) - opts = DINO_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in DINO_DEFAULTS.items()} + opts = _load_job_options("nersc-dino-seg-options", self.config.nersc_segment_dino_settings) + cfs_path = opts["cfs_path"] + conda_env_path = opts["conda_env_path"] + seg_scripts_dir = opts["seg_scripts_dir"] + dino_checkpoint = opts["dino_checkpoint_path"] + ntasks_per_node = opts["ntasks-per-node"] + gpus_per_node = opts["gpus-per-node"] + cpus_per_task = opts["cpus-per-task"] batch_size = opts["batch_size"] num_nodes = opts["num_nodes"] nproc_per_node = opts["nproc_per_node"] qos = opts["qos"] account = opts["account"] constraint = opts["constraint"] - walltime = opts["walltime"] - + walltime = opts.get("walltime", "00:59:00") reservation = opts.get("reservation", "") - reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" + input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" + seg_folder = recon_folder_path.replace("/rec", "/seg") + output_dir = f"{pscratch_path}/8.3.2/scratch/{seg_folder}/dino" + + logger.info(f"DINO input dir: {input_dir}") + logger.info(f"DINO output dir: {output_dir}") + + reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" job_name = f"dino_{Path(recon_folder_path).name}" + + # dino_settings = self.config.nersc_segment_dino_settings + # cfs_path = dino_settings["cfs_path"] + # conda_env_path = dino_settings["conda_env_path"] + # seg_scripts_dir = dino_settings["seg_scripts_dir"] + # dino_checkpoint = dino_settings["dino_checkpoint_path"] + # cpus_per_task = dino_settings["cpus-per-task"] + # gpus_per_node = dino_settings["gpus-per-node"] + # ntasks_per_node = dino_settings["ntasks-per-node"] + # reservation = dino_settings.get("reservation", "") + + # input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" + # seg_folder = recon_folder_path.replace("/rec", "/seg") + # output_dir = f"{pscratch_path}/8.3.2/scratch/{seg_folder}/dino" + + # logger.info(f"DINO input dir: {input_dir}") + # logger.info(f"DINO output dir: {output_dir}") + + # DINO_DEFAULTS = { + # "defaults": True, + # "batch_size": 4, + # "num_nodes": 4, + # "nproc_per_node": 4, + # "qos": "regular", + # "account": self.config.nersc_account, # amsc006 + # "constraint": "gpu", # "gpu&hbm80g", + # "reservation": reservation, # e.g. "_CAP_MarchModCon_GPU" + # "walltime": "00:59:00", + # } + # try: + # seg_options = Variable.get("nersc-dino-seg-options", default={"defaults": True}, _sync=True) + # if isinstance(seg_options, str): + # import json + # seg_options = json.loads(seg_options) + # except Exception as e: + # logger.warning(f"Could not load nersc-dino-seg-options: {e}. Using defaults.") + # seg_options = {"defaults": True} + + # use_defaults = seg_options.get("defaults", True) + # opts = DINO_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in DINO_DEFAULTS.items()} + + # batch_size = opts["batch_size"] + # num_nodes = opts["num_nodes"] + # nproc_per_node = opts["nproc_per_node"] + # qos = opts["qos"] + # account = opts["account"] + # constraint = opts["constraint"] + # walltime = opts["walltime"] + + # reservation = opts.get("reservation", "") + # reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" + + # job_name = f"dino_{Path(recon_folder_path).name}" job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} @@ -979,9 +1121,19 @@ def combine_segmentations( user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" - combine_settings = self.config.nersc_combine_segmentation_settings - conda_env_path = combine_settings["conda_env_path"] - seg_scripts_dir = combine_settings["seg_scripts_dir"] + opts = _load_job_options( + "nersc-combine-seg-options", self.config.nersc_combine_segmentation_settings + ) + + conda_env_path = opts["conda_env_path"] + seg_scripts_dir = opts["seg_scripts_dir"] + num_nodes = opts["num_nodes"] + qos = opts["qos"] + account = opts["account"] + constraint = opts["constraint"] + walltime = opts.get("walltime", "01:00:00") + dilate_px = opts["dilate_px"] + reservation = opts.get("reservation", "") seg_folder = recon_folder_path.replace("/rec", "/seg") input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" @@ -994,40 +1146,58 @@ def combine_segmentations( logger.info(f"Combine input dir: {input_dir}") logger.info(f"Combine output dir: {combined_output}") - COMBINE_DEFAULTS = { - "defaults": True, - "num_nodes": combine_settings["num_nodes"], - "qos": "regular", - "account": self.config.nersc_account, # "amsc006", - "constraint": "cpu", - "walltime": "01:00:00", - "dilate_px": 5, - "reservation": combine_settings["reservation"] - } - try: - seg_options = Variable.get("nersc-combine-seg-options", default={"defaults": True}, _sync=True) - if isinstance(seg_options, str): - import json - seg_options = json.loads(seg_options) - except Exception as e: - logger.warning(f"Could not load nersc-combine-seg-options: {e}. Using defaults.") - seg_options = {"defaults": True} - - use_defaults = seg_options.get("defaults", True) - opts = COMBINE_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in COMBINE_DEFAULTS.items()} - - num_nodes = opts["num_nodes"] - qos = opts["qos"] - account = opts["account"] - constraint = opts["constraint"] - walltime = opts["walltime"] - dilate_px = opts["dilate_px"] - reservation = opts["reservation"] - reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" - job_name = f"combine_{Path(recon_folder_path).name}" + # combine_settings = self.config.nersc_combine_segmentation_settings + # conda_env_path = combine_settings["conda_env_path"] + # seg_scripts_dir = combine_settings["seg_scripts_dir"] + + # seg_folder = recon_folder_path.replace("/rec", "/seg") + # input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" + # seg_base = f"{pscratch_path}/8.3.2/scratch/{seg_folder}" + + # sam3_results = f"{seg_base}/sam3" + # dino_results = f"{seg_base}/dino" + # combined_output = f"{seg_base}/combined" + + # logger.info(f"Combine input dir: {input_dir}") + # logger.info(f"Combine output dir: {combined_output}") + + # COMBINE_DEFAULTS = { + # "defaults": True, + # "num_nodes": combine_settings["num_nodes"], + # "qos": "regular", + # "account": self.config.nersc_account, # "amsc006", + # "constraint": "cpu", + # "walltime": "01:00:00", + # "dilate_px": 5, + # "reservation": combine_settings["reservation"] + # } + # try: + # seg_options = Variable.get("nersc-combine-seg-options", default={"defaults": True}, _sync=True) + # if isinstance(seg_options, str): + # import json + # seg_options = json.loads(seg_options) + # except Exception as e: + # logger.warning(f"Could not load nersc-combine-seg-options: {e}. Using defaults.") + # seg_options = {"defaults": True} + + # use_defaults = seg_options.get("defaults", True) + # opts = COMBINE_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in COMBINE_DEFAULTS.items()} + + # num_nodes = opts["num_nodes"] + # qos = opts["qos"] + # account = opts["account"] + # constraint = opts["constraint"] + # walltime = opts["walltime"] + # dilate_px = opts["dilate_px"] + # reservation = opts["reservation"] + + # reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" + + # job_name = f"combine_{Path(recon_folder_path).name}" + # #SBATCH --reservation=_CAP_MarchModCon_CPU job_script = f"""#!/bin/bash #SBATCH -q {qos} @@ -1037,8 +1207,8 @@ def combine_segmentations( #SBATCH -C {constraint} #SBATCH --job-name={job_name} #SBATCH --time={walltime} -#SBATCH --ntasks={combine_settings["ntasks"]} -#SBATCH --cpus-per-task={combine_settings["cpus-per-task"]} +#SBATCH --ntasks={opts["ntasks"]} +#SBATCH --cpus-per-task={opts["cpus-per-task"]} #SBATCH --output={pscratch_path}/tomo_seg_logs/%x_%j.out #SBATCH --error={pscratch_path}/tomo_seg_logs/%x_%j.err From b0844b363a7634f93e47e66428b5b8994f1c52db Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 23 Mar 2026 15:23:30 -0700 Subject: [PATCH 67/72] updating pytests --- orchestration/_tests/test_bl832/test_nersc.py | 63 ++++++++++++++----- orchestration/_tests/test_sfapi_flow.py | 33 +++++++++- 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index de3dff8d..597a57de 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -53,6 +53,10 @@ def mock_config832(mocker): Tests that call flows must pass config=None so Prefect's type validation is never given a MagicMock — the flow will call Config832() internally and get our mock back. + + All settings dicts must be fully populated to match the config YAML schema, + because _load_job_options() passes config_settings directly as the defaults + dict and then accesses keys by name. """ mock_config = mocker.MagicMock() @@ -73,42 +77,73 @@ def mock_config832(mocker): ep.root_path = f"/mock/{attr}" setattr(mock_config, attr, ep) - mock_config.nersc_account = "mock_account" - mock_config.nersc_recon_num_nodes = 4 mock_config.nersc_recon_settings = { - "cpus-per-task": 128, + "qos": "realtime", + "account": "mock_account", + "reservation": "", "num_nodes": 4, + "cpus-per-task": 128, + "walltime": "0:30:00", + } + mock_config.nersc_multiresolution_settings = { + "qos": "realtime", + "account": "mock_account", + "reservation": "", + "cpus-per-task": 128, + "walltime": "0:15:00", } mock_config.nersc_segment_sam3_settings = { + "qos": "regular", + "account": "mock_account", + "constraint": "gpu", + "reservation": "", + "num_nodes": 4, + "ntasks-per-node": 1, + "gpus-per-node": 4, + "cpus-per-task": 32, + "walltime": "00:59:00", + "batch_size": 1, + "patch_size": 400, + "confidence": [0.5], + "overlap": 0.25, + "prompts": ["cell wall", "lumen"], "cfs_path": "/mock/cfs", "conda_env_path": "/mock/conda/sam3", "seg_scripts_dir": "/mock/seg_scripts/sam3", "checkpoints_dir": "/mock/checkpoints", "bpe_path": "/mock/bpe.model", "original_checkpoint_path": "/mock/original.pt", - "finetuned_checkpoint_path": "/mock/finetuned.pt", + "finetuned_checkpoint_path": "/mock/checkpoints/finetuned.pt", + } + mock_config.nersc_segment_dino_settings = { + "qos": "regular", + "account": "mock_account", + "constraint": "gpu", + "reservation": "", + "num_nodes": 4, "ntasks-per-node": 1, + "nproc_per_node": 4, "gpus-per-node": 4, "cpus-per-task": 32, - "prompts": ["cell wall", "lumen"], - } - mock_config.nersc_segment_dino_settings = { + "walltime": "00:59:00", + "batch_size": 4, "cfs_path": "/mock/cfs", "conda_env_path": "/mock/conda/dino", "seg_scripts_dir": "/mock/seg_scripts/dino", "dino_checkpoint_path": "/mock/dino.pt", - "cpus-per-task": 32, - "gpus-per-node": 4, - "ntasks-per-node": 1, - "reservation": "", } mock_config.nersc_combine_segmentation_settings = { - "conda_env_path": "/mock/conda/combine", - "seg_scripts_dir": "/mock/seg_scripts/combine", + "qos": "regular", + "account": "mock_account", + "constraint": "cpu", + "reservation": "", "num_nodes": 1, "ntasks": 128, "cpus-per-task": 1, - "reservation": "", + "walltime": "01:00:00", + "dilate_px": 5, + "conda_env_path": "/mock/conda/combine", + "seg_scripts_dir": "/mock/seg_scripts/combine", } mocker.patch("orchestration.flows.bl832.nersc.Config832", return_value=mock_config) diff --git a/orchestration/_tests/test_sfapi_flow.py b/orchestration/_tests/test_sfapi_flow.py index d6fcfb23..6e9bf225 100644 --- a/orchestration/_tests/test_sfapi_flow.py +++ b/orchestration/_tests/test_sfapi_flow.py @@ -150,13 +150,32 @@ def mock_sfapi_client(): def mock_config832(): """ Mock the Config832 class to provide necessary configurations. + + All settings dicts must be fully populated to match the config YAML schema, + because _load_job_options() passes config_settings directly as the defaults + dict and then accesses keys by name. """ with patch("orchestration.flows.bl832.nersc.Config832") as MockConfig: mock_config = MockConfig.return_value - mock_config.harbor_images832 = { + mock_config.ghcr_images832 = { "recon_image": "mock_recon_image", "multires_image": "mock_multires_image", } + mock_config.nersc_recon_settings = { + "qos": "realtime", + "account": "mock_account", + "reservation": "", + "num_nodes": 4, + "cpus-per-task": 128, + "walltime": "0:30:00", + } + mock_config.nersc_multiresolution_settings = { + "qos": "realtime", + "account": "mock_account", + "reservation": "", + "cpus-per-task": 128, + "walltime": "0:15:00", + } mock_config.apps = {"als_transfer": "some_config"} yield mock_config @@ -264,7 +283,17 @@ def test_job_submission(mock_sfapi_client): from orchestration.flows.bl832.nersc import NERSCTomographyHPCController from sfapi_client.compute import Machine - controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=MagicMock()) + mock_config = MagicMock() + mock_config.nersc_recon_settings = { + "qos": "realtime", + "account": "mock_account", + "reservation": "", + "num_nodes": 4, + "cpus-per-task": 128, + "walltime": "0:30:00", + } + + controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config) file_path = "path/to/file.h5" # Mock Path to extract file and folder names From e5be7d5c8721a8bd92549172fc8f1e3b417b7c1b Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 23 Mar 2026 15:24:44 -0700 Subject: [PATCH 68/72] removing commented code --- orchestration/flows/bl832/nersc.py | 195 ----------------------------- 1 file changed, 195 deletions(-) diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 14207007..57b1d84d 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -153,16 +153,6 @@ def reconstruct( logger.info(f"Folder name: {folder_name}") logger.info(f"Number of nodes: {num_nodes}") - # if num_nodes == 8: - # qos = "debug" - # elif num_nodes < 8: - # qos = "realtime" - # elif num_nodes > 8: - # qos = "premium" - - # account = self.config.nersc_account - # cpus_per_task = self.config.nersc_recon_settings.get("cpus-per-task", 128) - opts = _load_job_options("nersc-reconstruction-options", self.config.nersc_recon_settings) num_nodes = opts.get("num_nodes", num_nodes) @@ -588,89 +578,6 @@ def segmentation_sam3( reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" job_name = f"seg_{Path(recon_folder_path).name}" - # sam3_settings = self.config.nersc_segment_sam3_settings - # cfs_path = sam3_settings["cfs_path"] - # conda_env_path = sam3_settings["conda_env_path"] - # seg_scripts_dir = sam3_settings["seg_scripts_dir"] - # checkpoints_dir = sam3_settings["checkpoints_dir"] - # bpe_path = sam3_settings["bpe_path"] - # original_checkpoint = sam3_settings["original_checkpoint_path"] - # finetuned_checkpoint = sam3_settings["finetuned_checkpoint_path"] - - # ntasks_per_node = sam3_settings["ntasks-per-node"] - # gpus_per_node = sam3_settings["gpus-per-node"] - # cpus_per_task = sam3_settings["cpus-per-task"] - - # prompts = sam3_settings["prompts"] - # if not isinstance(prompts, list) or not prompts: - # raise ValueError("nersc_segmentation_sam3.prompts must be a non-empty list") - # prompts_str = " ".join(f"'{p}'" for p in prompts) - - # input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" - # output_folder = recon_folder_path.replace('/rec', '/seg') - # output_dir = f"{pscratch_path}/8.3.2/scratch/{output_folder}/sam3" - - # logger.info(f"Input directory: {input_dir}") - # logger.info(f"Output directory: {output_dir}") - # logger.info(f"Conda environment: {conda_env_path}") - - # # Default values (used when defaults=True or variable not found) - # default_batch_size = 1 - # default_patch_size = 400 - # default_confidence = [0.5] - # default_overlap = 0.25 - # default_qos = "regular" - # default_account = self.config.nersc_account - # default_constraint = "gpu" - # default_checkpoint = finetuned_checkpoint - # default_reservation = "" - - # # Load options from Prefect variable - # try: - # seg_options = Variable.get("nersc-segmentation-options", default={}) - # if isinstance(seg_options, str): - # import json - # seg_options = json.loads(seg_options) - # except Exception as e: - # logger.warning(f"Could not load nersc-segmentation-options variable: {e}. Using defaults.") - # seg_options = {"defaults": True} - - # # Determine which values to use - # use_defaults = seg_options.get("defaults", True) - - # if use_defaults: - # logger.info("Using hardcoded default segmentation parameters") - # batch_size = default_batch_size - # patch_size = default_patch_size - # confidence = default_confidence - # overlap = default_overlap - # qos = default_qos - # account = default_account - # constraint = default_constraint - # finetuned_checkpoint = default_checkpoint - # reservation = default_reservation - # else: - # logger.info("Using parameters from nersc-segmentation-options variable") - # batch_size = seg_options.get("batch_size", default_batch_size) - # patch_size = seg_options.get("patch_size", default_patch_size) - # confidence = seg_options.get("confidence", default_confidence) - # overlap = seg_options.get("overlap", default_overlap) - # qos = seg_options.get("qos", default_qos) - # account = seg_options.get("account", default_account) - # constraint = seg_options.get("constraint", default_constraint) - # checkpoint = seg_options.get("checkpoint", default_checkpoint) - # finetuned_checkpoint = f"{checkpoints_dir}/{checkpoint}" - # reservation = seg_options.get("reservation", default_reservation) - - # # #SBATCH --reservation=_CAP_MarchModCon_GPU - # reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" - # # Format confidence for command line (handles both single value and list) - # if isinstance(confidence, list): - # confidence_str = " ".join(str(c) for c in confidence) - # else: - # confidence_str = str(confidence) - # walltime = "00:59:00" - # job_name = f"seg_{Path(recon_folder_path).name}" job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} @@ -927,58 +834,6 @@ def segmentation_dino( reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" job_name = f"dino_{Path(recon_folder_path).name}" - # dino_settings = self.config.nersc_segment_dino_settings - # cfs_path = dino_settings["cfs_path"] - # conda_env_path = dino_settings["conda_env_path"] - # seg_scripts_dir = dino_settings["seg_scripts_dir"] - # dino_checkpoint = dino_settings["dino_checkpoint_path"] - # cpus_per_task = dino_settings["cpus-per-task"] - # gpus_per_node = dino_settings["gpus-per-node"] - # ntasks_per_node = dino_settings["ntasks-per-node"] - # reservation = dino_settings.get("reservation", "") - - # input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" - # seg_folder = recon_folder_path.replace("/rec", "/seg") - # output_dir = f"{pscratch_path}/8.3.2/scratch/{seg_folder}/dino" - - # logger.info(f"DINO input dir: {input_dir}") - # logger.info(f"DINO output dir: {output_dir}") - - # DINO_DEFAULTS = { - # "defaults": True, - # "batch_size": 4, - # "num_nodes": 4, - # "nproc_per_node": 4, - # "qos": "regular", - # "account": self.config.nersc_account, # amsc006 - # "constraint": "gpu", # "gpu&hbm80g", - # "reservation": reservation, # e.g. "_CAP_MarchModCon_GPU" - # "walltime": "00:59:00", - # } - # try: - # seg_options = Variable.get("nersc-dino-seg-options", default={"defaults": True}, _sync=True) - # if isinstance(seg_options, str): - # import json - # seg_options = json.loads(seg_options) - # except Exception as e: - # logger.warning(f"Could not load nersc-dino-seg-options: {e}. Using defaults.") - # seg_options = {"defaults": True} - - # use_defaults = seg_options.get("defaults", True) - # opts = DINO_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in DINO_DEFAULTS.items()} - - # batch_size = opts["batch_size"] - # num_nodes = opts["num_nodes"] - # nproc_per_node = opts["nproc_per_node"] - # qos = opts["qos"] - # account = opts["account"] - # constraint = opts["constraint"] - # walltime = opts["walltime"] - - # reservation = opts.get("reservation", "") - # reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" - - # job_name = f"dino_{Path(recon_folder_path).name}" job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} @@ -1149,56 +1004,6 @@ def combine_segmentations( reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" job_name = f"combine_{Path(recon_folder_path).name}" - # combine_settings = self.config.nersc_combine_segmentation_settings - # conda_env_path = combine_settings["conda_env_path"] - # seg_scripts_dir = combine_settings["seg_scripts_dir"] - - # seg_folder = recon_folder_path.replace("/rec", "/seg") - # input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" - # seg_base = f"{pscratch_path}/8.3.2/scratch/{seg_folder}" - - # sam3_results = f"{seg_base}/sam3" - # dino_results = f"{seg_base}/dino" - # combined_output = f"{seg_base}/combined" - - # logger.info(f"Combine input dir: {input_dir}") - # logger.info(f"Combine output dir: {combined_output}") - - # COMBINE_DEFAULTS = { - # "defaults": True, - # "num_nodes": combine_settings["num_nodes"], - # "qos": "regular", - # "account": self.config.nersc_account, # "amsc006", - # "constraint": "cpu", - # "walltime": "01:00:00", - # "dilate_px": 5, - # "reservation": combine_settings["reservation"] - # } - # try: - # seg_options = Variable.get("nersc-combine-seg-options", default={"defaults": True}, _sync=True) - # if isinstance(seg_options, str): - # import json - # seg_options = json.loads(seg_options) - # except Exception as e: - # logger.warning(f"Could not load nersc-combine-seg-options: {e}. Using defaults.") - # seg_options = {"defaults": True} - - # use_defaults = seg_options.get("defaults", True) - # opts = COMBINE_DEFAULTS if use_defaults else {k: seg_options.get(k, v) for k, v in COMBINE_DEFAULTS.items()} - - # num_nodes = opts["num_nodes"] - # qos = opts["qos"] - # account = opts["account"] - # constraint = opts["constraint"] - # walltime = opts["walltime"] - # dilate_px = opts["dilate_px"] - # reservation = opts["reservation"] - - # reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" - - # job_name = f"combine_{Path(recon_folder_path).name}" - -# #SBATCH --reservation=_CAP_MarchModCon_CPU job_script = f"""#!/bin/bash #SBATCH -q {qos} #SBATCH -A {account} From b7b785a85c88e8bb3e13ead481e3940639e01eea Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 23 Mar 2026 15:29:21 -0700 Subject: [PATCH 69/72] updating prefect.yaml --- orchestration/flows/bl832/prefect.yaml | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/orchestration/flows/bl832/prefect.yaml b/orchestration/flows/bl832/prefect.yaml index 8514cbdb..b827045e 100644 --- a/orchestration/flows/bl832/prefect.yaml +++ b/orchestration/flows/bl832/prefect.yaml @@ -43,23 +43,11 @@ deployments: name: nersc_recon_flow_pool work_queue_name: nersc_recon_flow_queue -- name: nersc_recon_multinode_flow - entrypoint: orchestration/flows/bl832/nersc.py:nersc_recon_multinode_flow +- name: nersc_petiole_segment_flow + entrypoint: orchestration/flows/bl832/nersc.py:nersc_petiole_segment_flow work_pool: name: nersc_recon_flow_pool - work_queue_name: nersc_recon_multinode_flow_queue - -- name: nersc_forge_recon_segment_flow - entrypoint: orchestration/flows/bl832/nersc.py:nersc_forge_recon_segment_flow - work_pool: - name: nersc_recon_flow_pool - work_queue_name: nersc_forge_recon_segment_flow_queue - -- name: nersc_forge_recon_multisegment_flow - entrypoint: orchestration/flows/bl832/nersc.py:nersc_forge_recon_multisegment_flow - work_pool: - name: nersc_recon_flow_pool - work_queue_name: nersc_forge_recon_multisegment_flow_queue + work_queue_name: nersc_petiole_segment_flow_queue - name: nersc_streaming_flow entrypoint: orchestration/flows/bl832/nersc.py:nersc_streaming_flow From 836c9e8c1dd8bb0af030d07d8ce1adaecaaf8652 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Mon, 23 Mar 2026 15:34:05 -0700 Subject: [PATCH 70/72] including script_name as part of config for sam3/dino/combine --- config.yml | 3 +++ orchestration/flows/bl832/nersc.py | 11 +++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/config.yml b/config.yml index 3228aad9..cc7d3f94 100644 --- a/config.yml +++ b/config.yml @@ -188,6 +188,7 @@ hpc_submission_settings832: cpus-per-task: 128 walltime: "00:59:00" # ── Inference parameters ────────────────────────────────────────────────── + script_name: "src/inference_v6.py" batch_size: 1 patch_size: 400 confidence: @@ -219,6 +220,7 @@ hpc_submission_settings832: cpus-per-task: 128 walltime: "00:59:00" # ── Inference parameters ────────────────────────────────────────────────── + script_name: "src.inference_dino_v1" batch_size: 4 # ── Paths ───────────────────────────────────────────────────────────────── cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 @@ -236,6 +238,7 @@ hpc_submission_settings832: cpus-per-task: 1 walltime: "01:00:00" # ── Combination parameters ──────────────────────────────────────────────── + script_name: "src.combine_sam_dino_v3" dilate_px: 5 # ── Paths ───────────────────────────────────────────────────────────────── cfs_path: /global/cfs/cdirs/als/data_mover/8.3.2 diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 57b1d84d..7264f7c7 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -548,6 +548,7 @@ def segmentation_sam3( constraint = opts["constraint"] walltime = opts.get("walltime", "00:59:00") reservation = opts.get("reservation", "") + script_name = opts.get("script_name", "src/inference_v6.py") prompts = opts["prompts"] if not isinstance(prompts, list) or not prompts: @@ -646,7 +647,7 @@ def segmentation_sam3( # Change to script directory cd {seg_scripts_dir} -# Run inference with v6 +# Run inference export TORCH_DISTRIBUTED_DEBUG=DETAIL export NCCL_DEBUG=INFO export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 @@ -661,7 +662,7 @@ def segmentation_sam3( --rdzv_id=$SLURM_JOB_ID \ --rdzv_backend=c10d \ --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ - src/inference_v6.py \ + {script_name} \ --input-dir "${{INPUT_DIR}}" \ --output-dir "${{OUTPUT_DIR}}" \ --patch-size {patch_size} \ @@ -823,6 +824,7 @@ def segmentation_dino( constraint = opts["constraint"] walltime = opts.get("walltime", "00:59:00") reservation = opts.get("reservation", "") + script_name = opts.get("script_name", "src.inference_dino_v1") input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" seg_folder = recon_folder_path.replace("/rec", "/seg") @@ -896,7 +898,7 @@ def segmentation_dino( --rdzv_id=$SLURM_JOB_ID \\ --rdzv_backend=c10d \\ --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \\ - -m src.inference_dino_v1 \\ + -m {script_name} \\ --input-dir "{input_dir}" \\ --output-dir "{output_dir}" \\ --batch-size {batch_size} \\ @@ -989,6 +991,7 @@ def combine_segmentations( walltime = opts.get("walltime", "01:00:00") dilate_px = opts["dilate_px"] reservation = opts.get("reservation", "") + script_name = opts.get("script_name", "src.combine_sam_dino_v3") seg_folder = recon_folder_path.replace("/rec", "/seg") input_dir = f"{pscratch_path}/8.3.2/scratch/{recon_folder_path}" @@ -1038,7 +1041,7 @@ def combine_segmentations( cd {seg_scripts_dir} echo "--- Running SAM3 + DINO combination (v3) ---" -python -m src.combine_sam_dino_v3 \\ +python -m {script_name} \\ --input-dir "{input_dir}" \\ --instance-masks-dir "{sam3_results}" \\ --semantic-masks-dir "{dino_results}/semantic_masks" \\ From 0be1eebf59c41e0655634aedd1065277ed541c86 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:45:58 -0700 Subject: [PATCH 71/72] renaming DINO references to DINOv3 --- config.yml | 2 +- orchestration/flows/bl832/config.py | 2 +- orchestration/flows/bl832/nersc.py | 88 ++++++++++++++--------------- 3 files changed, 46 insertions(+), 46 deletions(-) diff --git a/config.yml b/config.yml index cc7d3f94..4e3c6ed2 100644 --- a/config.yml +++ b/config.yml @@ -207,7 +207,7 @@ hpc_submission_settings832: bpe_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/bpe_simple_vocab_16e6.txt.gz original_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/sam3.pt finetuned_checkpoint_path: /global/cfs/cdirs/als/data_mover/8.3.2/tomography_segmentation_scripts/sam3_finetune/sam3/checkpoint_v6.pt - nersc_segmentation_dino: + nersc_segmentation_dinov3: # ── SLURM resource allocation ───────────────────────────────────────────── qos: regular account: als diff --git a/orchestration/flows/bl832/config.py b/orchestration/flows/bl832/config.py index b0c9a40f..16b03629 100644 --- a/orchestration/flows/bl832/config.py +++ b/orchestration/flows/bl832/config.py @@ -34,5 +34,5 @@ def _beam_specific_config(self) -> None: self.nersc_recon_settings = self.config["hpc_submission_settings832"]["nersc_reconstruction"] self.nersc_multiresolution_settings = self.config["hpc_submission_settings832"]["nersc_multiresolution"] self.nersc_segment_sam3_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_sam3"] - self.nersc_segment_dino_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_dino"] + self.nersc_segment_dinov3_settings = self.config["hpc_submission_settings832"]["nersc_segmentation_dinov3"] self.nersc_combine_segmentation_settings = self.config["hpc_submission_settings832"]["nersc_combine_segmentations"] diff --git a/orchestration/flows/bl832/nersc.py b/orchestration/flows/bl832/nersc.py index 7264f7c7..42378d70 100644 --- a/orchestration/flows/bl832/nersc.py +++ b/orchestration/flows/bl832/nersc.py @@ -789,25 +789,25 @@ def segmentation_sam3( "output_dir": None } - def segmentation_dino( + def segmentation_dinov3( self, recon_folder_path: str = "", ) -> bool: """ - Run DINO segmentation at NERSC Perlmutter via SFAPI Slurm job. + Run DINOv3 segmentation at NERSC Perlmutter via SFAPI Slurm job. :param recon_folder_path: Relative path to the reconstructed data folder, e.g. 'folder_name/recYYYYMMDD_hhmmss_scanname/' :return: True if the job completed successfully, False otherwise. """ - logger.info("Starting NERSC DINO segmentation process.") + logger.info("Starting NERSC DINOv3 segmentation process.") user = self.client.user() pscratch_path = f"/pscratch/sd/{user.name[0]}/{user.name}" # Load from config - opts = _load_job_options("nersc-dino-seg-options", self.config.nersc_segment_dino_settings) + opts = _load_job_options("nersc-dinov3-seg-options", self.config.nersc_segment_dinov3_settings) cfs_path = opts["cfs_path"] conda_env_path = opts["conda_env_path"] @@ -830,8 +830,8 @@ def segmentation_dino( seg_folder = recon_folder_path.replace("/rec", "/seg") output_dir = f"{pscratch_path}/8.3.2/scratch/{seg_folder}/dino" - logger.info(f"DINO input dir: {input_dir}") - logger.info(f"DINO output dir: {output_dir}") + logger.info(f"DINOv3 input dir: {input_dir}") + logger.info(f"DINOv3 output dir: {output_dir}") reservation_line = f"#SBATCH --reservation={reservation}" if reservation else "" job_name = f"dino_{Path(recon_folder_path).name}" @@ -870,7 +870,7 @@ def segmentation_dino( mkdir -p {pscratch_path}/tomo_seg_logs echo "============================================================" -echo "DINO SEGMENTATION STARTED: $(date)" +echo "DINOv3 SEGMENTATION STARTED: $(date)" echo "============================================================" echo "Master: $MASTER_ADDR:$MASTER_PORT" echo "Nodes: $SLURM_JOB_NODELIST" @@ -914,7 +914,7 @@ def segmentation_dino( echo "" echo "============================================================" -echo "DINO SEGMENTATION COMPLETED: $(date)" +echo "DINOv3 SEGMENTATION COMPLETED: $(date)" echo "============================================================" echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" echo "Images processed: ${{NUM_IMAGES}}" @@ -926,7 +926,7 @@ def segmentation_dino( exit $SEG_STATUS """ try: - logger.info("Submitting DINO segmentation job to Perlmutter.") + logger.info("Submitting DINOv3 segmentation job to Perlmutter.") perlmutter = self.client.compute(Machine.perlmutter) job = perlmutter.submit_job(job_script) logger.info(f"Submitted job ID: {job.jobid}") @@ -940,11 +940,11 @@ def segmentation_dino( logger.info(f"Job {job.jobid} current state: {job.state}") job.complete() - logger.info("DINO segmentation job completed successfully.") + logger.info("DINOv3 segmentation job completed successfully.") return True except Exception as e: - logger.error(f"Error during DINO segmentation job submission or completion: {e}") + logger.error(f"Error during DINOv3 segmentation job submission or completion: {e}") match = re.search(r"Job not found:\s*(\d+)", str(e)) if match: jobid = match.group(1) @@ -953,7 +953,7 @@ def segmentation_dino( job = self.client.compute(Machine.perlmutter).job(jobid=jobid) time.sleep(30) job.complete() - logger.info("DINO segmentation job completed successfully after recovery.") + logger.info("DINOv3 segmentation job completed successfully after recovery.") return True except Exception as recovery_err: logger.error(f"Failed to recover job {jobid}: {recovery_err}") @@ -966,7 +966,7 @@ def combine_segmentations( recon_folder_path: str = "", ) -> bool: """ - Run CPU-based combination of SAM3+DINO segmentation results + Run CPU-based combination of SAM3+DINOv3 segmentation results at NERSC Perlmutter via SFAPI Slurm job. :param recon_folder_path: Relative path to the reconstructed data folder, @@ -998,7 +998,7 @@ def combine_segmentations( seg_base = f"{pscratch_path}/8.3.2/scratch/{seg_folder}" sam3_results = f"{seg_base}/sam3" - dino_results = f"{seg_base}/dino" + dinov3_results = f"{seg_base}/dino" combined_output = f"{seg_base}/combined" logger.info(f"Combine input dir: {input_dir}") @@ -1031,7 +1031,7 @@ def combine_segmentations( echo "============================================================" echo "Input: {input_dir}" echo "SAM3: {sam3_results}" -echo "DINO: {dino_results}" +echo "DINOv3: {dinov3_results}" echo "Output: {combined_output}" echo "Dilate: {dilate_px}px" echo "============================================================" @@ -1040,18 +1040,18 @@ def combine_segmentations( cd {seg_scripts_dir} -echo "--- Running SAM3 + DINO combination (v3) ---" +echo "--- Running SAM3 + DINOv3 combination ---" python -m {script_name} \\ --input-dir "{input_dir}" \\ --instance-masks-dir "{sam3_results}" \\ - --semantic-masks-dir "{dino_results}/semantic_masks" \\ + --semantic-masks-dir "{dinov3_results}/semantic_masks" \\ --output-dir "{combined_output}/sam_dino" \\ --dilate-px {dilate_px} \\ --save-extracted \\ --dino-trust Cortex Phloem_Fibers Phloem Air-based_Pith_cells Water-based_Pith_cells SAM_DINO_STATUS=$? -echo "SAM3+DINO exit status: $SAM_DINO_STATUS" +echo "SAM3+DINOv3 exit status: $SAM_DINO_STATUS" END_TIME=$(date +%s) DURATION=$((END_TIME - START_TIME)) @@ -1063,7 +1063,7 @@ def combine_segmentations( echo "SEGMENTATION COMBINATION COMPLETED: $(date)" echo "============================================================" echo "Total time: ${{MINUTES}}m ${{SECONDS}}s (${{DURATION}}s)" -echo "SAM3+DINO status: $SAM_DINO_STATUS" +echo "SAM3+DINOv3 status: $SAM_DINO_STATUS" echo "============================================================" chmod -R 2775 {combined_output} 2>/dev/null || true @@ -1624,10 +1624,10 @@ def nersc_petiole_segment_flow( nersc_reconstruction_success = False sam3_success = False - dino_success = False + dinov3_success = False data832_tiff_transfer_success = False data832_sam3_transfer_success = False - data832_dino_transfer_success = False + data832_dinov3_transfer_success = False data832_combined_transfer_success = False # ── STEP 1: Multinode Reconstruction ───────────────────────────────────── @@ -1689,7 +1689,7 @@ def nersc_petiole_segment_flow( sam3_future = nersc_segmentation_sam3_task.submit( recon_folder_path=scratch_path_tiff, config=config ) - dino_future = nersc_segmentation_dino_task.submit( + dinov3_future = nersc_segmentation_dinov3_task.submit( recon_folder_path=scratch_path_tiff, config=config ) @@ -1710,27 +1710,27 @@ def nersc_petiole_segment_flow( except Exception as e: logger.error(f"Failed to transfer SAM3 outputs to data832: {e}") - dino_success = dino_future.result() - logger.info(f"DINO segmentation result: {dino_success}") - if dino_success: - logger.info("Transferring DINO segmentation outputs to data832") - dino_segment_path = f"{folder_name}/seg{file_name}/dino" + dinov3_success = dinov3_future.result() + logger.info(f"DINOv3 segmentation result: {dinov3_success}") + if dinov3_success: + logger.info("Transferring DINOv3 segmentation outputs to data832") + dinov3_segment_path = f"{folder_name}/seg{file_name}/dino" try: - data832_dino_transfer_success = transfer_controller.copy( - file_path=dino_segment_path, + data832_dinov3_transfer_success = transfer_controller.copy( + file_path=dinov3_segment_path, source=config.nersc832_alsdev_pscratch_scratch, destination=config.data832_scratch ) - logger.info(f"DINO transfer to data832 success: {data832_dino_transfer_success}") + logger.info(f"DINOv3 transfer to data832 success: {data832_dinov3_transfer_success}") except Exception as e: - logger.error(f"Failed to transfer DINO outputs to data832: {e}") + logger.error(f"Failed to transfer DINOv3 outputs to data832: {e}") - any_seg_success = any([sam3_success, dino_success]) + any_seg_success = any([sam3_success, dinov3_success]) - logger.info(f"Segmentation results — SAM3: {sam3_success}, DINO: {dino_success}") + logger.info(f"Segmentation results — SAM3: {sam3_success}, DINOv3: {dinov3_success}") - # ── STEP 5: Combine Segmentations (after SAM3+DINO complete) ── - if dino_success and sam3_success: + # ── STEP 5: Combine Segmentations (after SAM3+DINOv3 complete) ── + if dinov3_success and sam3_success: logger.info("Running segmentation combination.") combine_future = nersc_combine_segmentations_task.submit( @@ -1799,7 +1799,7 @@ def nersc_petiole_segment_flow( source_endpoint=config.nersc832_alsdev_pscratch_scratch, check_endpoint=config.data832_scratch if any([ data832_sam3_transfer_success, - data832_dino_transfer_success, + data832_dinov3_transfer_success, ]) else None, days_from_now=1.0 ) @@ -1818,7 +1818,7 @@ def nersc_petiole_segment_flow( logger.warning(f"Failed to schedule data832 tiff pruning: {e}") if any([data832_sam3_transfer_success, - data832_dino_transfer_success, + data832_dinov3_transfer_success, data832_combined_transfer_success]): try: prune_controller.prune( @@ -1836,7 +1836,7 @@ def nersc_petiole_segment_flow( else: logger.warning( f"Flow completed with issues: recon={nersc_reconstruction_success}, " - f"sam3={sam3_success}, dino={dino_success}" + f"sam3={sam3_success}, dinov3={dinov3_success}" ) return False @@ -1991,8 +1991,8 @@ def nersc_segmentation_sam3_task( return nersc_segmentation_success -@task(name="nersc_segmentation_dino_task") -def nersc_segmentation_dino_task( +@task(name="nersc_segmentation_dinov3_task") +def nersc_segmentation_dinov3_task( recon_folder_path: str, config: Optional[Config832] = None, ) -> bool: @@ -2001,12 +2001,12 @@ def nersc_segmentation_dino_task( logger.info("No config provided, using default Config832.") config = Config832() tomography_controller = get_controller(hpc_type=HPC.NERSC, config=config) - logger.info(f"Starting NERSC DINO segmentation task for {recon_folder_path=}") - success = tomography_controller.segmentation_dino(recon_folder_path=recon_folder_path) + logger.info(f"Starting NERSC DINOv3 segmentation task for {recon_folder_path=}") + success = tomography_controller.segmentation_dinov3(recon_folder_path=recon_folder_path) if not success: - logger.error("DINO segmentation failed.") + logger.error("DINOv3 segmentation failed.") else: - logger.info("DINO segmentation successful.") + logger.info("DINOv3 segmentation successful.") return success From d696931d4ff16fa253f8742fddd7b2bc4fd83d87 Mon Sep 17 00:00:00 2001 From: David Abramov Date: Wed, 1 Apr 2026 16:55:41 -0700 Subject: [PATCH 72/72] Updating pytest for nersc bc of the DINO -> DINOv3 naming changes --- orchestration/_tests/test_bl832/test_nersc.py | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/orchestration/_tests/test_bl832/test_nersc.py b/orchestration/_tests/test_bl832/test_nersc.py index 597a57de..bfe0bad0 100644 --- a/orchestration/_tests/test_bl832/test_nersc.py +++ b/orchestration/_tests/test_bl832/test_nersc.py @@ -115,7 +115,7 @@ def mock_config832(mocker): "original_checkpoint_path": "/mock/original.pt", "finetuned_checkpoint_path": "/mock/checkpoints/finetuned.pt", } - mock_config.nersc_segment_dino_settings = { + mock_config.nersc_segment_dinov3_settings = { "qos": "regular", "account": "mock_account", "constraint": "gpu", @@ -358,10 +358,10 @@ def capture_script(script): # ────────────────────────────────────────────────────────────────────────────── -# segmentation_dino +# segmentation_dinov3 # ────────────────────────────────────────────────────────────────────────────── -def test_segmentation_dino_success(mocker, mock_sfapi_client, mock_config832): +def test_segmentation_dinov3_success(mocker, mock_sfapi_client, mock_config832): from orchestration.flows.bl832.nersc import NERSCTomographyHPCController from sfapi_client.compute import Machine @@ -369,7 +369,7 @@ def test_segmentation_dino_success(mocker, mock_sfapi_client, mock_config832): mocker.patch("orchestration.flows.bl832.nersc.Variable.get", return_value={"defaults": True}) controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - result = controller.segmentation_dino(recon_folder_path="folder/recfile") + result = controller.segmentation_dinov3(recon_folder_path="folder/recfile") mock_sfapi_client.compute.assert_called_with(Machine.perlmutter) mock_sfapi_client.compute.return_value.submit_job.assert_called_once() @@ -377,7 +377,7 @@ def test_segmentation_dino_success(mocker, mock_sfapi_client, mock_config832): assert result is True -def test_segmentation_dino_submission_failure(mocker, mock_sfapi_client, mock_config832): +def test_segmentation_dinov3_submission_failure(mocker, mock_sfapi_client, mock_config832): from orchestration.flows.bl832.nersc import NERSCTomographyHPCController mocker.patch("orchestration.flows.bl832.nersc.time.sleep") @@ -385,12 +385,12 @@ def test_segmentation_dino_submission_failure(mocker, mock_sfapi_client, mock_co mock_sfapi_client.compute.return_value.submit_job.side_effect = Exception("No GPU nodes") controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - result = controller.segmentation_dino(recon_folder_path="folder/recfile") + result = controller.segmentation_dinov3(recon_folder_path="folder/recfile") assert result is False -def test_segmentation_dino_output_paths(mocker, mock_sfapi_client, mock_config832): +def test_segmentation_dinov3_output_paths(mocker, mock_sfapi_client, mock_config832): """ Output dir should swap /rec for /seg in the folder name and route to /dino. @@ -413,7 +413,7 @@ def capture(script): mock_sfapi_client.compute.return_value.submit_job.side_effect = capture controller = NERSCTomographyHPCController(client=mock_sfapi_client, config=mock_config832) - controller.segmentation_dino(recon_folder_path="folder/recfile") + controller.segmentation_dinov3(recon_folder_path="folder/recfile") script = captured_scripts[0] assert "segfile" in script @@ -527,32 +527,32 @@ def test_nersc_segmentation_sam3_task_failure(mocker, mock_config832): assert result["success"] is False -def test_nersc_segmentation_dino_task_success(mocker, mock_config832): - from orchestration.flows.bl832.nersc import nersc_segmentation_dino_task +def test_nersc_segmentation_dinov3_task_success(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_segmentation_dinov3_task mocker.patch("orchestration.flows.bl832.nersc.get_run_logger", return_value=mocker.MagicMock()) mock_controller = mocker.MagicMock() - mock_controller.segmentation_dino.return_value = True + mock_controller.segmentation_dinov3.return_value = True mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) - result = nersc_segmentation_dino_task.fn( + result = nersc_segmentation_dinov3_task.fn( recon_folder_path="folder/recfile", config=mock_config832 ) - mock_controller.segmentation_dino.assert_called_once_with(recon_folder_path="folder/recfile") + mock_controller.segmentation_dinov3.assert_called_once_with(recon_folder_path="folder/recfile") assert result is True -def test_nersc_segmentation_dino_task_failure(mocker, mock_config832): - from orchestration.flows.bl832.nersc import nersc_segmentation_dino_task +def test_nersc_segmentation_dinov3_task_failure(mocker, mock_config832): + from orchestration.flows.bl832.nersc import nersc_segmentation_dinov3_task mocker.patch("orchestration.flows.bl832.nersc.get_run_logger", return_value=mocker.MagicMock()) mock_controller = mocker.MagicMock() - mock_controller.segmentation_dino.return_value = False + mock_controller.segmentation_dinov3.return_value = False mocker.patch("orchestration.flows.bl832.nersc.get_controller", return_value=mock_controller) - result = nersc_segmentation_dino_task.fn( + result = nersc_segmentation_dinov3_task.fn( recon_folder_path="folder/recfile", config=mock_config832 ) @@ -594,7 +594,7 @@ def test_nersc_combine_segmentations_task_failure(mocker, mock_config832): # ────────────────────────────────────────────────────────────────────────────── -# nersc_petiole_segment_flow (recon + SAM3 + DINO + combine) +# nersc_petiole_segment_flow (recon + SAM3 + DINOv3 + combine) # # Replaces the former nersc_forge_recon_multisegment_flow tests. # The cleaned nersc.py exposes nersc_petiole_segment_flow as the canonical @@ -607,7 +607,7 @@ def test_nersc_combine_segmentations_task_failure(mocker, mock_config832): # ────────────────────────────────────────────────────────────────────────────── def test_petiole_segment_flow_both_succeed(mocker, mock_config832, mock_recon_success): - """Recon + SAM3 + DINO all succeed → combine is called → flow returns True.""" + """Recon + SAM3 + DINOv3 all succeed → combine is called → flow returns True.""" from orchestration.flows.bl832.nersc import nersc_petiole_segment_flow mock_controller = mocker.MagicMock() @@ -620,13 +620,13 @@ def test_petiole_segment_flow_both_succeed(mocker, mock_config832, mock_recon_su mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") - mock_dino_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dino_task") + mock_dinov3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dinov3_task") mock_combine_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_combine_segmentations_task") mock_sam3_task.submit.return_value = _make_future( mocker, {"success": True, "job_id": "1", "timing": None, "output_dir": "/out/sam3"} ) - mock_dino_task.submit.return_value = _make_future(mocker, True) + mock_dinov3_task.submit.return_value = _make_future(mocker, True) mock_combine_task.submit.return_value = _make_future(mocker, True) result = nersc_petiole_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) @@ -634,7 +634,7 @@ def test_petiole_segment_flow_both_succeed(mocker, mock_config832, mock_recon_su assert result is True mock_controller.reconstruct.assert_called_once() mock_sam3_task.submit.assert_called_once() - mock_dino_task.submit.assert_called_once() + mock_dinov3_task.submit.assert_called_once() mock_combine_task.submit.assert_called_once() @@ -652,13 +652,13 @@ def test_petiole_segment_flow_only_sam3_succeeds(mocker, mock_config832, mock_re mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") - mock_dino_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dino_task") + mock_dinov3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dinov3_task") mock_combine_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_combine_segmentations_task") mock_sam3_task.submit.return_value = _make_future( mocker, {"success": True, "job_id": "1", "timing": None, "output_dir": "/out/sam3"} ) - mock_dino_task.submit.return_value = _make_future(mocker, False) + mock_dinov3_task.submit.return_value = _make_future(mocker, False) result = nersc_petiole_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None) @@ -680,13 +680,13 @@ def test_petiole_segment_flow_both_seg_fail(mocker, mock_config832, mock_recon_s mocker.patch("orchestration.flows.bl832.nersc.get_prune_controller", return_value=mocker.MagicMock()) mock_sam3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_sam3_task") - mock_dino_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dino_task") + mock_dinov3_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_segmentation_dinov3_task") mock_combine_task = mocker.patch("orchestration.flows.bl832.nersc.nersc_combine_segmentations_task") mock_sam3_task.submit.return_value = _make_future( mocker, {"success": False, "job_id": None, "timing": None, "output_dir": None} ) - mock_dino_task.submit.return_value = _make_future(mocker, False) + mock_dinov3_task.submit.return_value = _make_future(mocker, False) result = nersc_petiole_segment_flow(file_path="folder/file.h5", num_nodes=4, config=None)