From 4f1f0a0eb1b2b685371154811873bfcf610b497b Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 24 Nov 2025 20:06:47 +0100 Subject: [PATCH 01/33] [feature] update mesh visualizer using pyvista, and convert from global to camera coords --- src/poseforge/neuromechfly/constants.py | 12 +++ .../neuromechfly/scripts/visualize_meshes.py | 78 ++++++++++++------- 2 files changed, 63 insertions(+), 27 deletions(-) diff --git a/src/poseforge/neuromechfly/constants.py b/src/poseforge/neuromechfly/constants.py index 12dac23..9e3a8b6 100644 --- a/src/poseforge/neuromechfly/constants.py +++ b/src/poseforge/neuromechfly/constants.py @@ -38,6 +38,18 @@ f"{leg}{link}" for leg in legs for link in leg_keypoints_nmf ] + ["LPedicel", "RPedicel"] +all_segment_names_per_leg = [ + "Coxa", + "Femur", + "Tibia", + "Tarsus1", + "Tarsus2", + "Tarsus3", + "Tarsus4", + "Tarsus5", +] + + kchain_plotting_colors = { "LF": np.array([15, 115, 153]) / 255, "LM": np.array([26, 141, 175]) / 255, diff --git a/src/poseforge/neuromechfly/scripts/visualize_meshes.py b/src/poseforge/neuromechfly/scripts/visualize_meshes.py index a965440..55fe421 100644 --- a/src/poseforge/neuromechfly/scripts/visualize_meshes.py +++ b/src/poseforge/neuromechfly/scripts/visualize_meshes.py @@ -1,54 +1,58 @@ import numpy as np -import pandas as pd import pyvista as pv -import flygym +import h5py from xml.etree import ElementTree from pathlib import Path from scipy.spatial.transform import Rotation +from scipy.linalg import rq +from poseforge.neuromechfly.constants import legs, all_segment_names_per_leg -df = pd.read_pickle( - "bulk_data/nmf_rendering_enhanced/BO_Gal4_fly1_trial001/segment_000/subsegment_000/processed_kinematic_states.pkl" + +# Define paths +subsegment_dir = Path( + "bulk_data/nmf_rendering/BO_Gal4_fly1_trial001/segment_000/subsegment_000" ) -flygym_data_dir = Path(flygym.__file__).parent / "data" +sim_data_path = subsegment_dir / "processed_simulation_data.h5" +flygym_data_dir = Path("~/projects/flygym/flygym").expanduser() / "data" nmf_mesh_dir = flygym_data_dir / "mesh" mjcf_path = flygym_data_dir / "mjcf/neuromechfly_seqik_kinorder_ypr.xml" -sides = "LR" -positions = "FMH" -links = [ - "Coxa", - "Femur", - "Tibia", - "Tarsus1", - "Tarsus2", - "Tarsus3", - "Tarsus4", - "Tarsus5", -] - +# Load NeuroMechFly model mjcf_tree = ElementTree.parse(mjcf_path) worldbody = mjcf_tree.find("worldbody") body_attributes = {body.attrib["name"]: body.attrib for body in worldbody.iter("body")} frame_idx = 10 -entry = df.loc[frame_idx] +# Load simulation data +with h5py.File(sim_data_path, "r") as f: + # all_seg_pos_global: (n_frames, n_segments, 3) + all_seg_pos_global = f["raw/body_segment_states/pos_global"][:] + # all_seg_quat_global: (n_frames, n_segments, 4) + all_seg_quat_global = f["raw/body_segment_states/quat_global"][:] + # all_cam_matrices: (n_frames, 3, 4) + all_cam_matrices = f["raw/camera_matrix"][:] + # all_seg_names: list of length n_segments + all_seg_names = list(f["raw/body_segment_states"].attrs["keys"]) plotter = pv.Plotter() segments_to_include = [ - f"{side}{pos}{link}" for side in sides for pos in positions for link in links + f"{leg}{seg}" for leg in legs for seg in all_segment_names_per_leg ] segments_to_include += ["Thorax"] meshes = [] -for key in segments_to_include: - translation = entry[f"body_seg_pos_global_{key}"] - quaternion = entry[f"body_seg_quat_global_{key}"] +for seg_name in segments_to_include: + seg_idx = all_seg_names.index(seg_name) + translation = all_seg_pos_global[frame_idx, seg_idx, :] + quaternion = all_seg_quat_global[frame_idx, seg_idx, :] + cam_matrix = all_cam_matrices[frame_idx, :, :] + placement_transform = np.eye(4) placement_transform[:3, :3] = Rotation.from_quat(quaternion).as_matrix() placement_transform[:3, 3] = translation - mesh_file = nmf_mesh_dir / f"{key}.stl" + mesh_file = nmf_mesh_dir / f"{seg_name}.stl" mesh = pv.read(mesh_file) # Scale @@ -63,10 +67,30 @@ placement_transform[:3, 3] = translation mesh.transform(placement_transform) - meshes.append(mesh) - plotter.add_mesh(mesh, show_edges=False, name=key, smooth_shading=True) + # Convert to coordinates relative to camera + # Given camera projection matrix P = cam_matrix (3x4), solve for: + # R = camera rotation matrix (3x3) + # K = camera intrinsic matrix (3x3) + # where P = K[R|t] + # First, run RQ decomposition to get K and R + cam_intrinsics, cam_rotation = rq(cam_matrix[:, :3]) + # Make camera intrinsic mastrix have positive diagonal (just a convention) + _sign_multiplier = np.diag(np.sign(np.diag(cam_intrinsics))) + cam_intrinsics = cam_intrinsics @ _sign_multiplier + cam_rotation = _sign_multiplier @ cam_rotation + # Get camera translation in world coords + cam_translation = np.linalg.inv(cam_intrinsics) @ cam_matrix[:, 3] + + transform_world2cam = np.eye(4) + transform_world2cam[:3, :3] = cam_rotation + transform_world2cam[:3, 3] = cam_translation + # print(transform_world2cam) + # Cam matrix describes cam state in world coords: invert it for cam-to-world mapping + transform_cam2world = np.linalg.inv(transform_world2cam) + mesh.transform(transform_cam2world) - # print(f"Loading {key}:\n\ttranslation={translation}\n\tquaternion={quaternion}") + meshes.append(mesh) + plotter.add_mesh(mesh, show_edges=False, name=seg_name, smooth_shading=True) plotter.set_background("black") From 5d54f99ff1daac06989c182674e6352e57462322 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Thu, 27 Nov 2025 14:00:52 +0100 Subject: [PATCH 02/33] [feature] finish script to visualize 6d mesh states in pyvista --- .../neuromechfly/scripts/visualize_meshes.py | 124 +++++++++++------- 1 file changed, 75 insertions(+), 49 deletions(-) diff --git a/src/poseforge/neuromechfly/scripts/visualize_meshes.py b/src/poseforge/neuromechfly/scripts/visualize_meshes.py index 55fe421..cb6554e 100644 --- a/src/poseforge/neuromechfly/scripts/visualize_meshes.py +++ b/src/poseforge/neuromechfly/scripts/visualize_meshes.py @@ -1,4 +1,5 @@ import numpy as np +import pandas as pd import pyvista as pv import h5py from xml.etree import ElementTree @@ -22,78 +23,103 @@ mjcf_tree = ElementTree.parse(mjcf_path) worldbody = mjcf_tree.find("worldbody") body_attributes = {body.attrib["name"]: body.attrib for body in worldbody.iter("body")} -frame_idx = 10 # Load simulation data with h5py.File(sim_data_path, "r") as f: - # all_seg_pos_global: (n_frames, n_segments, 3) all_seg_pos_global = f["raw/body_segment_states/pos_global"][:] - # all_seg_quat_global: (n_frames, n_segments, 4) all_seg_quat_global = f["raw/body_segment_states/quat_global"][:] - # all_cam_matrices: (n_frames, 3, 4) all_cam_matrices = f["raw/camera_matrix"][:] - # all_seg_names: list of length n_segments all_seg_names = list(f["raw/body_segment_states"].attrs["keys"]) -plotter = pv.Plotter() - +n_frames = all_seg_pos_global.shape[0] segments_to_include = [ f"{leg}{seg}" for leg in legs for seg in all_segment_names_per_leg ] segments_to_include += ["Thorax"] -meshes = [] + +# Load original meshes once (before any transformations) +original_meshes = {} for seg_name in segments_to_include: - seg_idx = all_seg_names.index(seg_name) - translation = all_seg_pos_global[frame_idx, seg_idx, :] - quaternion = all_seg_quat_global[frame_idx, seg_idx, :] - cam_matrix = all_cam_matrices[frame_idx, :, :] + mesh_file = nmf_mesh_dir / f"{seg_name}.stl" + original_meshes[seg_name] = pv.read(mesh_file) - placement_transform = np.eye(4) - placement_transform[:3, :3] = Rotation.from_quat(quaternion).as_matrix() - placement_transform[:3, 3] = translation +# Create plotter +plotter = pv.Plotter() +plotter.set_background("black") +plotter.show_axes() - mesh_file = nmf_mesh_dir / f"{seg_name}.stl" - mesh = pv.read(mesh_file) - - # Scale - scale_transform = np.eye(4) - np.fill_diagonal(scale_transform, [1000, 1000, 1000, 1]) - mesh.transform(scale_transform) - - # Apply transformation based on MuJoCo state - placement_transform = np.eye(4) - rotation_object = Rotation.from_quat(quaternion, scalar_first=True) - placement_transform[:3, :3] = rotation_object.as_matrix() - placement_transform[:3, 3] = translation - mesh.transform(placement_transform) - - # Convert to coordinates relative to camera - # Given camera projection matrix P = cam_matrix (3x4), solve for: - # R = camera rotation matrix (3x3) - # K = camera intrinsic matrix (3x3) - # where P = K[R|t] - # First, run RQ decomposition to get K and R +# Add all meshes to plotter initially +current_meshes = {} +for seg_name in segments_to_include: + current_meshes[seg_name] = original_meshes[seg_name].copy() + plotter.add_mesh( + current_meshes[seg_name], + show_edges=False, + name=seg_name, + smooth_shading=True + ) + +# Current frame tracker +current_frame = [0] + + +def update_frame(): + """Update all meshes to the current frame""" + frame_idx = current_frame[0] + cam_matrix = all_cam_matrices[frame_idx, :, :] + + # Compute camera transformation once per frame cam_intrinsics, cam_rotation = rq(cam_matrix[:, :3]) - # Make camera intrinsic mastrix have positive diagonal (just a convention) _sign_multiplier = np.diag(np.sign(np.diag(cam_intrinsics))) cam_intrinsics = cam_intrinsics @ _sign_multiplier cam_rotation = _sign_multiplier @ cam_rotation - # Get camera translation in world coords cam_translation = np.linalg.inv(cam_intrinsics) @ cam_matrix[:, 3] - + transform_world2cam = np.eye(4) transform_world2cam[:3, :3] = cam_rotation transform_world2cam[:3, 3] = cam_translation - # print(transform_world2cam) - # Cam matrix describes cam state in world coords: invert it for cam-to-world mapping transform_cam2world = np.linalg.inv(transform_world2cam) - mesh.transform(transform_cam2world) - - meshes.append(mesh) - plotter.add_mesh(mesh, show_edges=False, name=seg_name, smooth_shading=True) + + # Update each segment mesh + for seg_name in segments_to_include: + seg_idx = all_seg_names.index(seg_name) + translation = all_seg_pos_global[frame_idx, seg_idx, :] + quaternion = all_seg_quat_global[frame_idx, seg_idx, :] + + # Start with original mesh + mesh = original_meshes[seg_name].copy() + + # Scale + scale_transform = np.eye(4) + np.fill_diagonal(scale_transform, [1000, 1000, 1000, 1]) + mesh = mesh.transform(scale_transform, inplace=False) + + # Apply MuJoCo state transformation + placement_transform = np.eye(4) + rotation_object = Rotation.from_quat(quaternion, scalar_first=True) + placement_transform[:3, :3] = rotation_object.as_matrix() + placement_transform[:3, 3] = translation + mesh = mesh.transform(placement_transform, inplace=False) + + # Transform to camera coordinates + mesh = mesh.transform(transform_cam2world, inplace=False) + + # Update the mesh points in place + current_meshes[seg_name].points[:] = mesh.points + + # Update frame counter + current_frame[0] = (current_frame[0] + 1) % n_frames + + # Update title to show current frame + plotter.add_text(f"Frame: {frame_idx}/{n_frames}", name="frame_counter", position="upper_left") + + +# Initialize first frame +update_frame() +plotter.reset_camera() +# Add timer callback for animation (30 fps) +# The callback receives a step argument, so we need to accept it +plotter.add_timer_event(max_steps=n_frames, duration=int(1000/30), callback=lambda step: update_frame()) -plotter.set_background("black") -plotter.reset_camera() -plotter.show_axes() -plotter.show() # or plotter.show(screenshot='screenshot.png') +plotter.show() \ No newline at end of file From 92c2009cd8a9899c2207a39f9344fd883dd868f6 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Thu, 27 Nov 2025 17:47:58 +0100 Subject: [PATCH 03/33] [feature] rerun pose nmf postprocessing with bodyseg states --- src/poseforge/neuromechfly/constants.py | 63 +++++++++++++++++- src/poseforge/neuromechfly/data.py | 3 +- src/poseforge/neuromechfly/postprocessing.py | 64 ++++++++++--------- .../neuromechfly/scripts/run_simulation.py | 48 +++++++------- src/poseforge/neuromechfly/simulate.py | 39 ++--------- 5 files changed, 130 insertions(+), 87 deletions(-) diff --git a/src/poseforge/neuromechfly/constants.py b/src/poseforge/neuromechfly/constants.py index 9e3a8b6..dfccbfc 100644 --- a/src/poseforge/neuromechfly/constants.py +++ b/src/poseforge/neuromechfly/constants.py @@ -1,6 +1,10 @@ import numpy as np +########################################################################### +## NEUROMECHFLY BODY CONFIGURATION BELOW ## +########################################################################### + dof_name_lookup_nmf_to_canonical = { "Coxa": "ThC_pitch", "Coxa_roll": "ThC_roll", @@ -49,8 +53,22 @@ "Tarsus5", ] +all_leg_dofs = [ + f"joint_{side}{pos}{dof}" + for side in "LR" + for pos in "FMH" + for dof in [ + "Coxa", + "Coxa_roll", + "Coxa_yaw", + "Femur", + "Femur_roll", + "Tibia", + "Tarsus1", + ] +] -kchain_plotting_colors = { +kchain_plotting_colors = { # these are only for plotting aesthetics "LF": np.array([15, 115, 153]) / 255, "LM": np.array([26, 141, 175]) / 255, "LH": np.array([117, 190, 203]) / 255, @@ -61,6 +79,49 @@ "RAntenna": np.array([50, 120, 32]) / 255, } + +########################################################################### +## COLORS FOR BODY SEGMENT RENDERING BELOW ## +## These are set to artificially boost contrast between body segments ## +## -- they are NOT just for aesthetics! ## +########################################################################### + +# Define color combo by body segment +color_by_link = { + "Coxa": "cyan", + "Femur": "yellow", + "Tibia": "blue", + "Tarsus": "green", + "Antenna": "magenta", + "Thorax": "gray", +} +color_by_kinematic_chain = { + "LF": "red", # left front leg + "LM": "green", # left mid leg + "LH": "blue", # left hind leg + "RF": "cyan", # right front leg + "RM": "magenta", # right mid leg + "RH": "yellow", # right hind leg + "L": "red", # left antenna + "R": "green", # right antenna + "Thorax": "white", # thorax +} +color_palette = { + "red": (1, 0, 0, 1), + "green": (0, 1, 0, 1), + "blue": (0, 0, 1, 1), + "yellow": (1, 1, 0, 1), + "magenta": (1, 0, 1, 1), + "cyan": (0, 1, 1, 1), + "gray": (0.4, 0.4, 0.4, 1), + "white": (1, 1, 1, 1), +} + + +########################################################################### +## PARAMETERS FOR INVERSE KINEMATICS WITH SEQIKPY BELOW ## +########################################################################### + # SeqIKPy considers the anchor point of every DoF a "joint" keypoint. However, some # anatomical joints have multiple DoFs (e.g., ThC has yaw, pitch, roll). This results in # some "virtual" keypoints in the inverse kinematics output. This mask filters them out. diff --git a/src/poseforge/neuromechfly/data.py b/src/poseforge/neuromechfly/data.py index 8901e7f..4cf3b8e 100644 --- a/src/poseforge/neuromechfly/data.py +++ b/src/poseforge/neuromechfly/data.py @@ -1,9 +1,8 @@ import numpy as np import pandas as pd from pathlib import Path -from flygym.preprogrammed import all_leg_dofs -from poseforge.neuromechfly.constants import parse_nmf_joint_name +from poseforge.neuromechfly.constants import parse_nmf_joint_name, all_leg_dofs def extract_joint_angles_trajectory( diff --git a/src/poseforge/neuromechfly/postprocessing.py b/src/poseforge/neuromechfly/postprocessing.py index ce62626..fa33870 100644 --- a/src/poseforge/neuromechfly/postprocessing.py +++ b/src/poseforge/neuromechfly/postprocessing.py @@ -11,14 +11,7 @@ from tqdm import tqdm from joblib import Parallel, delayed -import poseforge.neuromechfly.simulate as simulate -from poseforge.neuromechfly.constants import ( - keypoint_name_lookup_canonical_to_nmf, - kchain_plotting_colors, - keypoint_segments_nmf, - legs, - leg_keypoints_canonical, -) +import poseforge.neuromechfly.constants as constants from poseforge.util.plot import ( configure_matplotlib_style, get_segmentation_color_palette, @@ -65,8 +58,10 @@ def __init__(self): for pos in "FMH": for link in leg_segments: leg = f"{side}{pos}" - color0 = nmf_rendered_colors[simulate.color_by_link[link]] - color1 = nmf_rendered_colors[simulate.color_by_kinematic_chain[leg]] + color0 = nmf_rendered_colors[constants.color_by_link[link]] + color1 = nmf_rendered_colors[ + constants.color_by_kinematic_chain[leg] + ] color_6d = np.array(list(color0) + list(color1)) label = f"{leg}{link}" self.label_keys.append(label) @@ -74,8 +69,8 @@ def __init__(self): # Antennas for side in "LR": - color0 = nmf_rendered_colors[simulate.color_by_link["Antenna"]] - color1 = nmf_rendered_colors[simulate.color_by_kinematic_chain[side]] + color0 = nmf_rendered_colors[constants.color_by_link["Antenna"]] + color1 = nmf_rendered_colors[constants.color_by_kinematic_chain[side]] color_6d = np.array(list(color0) + list(color1)) label = f"{side}Antenna" self.label_keys.append(label) @@ -288,7 +283,7 @@ def process_single_frame( # Gather keypoint positions in coordinates and rotate/center-crop accordingly keypoints_pos_dict_world_raw, keypoints_pos_dict_camera_raw = ( extract_body_segment_positions( - h5_file, frame_idx, "pos_atparent", keypoint_segments_nmf + h5_file, frame_idx, "pos_atparent", constants.keypoint_segments_nmf ) ) keypoints_pos_dict_world_rotated = rotate_keypoint_positions_world( @@ -352,8 +347,7 @@ def process_subsegment( segment_label_parser, ) ) - - # Process frames in parallel + # Parallel execution with joblib # Use 'loky' backend for CPU-intensive image processing operations parallel_executor = Parallel(n_jobs=n_jobs, backend="loky") effective_n_jobs = parallel_executor._effective_n_jobs() @@ -446,9 +440,10 @@ def process_subsegment( keypoint_pos_group = postprocessed_group.create_group("keypoint_pos") for ref_frame in ["camera", "world"]: data_block = np.empty( - (num_frames, len(keypoint_segments_nmf), 3), dtype="float32" + (num_frames, len(constants.keypoint_segments_nmf), 3), + dtype="float32", ) - for seg_id, body_segment in enumerate(keypoint_segments_nmf): + for seg_id, body_segment in enumerate(constants.keypoint_segments_nmf): key = f"keypoint_pos_{ref_frame}_{body_segment}" values = np.array(derived_variables_by_key[key]) data_block[:, seg_id, :] = values @@ -456,13 +451,13 @@ def process_subsegment( pos_ds = keypoint_pos_group.create_dataset( f"{ref_frame}_coords", data=data_block, dtype="float32" ) - pos_ds.attrs["keys"] = keypoint_segments_nmf + pos_ds.attrs["keys"] = constants.keypoint_segments_nmf pos_ds.attrs["description"] = ( f"Keypoint positions in {ref_frame} coordinates. Shape is " "(num_frames, num_keypoints, 3). See the `.attrs['keys']` for the " "order of keypoints." ) - keypoint_pos_group.attrs["keys"] = keypoint_segments_nmf + keypoint_pos_group.attrs["keys"] = constants.keypoint_segments_nmf keypoint_pos_group.attrs["description"] = ( "This group contains positions of joint keypoints in the rotated image " "centered around the fly, cropped, and rotated so that the fly faces " @@ -492,6 +487,17 @@ def process_subsegment( "for the mapping from label IDs (pixel values) to body segment names." ) + # Add mesh state labels + seg_states_grp = postprocessed_group.create_group("body_segment_states") + seg_states_grp.attrs.update(source_h5_file["body_segment_states"].attrs) + for sensor_type in source_h5_file["body_segment_states"].keys(): + source_ds = source_h5_file["body_segment_states"][sensor_type] + seg_states_grp.create_dataset( + sensor_type, + data=source_ds[frame_idx_start:frame_idx_end, :, :], + dtype="float32", + ) + def _draw_pose_2d_and_3d( ax_pose2d: plt.Axes, @@ -518,11 +524,11 @@ def _draw_pose_2d_and_3d( # Legs keypoint_pos_cam_ds = h5_file["postprocessed/keypoint_pos/camera_coords"] keypoints = keypoint_pos_cam_ds.attrs["keys"].tolist() - for leg in legs: - color = kchain_plotting_colors[leg] + for leg in constants.legs: + color = constants.kchain_plotting_colors[leg] all_positions = [] - for kpt in leg_keypoints_canonical: - segment_name = keypoint_name_lookup_canonical_to_nmf[kpt] + for kpt in constants.leg_keypoints_canonical: + segment_name = constants.keypoint_name_lookup_canonical_to_nmf[kpt] keypoint_idx = keypoints.index(f"{leg}{segment_name}") pos = keypoint_pos_cam_ds[frame_index, keypoint_idx, :] all_positions.append(pos) @@ -535,17 +541,17 @@ def _draw_pose_2d_and_3d( segment_name = f"{side}Pedicel" keypoint_idx = keypoints.index(segment_name) pos = keypoint_pos_cam_ds[frame_index, keypoint_idx, :] - color = kchain_plotting_colors[f"{side}Antenna"] + color = constants.kchain_plotting_colors[f"{side}Antenna"] ax_pose2d.plot(pos[0], pos[1], marker="o", color=color, markersize=5) # Plot 3D keypoints keypoint_pos_world_ds = h5_file["postprocessed/keypoint_pos/world_coords"] # Legs - for leg in legs: - color = kchain_plotting_colors[leg] + for leg in constants.legs: + color = constants.kchain_plotting_colors[leg] all_positions = [] - for kpt in leg_keypoints_canonical: - segment_name = keypoint_name_lookup_canonical_to_nmf[kpt] + for kpt in constants.leg_keypoints_canonical: + segment_name = constants.keypoint_name_lookup_canonical_to_nmf[kpt] keypoint_idx = keypoints.index(f"{leg}{segment_name}") pos = keypoint_pos_world_ds[frame_index, keypoint_idx, :] all_positions.append(pos) @@ -563,7 +569,7 @@ def _draw_pose_2d_and_3d( segment_name = f"{side}Pedicel" keypoint_idx = keypoints.index(segment_name) pos = keypoint_pos_world_ds[frame_index, keypoint_idx, :] - color = kchain_plotting_colors[f"{side}Antenna"] + color = constants.kchain_plotting_colors[f"{side}Antenna"] ax_pose3d.plot( pos[0], pos[1], pos[2], marker="o", color=color, markersize=5 ) diff --git a/src/poseforge/neuromechfly/scripts/run_simulation.py b/src/poseforge/neuromechfly/scripts/run_simulation.py index 8af55bf..a9e33d3 100644 --- a/src/poseforge/neuromechfly/scripts/run_simulation.py +++ b/src/poseforge/neuromechfly/scripts/run_simulation.py @@ -53,7 +53,7 @@ from pathlib import Path from poseforge.neuromechfly.data import load_kinematic_recording -from poseforge.neuromechfly.simulate import simulate_one_segment +# from poseforge.neuromechfly.simulate import simulate_one_segment # TODO: revert from poseforge.neuromechfly.postprocessing import postprocess_segment from poseforge.util import get_hardware_availability @@ -134,19 +134,20 @@ def simulate_using_kinematic_prior( print(f"=== Simulating segment #{segment_id} ({num_segments} total) ===") segment = kinematic_recording_segments[segment_id] output_subdir = trial_output_dir / f"segment_{segment_id:03d}" - is_success = simulate_one_segment( - kinematic_recording_segment=segment, - output_dir=output_subdir, - input_timestep=input_timestep, - sim_timestep=sim_timestep, - output_data_freq=output_data_freq, - render_play_speed=render_play_speed, - min_sim_duration_sec=0.2, - max_sim_steps=max_sim_steps_per_segment, - ) + # is_success = simulate_one_segment( # TODO: revert + # kinematic_recording_segment=segment, + # output_dir=output_subdir, + # input_timestep=input_timestep, + # sim_timestep=sim_timestep, + # output_data_freq=output_data_freq, + # render_play_speed=render_play_speed, + # min_sim_duration_sec=0.2, + # max_sim_steps=max_sim_steps_per_segment, + # ) + is_success = True if is_success: postprocess_segment( - output_subdir, visualize=True, min_subsegment_duration_sec=0.1 + output_subdir, visualize=False, min_subsegment_duration_sec=0.1 # TODO: enable visualization ) print(f"### Done processing trial: {trial_name} ###") @@ -154,19 +155,22 @@ def simulate_using_kinematic_prior( def run_sequentially_for_testing(): """Run everything sequentially (for debugging)""" # Configs - output_basedir = Path("bulk_data/nmf_rendering_test/") + output_basedir = Path("bulk_data/nmf_rendering_new/") # TODO: change back to *_test input_timestep = 0.01 sim_timestep = 0.0001 - trial_paths = [ - # For testing: change this list to limit the scope - Path("bulk_data/kinematic_prior/aymanns2022/trials/BO_Gal4_fly1_trial001.pkl") - ] + # trial_paths = [ + # # For testing: change this list to limit the scope + # Path("bulk_data/kinematic_prior/aymanns2022/trials/BO_Gal4_fly1_trial001.pkl") + # ] + trial_paths = sorted( # TODO: revert + Path("bulk_data/kinematic_prior/aymanns2022/trials/").glob("*.pkl") + ) # Limit scope of simulation as this is only for testing # Don't make `max_sim_steps_per_segment` too small; otherwise no subsegment-level # postprocessing will be performed - max_segments_per_trial = 2 - max_sim_steps_per_segment = 3000 + max_segments_per_trial = None # 2 # TODO: revert + max_sim_steps_per_segment = None # 3000 # TODO: revert # Process each trial for trial_path in trial_paths: @@ -187,7 +191,7 @@ def run_sequentially_for_testing(): get_hardware_availability(check_gpu=False, print_results=True) # Run the CLI - tyro.cli(simulate_using_kinematic_prior) + # tyro.cli(simulate_using_kinematic_prior) # TODO: enable CLI - # # Run everything sequentially (for debugging) - # run_sequentially_for_testing() + # # Run everything sequentially (for debugging) # TODO: disable testing + run_sequentially_for_testing() diff --git a/src/poseforge/neuromechfly/simulate.py b/src/poseforge/neuromechfly/simulate.py index 75642e0..144b631 100644 --- a/src/poseforge/neuromechfly/simulate.py +++ b/src/poseforge/neuromechfly/simulate.py @@ -13,39 +13,12 @@ from flygym.preprogrammed import all_leg_dofs from poseforge.neuromechfly.data import interpolate_trajectories -from poseforge.neuromechfly.constants import parse_nmf_joint_name - - -# Define color combo by body segment -color_by_link = { - "Coxa": "cyan", - "Femur": "yellow", - "Tibia": "blue", - "Tarsus": "green", - "Antenna": "magenta", - "Thorax": "gray", -} -color_by_kinematic_chain = { - "LF": "red", # left front leg - "LM": "green", # left mid leg - "LH": "blue", # left hind leg - "RF": "cyan", # right front leg - "RM": "magenta", # right mid leg - "RH": "yellow", # right hind leg - "L": "red", # left antenna - "R": "green", # right antenna - "Thorax": "white", # thorax -} -color_palette = { - "red": (1, 0, 0, 1), - "green": (0, 1, 0, 1), - "blue": (0, 0, 1, 1), - "yellow": (1, 1, 0, 1), - "magenta": (1, 0, 1, 1), - "cyan": (0, 1, 1, 1), - "gray": (0.4, 0.4, 0.4, 1), - "white": (1, 1, 1, 1), -} +from poseforge.neuromechfly.constants import ( + parse_nmf_joint_name, + color_by_link, + color_by_kinematic_chain, + color_palette, +) class SpotlightArena(FlatTerrain): From 3ff696cd79c956b4ab95e38cd4e90515dc6de198 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Fri, 28 Nov 2025 16:40:07 +0100 Subject: [PATCH 04/33] [fix] fix handling of failed simulations --- src/poseforge/neuromechfly/scripts/run_simulation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/poseforge/neuromechfly/scripts/run_simulation.py b/src/poseforge/neuromechfly/scripts/run_simulation.py index a9e33d3..c875731 100644 --- a/src/poseforge/neuromechfly/scripts/run_simulation.py +++ b/src/poseforge/neuromechfly/scripts/run_simulation.py @@ -144,7 +144,7 @@ def simulate_using_kinematic_prior( # min_sim_duration_sec=0.2, # max_sim_steps=max_sim_steps_per_segment, # ) - is_success = True + is_success = output_subdir.exists() and len(list(output_subdir.iterdir())) > 0 if is_success: postprocess_segment( output_subdir, visualize=False, min_subsegment_duration_sec=0.1 # TODO: enable visualization @@ -191,7 +191,7 @@ def run_sequentially_for_testing(): get_hardware_availability(check_gpu=False, print_results=True) # Run the CLI - # tyro.cli(simulate_using_kinematic_prior) # TODO: enable CLI + tyro.cli(simulate_using_kinematic_prior) # TODO: enable CLI - # # Run everything sequentially (for debugging) # TODO: disable testing - run_sequentially_for_testing() + # Run everything sequentially (for debugging) # TODO: disable testing + # run_sequentially_for_testing() From 40acfca693921967babfc4e439fbfa2e600edcf1 Mon Sep 17 00:00:00 2001 From: Sibo Wang-Chen Date: Sat, 29 Nov 2025 13:03:30 +0100 Subject: [PATCH 05/33] [run] retrain feature extractor with low LR and with fewer variants per batch (#45) --- scripts_on_cluster/bodyseg_training/job.run | 55 ++++++++------- .../{job.run => 2variants.run} | 8 +-- .../low_lr.run | 70 +++++++++++++++++++ .../keypoints3d_training/job.run | 53 +++++++------- 4 files changed, 131 insertions(+), 55 deletions(-) rename scripts_on_cluster/contrastive_pretraining_training/{job.run => 2variants.run} (94%) create mode 100644 scripts_on_cluster/contrastive_pretraining_training/low_lr.run diff --git a/scripts_on_cluster/bodyseg_training/job.run b/scripts_on_cluster/bodyseg_training/job.run index 73e663b..7a1154c 100644 --- a/scripts_on_cluster/bodyseg_training/job.run +++ b/scripts_on_cluster/bodyseg_training/job.run @@ -5,11 +5,11 @@ #SBATCH --ntasks 1 #SBATCH --cpus-per-task 16 #SBATCH --mem 92GB -#SBATCH --time 48:00:00 +#SBATCH --time 72:00:00 #SBATCH --partition=h100 #SBATCH --qos=normal #SBATCH --gres=gpu:1 -#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/bodyseg_training/output_20251118a.log +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/bodyseg_training/output_20251127a.log echo "Hello from $(hostname)" @@ -19,7 +19,10 @@ conda activate poseforge cd $HOME/poseforge training_cli_path="src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py" -training_trial_name="trial_20251118a" +training_trial_name="trial_20251127a" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" echo "Training starting at $(date)" @@ -30,32 +33,32 @@ python -u $training_cli_path \ --model-architecture-config.final-upsampler-n-hidden-channels 32 \ --model-architecture-config.confidence-method entropy \ --model-weights-config.feature-extractor-weights \ - "bulk_data/pose_estimation/contrastive_pretraining/trial_20251117a/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth" \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ --loss-config.weight-dice 1.0 \ --loss-config.weight-ce 1.0 \ --training-data-config.train-data-dirs \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ --training-data-config.val-data-dirs \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ --training-data-config.input-image-size 256 256 \ --training-data-config.atomic-batch-n-samples 32 \ --training-data-config.atomic-batch-n-variants 4 \ @@ -67,7 +70,7 @@ python -u $training_cli_path \ --optimizer-config.learning-rate-segmentation-head 3e-4 \ --optimizer-config.weight-decay 1e-5 \ --training-artifacts-config.output-basedir \ - "bulk_data/pose_estimation/bodyseg/trial_20251118a/" \ + "bulk_data/pose_estimation/bodyseg/$training_trial_name/" \ --training-artifacts-config.logging-interval 10 \ --training-artifacts-config.checkpoint-interval 1000 \ --training-artifacts-config.validation-interval 1000 \ diff --git a/scripts_on_cluster/contrastive_pretraining_training/job.run b/scripts_on_cluster/contrastive_pretraining_training/2variants.run similarity index 94% rename from scripts_on_cluster/contrastive_pretraining_training/job.run rename to scripts_on_cluster/contrastive_pretraining_training/2variants.run index befbd3c..8ba3069 100644 --- a/scripts_on_cluster/contrastive_pretraining_training/job.run +++ b/scripts_on_cluster/contrastive_pretraining_training/2variants.run @@ -5,11 +5,11 @@ #SBATCH --ntasks 1 #SBATCH --cpus-per-task 16 #SBATCH --mem 90GB -#SBATCH --time 48:00:00 +#SBATCH --time 72:00:00 #SBATCH --partition=h100 #SBATCH --qos=normal #SBATCH --gres=gpu:1 -#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/contrastive_pretraining_training/output_20251117a.log +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/contrastive_pretraining_training/output_20251125b_2variants.log echo "Hello from $(hostname)" @@ -20,7 +20,7 @@ conda activate poseforge cd $HOME/poseforge echo "Training starting at $(date)" -trial_name="trial_20251117a" +trial_name="trial_20251125b_2variants" python -u src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py \ --n-epochs 10 \ @@ -53,7 +53,7 @@ python -u src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py \ --training-data-config.val-data-dirs \ "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \ --training-data-config.atomic-batch-n-samples 32 \ - --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.atomic-batch-n-variants 2 \ --training-data-config.train-batch-size 960 \ --training-data-config.val-batch-size 256 \ --training-data-config.image-size 256 256 \ diff --git a/scripts_on_cluster/contrastive_pretraining_training/low_lr.run b/scripts_on_cluster/contrastive_pretraining_training/low_lr.run new file mode 100644 index 0000000..c519def --- /dev/null +++ b/scripts_on_cluster/contrastive_pretraining_training/low_lr.run @@ -0,0 +1,70 @@ +#!/bin/bash -l + +#SBATCH --job-name contr_pretrain_lr3e-5 +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 16 +#SBATCH --mem 90GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/contrastive_pretraining_training/output_20251125a_lowlr.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge + +cd $HOME/poseforge +echo "Training starting at $(date)" + +trial_name="trial_20251125a_lowlr" + +python -u src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py \ + --n-epochs 10 \ + --seed 42 \ + --model-architecture-config.projection-head-hidden-dim 512 \ + --model-architecture-config.projection-head-output-dim 256 \ + --model-weights-config.feature-extractor-weights "IMAGENET1K_V1" \ + --loss-config.info-nce-temperature 0.1 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches_4variants/BO_Gal4_fly1_trial001" \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 960 \ + --training-data-config.val-batch-size 256 \ + --training-data-config.image-size 256 256 \ + --training-data-config.n-workers 4 \ + --optimizer-config.adam-lr 3e-5 \ + --optimizer-config.adam-weight-decay 1e-4 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/contrastive_pretraining/$trial_name" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 500 \ + --training-artifacts-config.validation-interval 200 \ + --training-artifacts-config.n-batches-per-validation 100 + +echo "Training ends at $(date)" diff --git a/scripts_on_cluster/keypoints3d_training/job.run b/scripts_on_cluster/keypoints3d_training/job.run index 678ebcb..5e623d4 100644 --- a/scripts_on_cluster/keypoints3d_training/job.run +++ b/scripts_on_cluster/keypoints3d_training/job.run @@ -1,6 +1,6 @@ #!/bin/bash -l -#SBATCH --job-name keypoints3d-20251118a +#SBATCH --job-name keypoints3d #SBATCH --nodes 1 #SBATCH --ntasks 1 #SBATCH --cpus-per-task 16 @@ -9,7 +9,7 @@ #SBATCH --partition=h100 #SBATCH --qos=normal #SBATCH --gres=gpu:1 -#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/keypoints3d_training/output_20251118a.log +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/keypoints3d_training/output_20251127a.log echo "Hello from $(hostname)" @@ -19,37 +19,40 @@ conda activate poseforge cd $HOME/poseforge training_cli_path="src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py" -training_trial_name="trial_20251118a" +training_trial_name="trial_20251127a" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" echo "Training starting at $(date)" python -u $training_cli_path \ --n-epochs 30 \ --model-weights-config.feature-extractor-weights \ - "bulk_data/pose_estimation/contrastive_pretraining/trial_20251117a/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth" \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ --training-data-config.train-data-dirs \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly2_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly3_trial005" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial001" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial002" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial003" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial004" \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly4_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ --training-data-config.val-data-dirs \ - "bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly5_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ --training-data-config.input-image-size 256 256 \ --training-data-config.atomic-batch-n-samples 32 \ --training-data-config.atomic-batch-n-variants 4 \ From 76fb3f066462821117363405004eed3e5ec41656 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 30 Nov 2025 16:42:22 +0100 Subject: [PATCH 06/33] [feature] save predicted probabilities in bodyseg pipeline, also remove IO optimization to reduce RAM usage --- .../bodyseg/scripts/run_bodyseg_inference.py | 179 +++++++++--------- .../scripts/visualize_bodyseg_predictions.py | 10 +- 2 files changed, 90 insertions(+), 99 deletions(-) diff --git a/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py b/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py index 0e5ca59..71f9b4e 100644 --- a/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py +++ b/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py @@ -1,5 +1,6 @@ import torch import h5py +from loguru import logger from pathlib import Path from torchvision.transforms import Resize from torchsummary import summary @@ -9,18 +10,18 @@ import poseforge.pose.bodyseg.config as config from poseforge.pose.bodyseg import BodySegmentationModel, BodySegmentationPipeline from poseforge.util.sys import get_hardware_availability -from poseforge.util.data import OutputBuffer def test_bodyseg_model( - input_basedir: Path, + *, + input_dir: Path, model_dir: Path, model_checkpoint_path: Path, + save_prob: bool = False, output_basedir: Path | None = None, batch_size: int = 512, n_workers: int = 16, inference_image_size: tuple[int, int] = (256, 256), - output_buffer_log_interval: int = 10, ): # System setup hardware_avail = get_hardware_availability(check_gpu=True, print_results=True) @@ -28,19 +29,21 @@ def test_bodyseg_model( raise RuntimeError("No GPU available for testing") torch.backends.cudnn.benchmark = True - # Find all trials to process - input_trials = list(input_basedir.glob("*/model_prediction/not_flipped/")) - print(f"Found {len(list(input_trials))} trials to process") - input_trials = [trial for trial in input_trials if len(list(trial.iterdir())) > 0] - print(f"{len(input_trials)} trials have images to process") - # Create dataset and dataloader + input_img_dir = input_dir / "model_prediction/not_flipped/" + if len(list(input_img_dir.iterdir())) == 0: + logger.warning(f"Trial {input_img_dir} is empty - skipping") + return transform = Resize(inference_image_size) dataloader = SimpleVideoCollectionLoader( - input_trials, transform=transform, batch_size=batch_size, num_workers=n_workers + [input_img_dir], + transform=transform, + batch_size=batch_size, + num_workers=n_workers, ) - print(f"Found {len(dataloader.dataset)} frames to process") - print( + n_frames = len(dataloader.dataset) + logger.info(f"Found {n_frames} frames to process") + logger.info( f"Using batch size {dataloader.batch_size} with {dataloader.num_workers} " f"workers. This will generate {len(dataloader)} batches." ) @@ -55,105 +58,93 @@ def test_bodyseg_model( summary(model, (3, *inference_image_size)) pipeline = BodySegmentationPipeline(model, device="cuda", use_float16=True) - # Make an output buffer - output data for multiple videos will arrive out of sync - def save_predictions(input_video_idx, data_items): - video_obj = dataloader.dataset.videos[input_video_idx] - input_video_path = video_obj.path - exp_trial_name = "_".join(input_video_path.parts[-3:]) - out_dir = output_basedir / exp_trial_name - out_dir.mkdir(parents=True, exist_ok=True) - with h5py.File(out_dir / f"bodyseg_pred.h5", "w") as f: - pred_segmaps = torch.stack([x[0] for x in data_items], dim=0).cpu().numpy() - ds = f.create_dataset( - "pred_segmap", - data=pred_segmaps, - dtype="uint8", - compression="gzip", - shuffle=True, - ) - ds.attrs["class_labels"] = pipeline.class_labels - confs = torch.stack([x[1] for x in data_items], dim=0).cpu().numpy() - ds = f.create_dataset( - "pred_confidence", - data=confs, - dtype="uint8", - compression="gzip", - shuffle=True, - ) - # Confidence is predicted in 0-1, but we store it in 0-100 as uint8 - ds.attrs["scale"] = 100 - ds.attrs["method"] = model.confidence_method - frame_ids = [ - int(p.stem.split("_")[1]) - for p in video_obj.phy_frame_id_to_path.values() - ] - # These are the actual, raw frame IDs from the original video assigned by - # the Spotlight recording software. They may not be contiguous because - # frames where the fly is upside down or too close to the edge, etc. are - # already removed. - f.create_dataset( - "frame_ids", - data=frame_ids, - dtype="int", - compression="gzip", - shuffle=True, - ) - - buckets_and_sizes = { - i: n_frames for i, n_frames in enumerate(dataloader.dataset.n_frames_by_video) - } - output_buffer = OutputBuffer( - buckets_and_expected_sizes=buckets_and_sizes, - closing_func=save_predictions, + # Set up output H5 files + exp_trial_name = input_img_dir.parts[-3] + out_dir = output_basedir / exp_trial_name + out_dir.mkdir(parents=True, exist_ok=True) + f_pred = h5py.File(out_dir / f"bodyseg_pred.h5", "w") + f_prob = h5py.File(out_dir / f"bodyseg_prob.h5", "w") if save_prob else None + ds_segmap = f_pred.create_dataset( + "pred_segmap", + shape=(n_frames, inference_image_size[0], inference_image_size[1]), + dtype="uint8", + compression="gzip", + shuffle=True, ) + ds_segmap.attrs["class_labels"] = pipeline.class_labels + ds_conf = f_pred.create_dataset( + "pred_confidence", + shape=(n_frames, inference_image_size[0], inference_image_size[1]), + dtype="uint8", + compression="gzip", + shuffle=True, + ) + ds_conf.attrs["scale"] = 100 # transform from 0-1 to 0-100 for uint8 storage + ds_conf.attrs["method"] = model.confidence_method + if save_prob: + ds_probs = f_prob.create_dataset( + "pred_probabilities", + shape=( + n_frames, + model.n_classes, + inference_image_size[0], + inference_image_size[1], + ), + dtype="uint8", + compression="gzip", + shuffle=True, + ) + ds_probs.attrs["scale"] = 100 # transform from 0-1 to 0-100 for uint8 storage - # Run inference + # Inference loop + log_interval = max(len(dataloader) // 10, 1) for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): - # No need to move data to and from the GPU, pipeline will do that + # Forward pass (no need to move data to and from the GPU; pipeline will do that) pred_dict = pipeline.inference(batch["frames"]) logits = pred_dict["logits"] + + # Save outputs + frame_ids = batch["frame_indices"] pred_seg = torch.argmax(logits, dim=1).to(torch.uint8).detach().cpu() - confidence = (pred_dict["confidence"] * 100).to(torch.uint8).detach().cpu() + conf = (pred_dict["confidence"] * 100).to(torch.uint8).detach().cpu() + ds_segmap[frame_ids, :, :] = pred_seg.numpy() + ds_conf[frame_ids, :, :] = conf.numpy() + if save_prob: + prob = (torch.softmax(logits, dim=1) * 100).to(torch.uint8).detach().cpu() + ds_probs[frame_ids, :, :, :] = prob.numpy() - for i in range(logits.shape[0]): - data_item = (pred_seg[i, :, :], confidence[i, :, :]) - output_buffer.add_data( - bucket=batch["video_indices"][i], - index=batch["frame_indices"][i], - data=data_item, - ) - if (batch_idx + 1) % output_buffer_log_interval == 0: - print( - f"{batch_idx + 1}/{len(dataloader)} batches - " - f"{output_buffer.n_open_buckets} partially processed videos, " - f"{output_buffer.n_data_total} total frames in buffer" - ) + if (batch_idx + 1) % log_interval == 0 or (batch_idx + 1) == len(dataloader): + logger.info(f"Processed batch {batch_idx + 1} / {len(dataloader)}") - assert output_buffer.n_data_total == 0 - assert output_buffer.n_open_buckets == 0 - print("Inference complete") + logger.info("Inference complete") if __name__ == "__main__": + from poseforge.util.sys import set_loguru_level + + set_loguru_level("INFO") + input_basedir = Path("bulk_data/behavior_images/spotlight_aligned_and_cropped/") - model_dir = Path("bulk_data/pose_estimation/bodyseg/trial_20251118a") + model_dir = Path("bulk_data/pose_estimation/bodyseg/trial_20251127a") batch_size = 192 n_workers = 16 inference_image_size = (256, 256) output_buffer_log_interval = 10 - epoch = 14 # chosen by validation performance and visual inspection - step = 12000 # last step of each epoch + epoch = 8 # chosen by validation performance and visual inspection + step = 18335 # last step of each epoch model_checkpoint_path = model_dir / f"checkpoints/epoch{epoch}_step{step}.model.pth" output_basedir = model_dir / f"production/epoch{epoch}_step{step}/" output_basedir.mkdir(parents=True, exist_ok=True) - test_bodyseg_model( - input_basedir=input_basedir, - model_dir=model_dir, - model_checkpoint_path=model_checkpoint_path, - output_basedir=output_basedir, - batch_size=batch_size, - n_workers=n_workers, - inference_image_size=inference_image_size, - output_buffer_log_interval=output_buffer_log_interval, - ) + for input_dir in sorted(input_basedir.iterdir()): + print(f"Processing {input_dir}") + test_bodyseg_model( + input_dir=input_dir, + model_dir=model_dir, + model_checkpoint_path=model_checkpoint_path, + output_basedir=output_basedir, + batch_size=batch_size, + n_workers=n_workers, + inference_image_size=inference_image_size, + save_prob=True, + ) diff --git a/src/poseforge/pose/bodyseg/scripts/visualize_bodyseg_predictions.py b/src/poseforge/pose/bodyseg/scripts/visualize_bodyseg_predictions.py index 02b1cb7..8c14d7a 100644 --- a/src/poseforge/pose/bodyseg/scripts/visualize_bodyseg_predictions.py +++ b/src/poseforge/pose/bodyseg/scripts/visualize_bodyseg_predictions.py @@ -223,13 +223,13 @@ def visualize_bodyseg_prediction( label_alpha = 0.3 n_workers = -1 - epoch = 14 # chosen by validation performance and visual inspection - step = 12000 # last step of each epoch + epoch = 8 # chosen by validation performance and visual inspection + step = 18335 # last step of each epoch pred_basedir = Path( - f"bulk_data/pose_estimation/bodyseg/trial_20251118a/production/epoch{epoch}_step{step}/" + f"bulk_data/pose_estimation/bodyseg/trial_20251127a/production/epoch{epoch}_step{step}/" ) - pred_path = pred_basedir / f"{trial}_model_prediction_not_flipped/bodyseg_pred.h5" - output_path = pred_basedir / f"{trial}_model_prediction_not_flipped/viz.mp4" + pred_path = pred_basedir / f"{trial}/bodyseg_pred.h5" + output_path = pred_basedir / f"{trial}/viz.mp4" visualize_bodyseg_prediction( recording_dir=recording_dir, From 10a16c716132b73b4be0ad172a966846f32ab8db Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 30 Nov 2025 16:43:29 +0100 Subject: [PATCH 07/33] [feature] add mesh state in atomic batch extraction pipeline --- src/poseforge/neuromechfly/constants.py | 5 +++++ .../pose/data/synthetic/sim_data_seq.py | 22 +++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/poseforge/neuromechfly/constants.py b/src/poseforge/neuromechfly/constants.py index dfccbfc..755b626 100644 --- a/src/poseforge/neuromechfly/constants.py +++ b/src/poseforge/neuromechfly/constants.py @@ -79,6 +79,11 @@ "RAntenna": np.array([50, 120, 32]) / 255, } +segments_for_6dpose_per_leg = ["Coxa", "Femur", "Tibia", "Tarsus1"] +segments_for_6dpose = [ + f"{leg}{seg}" for leg in legs for seg in segments_for_6dpose_per_leg +] + ["Thorax"] + ########################################################################### ## COLORS FOR BODY SEGMENT RENDERING BELOW ## diff --git a/src/poseforge/pose/data/synthetic/sim_data_seq.py b/src/poseforge/pose/data/synthetic/sim_data_seq.py index 837cc90..6d583a1 100644 --- a/src/poseforge/pose/data/synthetic/sim_data_seq.py +++ b/src/poseforge/pose/data/synthetic/sim_data_seq.py @@ -6,6 +6,8 @@ from typing import Iterator from pvio.io import get_video_metadata, read_frames_from_video +from poseforge.neuromechfly.constants import segments_for_6dpose + class SimulatedDataSequence: def __init__( @@ -119,7 +121,7 @@ def read_simulated_labels( *, load_dof_angles: bool = True, load_keypoint_pos: bool = True, - load_mesh_states: bool = False, + load_mesh_states: bool = True, load_body_seg_maps: bool = True, ) -> dict[str, np.ndarray]: self._check_frame_indices_validity(frame_indices) @@ -143,10 +145,22 @@ def read_simulated_labels( labels["keypoint_pos"] = keypoint_pos if load_mesh_states: - raise NotImplementedError( - "Mesh states tracking (xyz + quat for 3D rotation) has not been " - "implemented yet" + pose6d_grp = ds["body_segment_states"] + all_avail_segments = pose6d_grp.attrs["keys"] + # segment_indices_lookup = { + # name: i for i, name in enumerate(all_avail_segments) + # } + # segment_indices = [ + # segment_indices_lookup[seg_name] for seg_name in segments_for_6dpose + # ] + seg_mask = np.array( + [name in segments_for_6dpose for name in all_avail_segments] ) + # h5py only supports fancy indexing along one axis + mesh_pos = pose6d_grp["pos_global"][frame_indices, :, :] + labels["mesh_pos"] = mesh_pos[:, seg_mask, :] + mesh_quat = pose6d_grp["quat_global"][frame_indices, :, :] + labels["mesh_quat"] = mesh_quat[:, seg_mask, :] if load_body_seg_maps: seg_labels_ds = ds["segmentation_labels"] From 9e528df42895647d7f94f89f0f8c50389a6a93aa Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 30 Nov 2025 16:44:28 +0100 Subject: [PATCH 08/33] [run] update paths for 20251127a training (based on pretrained feature extractor at lower learning rate) --- .../pose/keypoints3d/scripts/run_inverse_kinematics.py | 5 +++-- .../pose/keypoints3d/scripts/run_keypoints3d_inference.py | 4 ++-- .../keypoints3d/scripts/visualize_production_keypoints3d.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py b/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py index 418557a..65354b1 100644 --- a/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py +++ b/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py @@ -613,6 +613,7 @@ def process_all( max_n_frames=max_n_frames, n_workers=n_workers_per_dataset, debug_plots_dir=keypoints3d_output_file.parent / "ik_debug_plots/", + parallel_over_time=False, ) _save_seqikpy_output( output_path, joint_angles, forward_kinematics, frame_ids=frame_ids @@ -828,10 +829,10 @@ def _save_seqikpy_output( # --input-images-basedir bulk_data/behavior_images/spotlight_aligned_and_cropped/ # * Processing from this script directly - epoch = 19 # these must be consistent with run_keypoints3d_inference.py + epoch = 15 # these must be consistent with run_keypoints3d_inference.py step = 9167 # same as above production_model_basedir = Path( - f"bulk_data/pose_estimation/keypoints3d/trial_20251118a/production/epoch{epoch}_step{step}/" + f"bulk_data/pose_estimation/keypoints3d/trial_20251127a/production/epoch{epoch}_step{step}/" ) input_images_basedir = Path( "bulk_data/behavior_images/spotlight_aligned_and_cropped/" diff --git a/src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py b/src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py index 84d68a4..42a260c 100644 --- a/src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py +++ b/src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py @@ -200,8 +200,8 @@ def run_keypoints3d_inference( if __name__ == "__main__": input_basedir = Path("bulk_data/behavior_images/spotlight_aligned_and_cropped/") - model_dir = Path("bulk_data/pose_estimation/keypoints3d/trial_20251118a") - epoch = 19 # chosen based on validation performance and visual inspection + model_dir = Path("bulk_data/pose_estimation/keypoints3d/trial_20251127a") + epoch = 15 # chosen based on validation performance and visual inspection step = 9167 # last step print(f"Running inference for epoch {epoch}") diff --git a/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py b/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py index 9fdda06..9421c47 100644 --- a/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py +++ b/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py @@ -507,9 +507,9 @@ def visualize_predictions( if __name__ == "__main__": input_basedir = Path("bulk_data/behavior_images/spotlight_aligned_and_cropped/") - model_dir = Path("bulk_data/pose_estimation/keypoints3d/trial_20251118a") + model_dir = Path("bulk_data/pose_estimation/keypoints3d/trial_20251127a") recordings = ["20250613-fly1b-002"] - epoch = 19 # these must be consistent with run_keypoints3d_inference.py + epoch = 15 # these must be consistent with run_keypoints3d_inference.py step = 9167 # same as above for recording in recordings: From ab25d33c6a08aa4c54dcfcd7c0398218a3b129eb Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 30 Nov 2025 16:45:13 +0100 Subject: [PATCH 09/33] [wip] start pose6d model implementation --- src/poseforge/pose/pose6d/___init__.py | 0 src/poseforge/pose/pose6d/config.py | 83 +++++++++ src/poseforge/pose/pose6d/model.py | 243 +++++++++++++++++++++++++ src/poseforge/pose/pose6d/pipeline.py | 221 ++++++++++++++++++++++ 4 files changed, 547 insertions(+) create mode 100644 src/poseforge/pose/pose6d/___init__.py create mode 100644 src/poseforge/pose/pose6d/config.py create mode 100644 src/poseforge/pose/pose6d/model.py create mode 100644 src/poseforge/pose/pose6d/pipeline.py diff --git a/src/poseforge/pose/pose6d/___init__.py b/src/poseforge/pose/pose6d/___init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/poseforge/pose/pose6d/config.py b/src/poseforge/pose/pose6d/config.py new file mode 100644 index 0000000..a761cbb --- /dev/null +++ b/src/poseforge/pose/pose6d/config.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass + +from poseforge.util import SerializableDataClass + + +@dataclass(frozen=True) +class ModelArchitectureConfig(SerializableDataClass): + # Number of object segments. + # Default: 6 legs * (coxa, femur, tibia, tarsus1) + thorax = 25 + n_segments: int = 25 + # Feature extractor configuration + feature_extractor: dict = None + # Number of hidden channels in the final upsampling layer before pose regression + final_upsampler_n_hidden_channels: int = 64 + # Hidden sizes of intermediate layers of the MLP mesh heads + # (comma-separated list of ints) + pose6d_head_hidden_sizes: str = "512,256" + + +@dataclass(frozen=True) +class ModelWeightsConfig(SerializableDataClass): + # Feature extractor weights. Can be a path to the (contrastively) pretrained weights + # or "IMAGENET1K_V1" + feature_extractor_weights: str | None = None + # Model weights, optional. If provided, the model will be initialized from these + # weights (in which case feature_extractor_weights is ignored). + model_weights: str | None = None + + +@dataclass(frozen=True) +class LossConfig(SerializableDataClass): + # Weight for the translation loss term + translation_weight: float = 1.0 + # Weight for the rotation loss term + rotation_weight: float = 1.0 + + +@dataclass(frozen=True) +class TrainingDataConfig(SerializableDataClass): + # Paths to training data (recursively containing atomic batches) + train_data_dirs: list[str] + # Paths to validation data (recursively containing atomic batches) + val_data_dirs: list[str] + # Frame size (height, width) + input_image_size: tuple[int, int] + # Numbers of samples (frames) in each pre-extracted atomic batch + atomic_batch_n_samples: int + # Number of variants (synthetic images made by different style transfer models) + atomic_batch_n_variants: int + # Number of different frames to include in each batch. Note that n_variants variants + # of each frame will be included, so effective batch size = + # train_batch_size * n_variants. + # This must be a multiple of `atomic_batch_n_samples` in `AtomicBatchDataset`. + train_batch_size: int + # Validation batch size. Can be much smaller than train_batch_size. Must be + # a multiple of `atomic_batch_n_samples` in `AtomicBatchDataset` + val_batch_size: int + # Number of workers for data loading. Use number of CPU cores if None. + n_workers: int | None = None + # Optional kernel size for dilating bodyseg masks + mask_dilation_kernel: int | None = None + + +@dataclass(frozen=True) +class OptimizerConfig(SerializableDataClass): + learning_rate_encoder: float = 3e-5 + learning_rate_deconv: float = 3e-4 + learning_rate_pose6d_heads: float = 3e-4 + weight_decay: float = 1e-5 + + +@dataclass(frozen=True) +class TrainingArtifactsConfig(SerializableDataClass): + # Base directory to save logs and model checkpoints + output_basedir: str + # Log training metrics every N steps + logging_interval: int = 10 + # Save model checkpoint every N steps (NOT EPOCHS!) + checkpoint_interval: int = 500 + # Run validation every N steps (NOT EPOCHS!) + validation_interval: int = 500 + # Number of batches to use for each validation (useful if validation set is large) + n_batches_per_validation: int = 300 diff --git a/src/poseforge/pose/pose6d/model.py b/src/poseforge/pose/pose6d/model.py new file mode 100644 index 0000000..75307ea --- /dev/null +++ b/src/poseforge/pose/pose6d/model.py @@ -0,0 +1,243 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger +from pathlib import Path + +import poseforge.pose.pose6d.config as config +from poseforge.pose.common import ResNetFeatureExtractor, DecoderBlock + + +class Pose6DModel(nn.Module): + def __init__( + self, + n_segments: int, + feature_extractor: ResNetFeatureExtractor, + final_upsampler_n_hidden_channels: int, + pose6d_head_hidden_sizes: list[int], + ): + super(Pose6DModel, self).__init__() + self.n_segments = n_segments + self.feature_extractor = feature_extractor + self.final_upsampler_n_hidden_channels = final_upsampler_n_hidden_channels + self.pose6d_head_hidden_sizes = pose6d_head_hidden_sizes + + # Create decoder (decoder1/2/3/4 mirror encoder layers 1/2/3/4) + # Note that when upsampling, we actually run decoder4 first, decoder1 last + # Also note that the number of channels along the decoding path is higher than + # in keypoints3d and bodyseg models. This because here we will eventually pool + # the features globally for FC pose6d heads, so we want more channels to retain + # information. + self.dec_layer4 = DecoderBlock(512, 256, 512) + self.dec_layer3 = DecoderBlock(512, 128, 256) + self.dec_layer2 = DecoderBlock(256, 64, 256) + self.dec_layer1 = DecoderBlock(256, 64, final_upsampler_n_hidden_channels) + + # 6D pose prediction head (one per segment) + self.pose6d_heads = nn.ModuleList( + [self._make_pose6d_head() for _ in range(n_segments)] + ) + + @classmethod + def from_config( + cls, architecture_config: config.ModelArchitectureConfig | Path | str + ) -> "Pose6DModel": + # Load from file if config is a path + if isinstance(architecture_config, (str, Path)): + architecture_config = config.ModelArchitectureConfig.load( + architecture_config + ) + logger.info(f"Loaded model architecture config from {architecture_config}") + + # Initialize feature extractor (WITHOUT WEIGHTS at this step!) + feature_extractor = ResNetFeatureExtractor() + + # Initialize Pose6DModel (self) from config (WITHOUT WEIGHTS at this step!) + try: + # Parse pose6d_head_hidden_sizes from string + # (don't specify directly as list[int] in YAML - mutability issues) + pose6d_head_hidden_sizes = [ + int(x.strip()) + for x in architecture_config.pose6d_head_hidden_sizes.split(",") + ] + except ValueError as e: + logger.critical( + f"Invalid pose6d_head_hidden_sizes in ModelArchitectureConfig: {e}. " + f"Expected a comma-separated list of integers." + ) + raise e + obj = cls( + n_segments=architecture_config.n_segments, + feature_extractor=feature_extractor, + final_upsampler_n_hidden_channels=architecture_config.final_upsampler_n_hidden_channels, + pose6d_head_hidden_sizes=pose6d_head_hidden_sizes, + ) + + logger.info("Initialized Pose6DModel from architecture config") + return obj + + def load_weights_from_config( + self, weights_config: config.ModelWeightsConfig | Path | str + ): + # Load from file if config is given as a path + if isinstance(weights_config, (str, Path)): + weights_config = config.ModelWeightsConfig.load(weights_config) + logger.info(f"Loaded model weights config from {weights_config}") + + # Check if config has either feature extractor weights or model weights + if ( + weights_config.feature_extractor_weights is None + and weights_config.model_weights is None + ): + logger.warning("weights_config contains nothing useful. No action taken") + return + + # If full model weights are provided, load them directly + if weights_config.model_weights is not None: + checkpoint_path = Path(weights_config.model_weights) + if not checkpoint_path.exists(): + logger.critical(f"Model weights path {checkpoint_path} does not exist") + raise FileNotFoundError(f"Model weights file does not exist") + weights = torch.load(checkpoint_path, map_location="cpu") + self.load_state_dict(weights) + logger.info( + f"Loaded Pose6DModel weights (inc. feature extractor) " + f"from {checkpoint_path}" + ) + return + + # Otherwise, init feature extractor weights if provided + self.feature_extractor = ResNetFeatureExtractor( + # Path, str, or "IMAGENET1K_V1" + weights=weights_config.feature_extractor_weights + ) + logger.info( + f"Initialized feature extractor weights from " + f"{weights_config.feature_extractor_weights}" + ) + + def _make_pose6d_head(self) -> nn.Module: + all_layers = [] + n_channels_in = self.final_upsampler_n_hidden_channels + for hidden_size in self.pose6d_head_hidden_sizes: + layers_within_block = [ + nn.Linear(n_channels_in, hidden_size), + nn.ReLU(inplace=True), + nn.Dropout(p=0.3), + ] + all_layers.extend(layers_within_block) + n_channels_in = hidden_size + # 3 translation + 4 rotation (quaternion) + all_layers.append(nn.Linear(n_channels_in, 7)) + return nn.Sequential(*all_layers) + + def forward(self, input_img: torch.Tensor, bodyseg_probs: torch.Tensor) -> dict: + # Extract features with ResNet backbone and upsample to 128x128 + features = self.feature_extractor(input_img) + x = self.dec_layer4(features) + x = self.dec_layer3(x) + x = self.dec_layer2(x) + x = self.dec_layer1(x) + # Now x has shape (B, final_upsampler_n_hidden_channels, 128, 128) + + # Process each segment separately + translation_pred_list = [] + quaternion_pred_list = [] + + for seg_idx in range(self.n_segments): + # Confidence-weighted global average pooling + mask_probs = bodyseg_probs[:, seg_idx, :, :].unsqueeze(1) # (B, 1, H, W) + weighted_features = x * mask_probs # (B, C, H, W) + feature_sums = weighted_features.sum(dim=(2, 3)) # (B, C) + confidence_sums = mask_probs.sum(dim=(2, 3)) + 1e-6 # (B, 1) + pooled_features = feature_sums / confidence_sums.clamp_min(1e-6) # (B, C) + + # Predict 6D pose + pose_pred = self.pose6d_heads[seg_idx](pooled_features) # (B, 7) + translation_pred = pose_pred[:, 0:3] # (B, 3) + quaternion_pred = pose_pred[:, 3:7] # (B, 4) + # Normalize quaternion to unit length + quaternion_pred = F.normalize(quaternion_pred, p=2, dim=1) + + translation_pred_list.append(translation_pred) + quaternion_pred_list.append(quaternion_pred) + + # Stack predictions into single tensors of shape (B, n_segments, 3 or 4) + translation_pred_all = torch.stack(translation_pred_list, dim=1) + quaternion_pred_all = torch.stack(quaternion_pred_list, dim=1) + + return translation_pred_all, quaternion_pred_all + + +class Pose6DLoss(nn.Module): + def __init__(self, translation_weight: float = 1.0, rotation_weight: float = 1.0): + super(Pose6DLoss, self).__init__() + self.translation_weight = translation_weight + self.rotation_weight = rotation_weight + + def forward( + self, + translation_pred: torch.Tensor, + quaternion_pred: torch.Tensor, + translation_label: torch.Tensor, + quaternion_label: torch.Tensor, + valid_mask: torch.Tensor, + ) -> torch.Tensor: + # If the segment is too small, don't include it in loss computation + translation_pred = translation_pred[valid_mask, ...] + quaternion_pred = quaternion_pred[valid_mask, ...] + translation_label = translation_label[valid_mask, ...] + quaternion_label = quaternion_label[valid_mask, ...] + + # Compute losses + translation_loss = self.translation_loss( + translation_pred.view(-1, 3), translation_label.view(-1, 3) + ) + rotation_loss = self.rotation_loss( + quaternion_pred.view(-1, 4), quaternion_label.view(-1, 4) + ) + return ( + self.translation_weight * translation_loss + + self.rotation_weight * rotation_loss + ) + + @classmethod + def create_from_config( + cls, loss_config: config.LossConfig | Path | str + ) -> "Pose6DLoss": + # Load from file if config is a path + if isinstance(loss_config, (str, Path)): + loss_config = config.LossConfig.load(loss_config) + logger.info(f"Loaded loss config from {loss_config}") + + obj = cls( + translation_weight=loss_config.translation_weight, + rotation_weight=loss_config.rotation_weight, + ) + logger.info("Initialized Pose6DLoss from loss config") + return obj + + @staticmethod + def translation_loss( + translation_pred: torch.Tensor, + translation_label: torch.Tensor, + ) -> torch.Tensor: + """Computes L1 translation loss. Both predictions and labels are should have + shape (N, 3), where N is the number of valid samples.""" + return F.l1_loss(translation_pred, translation_label) + + @staticmethod + def rotation_loss( + quaternion_pred: torch.Tensor, + quaternion_label: torch.Tensor, + ) -> torch.Tensor: + """Computes rotation loss based on quaternion dot product. Both predictions and + labels are should have shape (N, 4), where N is the number of valid samples.""" + # Normalize both quaternions to unit length (note: dim 1 is quat components) + quaternion_pred = F.normalize(quaternion_pred, p=2, dim=1) + quaternion_label = F.normalize(quaternion_label, p=2, dim=1) + + # Compute absolute quaternion dot product (note: dim 1 is quat components) + dot_prod = torch.abs(torch.sum(quaternion_pred * quaternion_label, dim=1)) + + return torch.mean(1.0 - dot_prod) diff --git a/src/poseforge/pose/pose6d/pipeline.py b/src/poseforge/pose/pose6d/pipeline.py new file mode 100644 index 0000000..4beac02 --- /dev/null +++ b/src/poseforge/pose/pose6d/pipeline.py @@ -0,0 +1,221 @@ +import torch +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from collections import defaultdict +from time import time +from datetime import datetime +from pathlib import Path +from itertools import chain +from tqdm import tqdm +from loguru import logger + + +import poseforge.pose.pose6d.config as config +from poseforge.pose.pose6d.model import Pose6DModel, Pose6DLoss +from poseforge.pose.data.synthetic import ( + init_atomic_dataset_and_dataloader, + atomic_batches_to_simple_batch, +) +from poseforge.util import ( + set_random_seed, + check_mixed_precision_status, + count_optimizer_parameters, + count_module_parameters, + clear_memory_cache, +) + + +class Pose6DPipeline: + def __init__( + self, + model: Pose6DModel, + loss_func: Pose6DLoss | None = None, + device: torch.device | str = "cuda", + use_float16: bool = True, + ): + self.model = model.to(device) + self.device = device + self.loss_func = loss_func + self.use_float16 = use_float16 + if torch.cuda.is_available() and "cuda" in str(device): + self.device_type = "cuda" + else: + self.device_type = "cpu" + + def train( + self, + n_epochs: int, + data_config: config.TrainingDataConfig, + optimizer_config: config.OptimizerConfig, + artifacts_config: config.TrainingArtifactsConfig, + seed: int = 42, + ): + # Set random seed for reproducibility + set_random_seed(seed) + + # Set up training and validation data + train_ds, train_loader = self._init_training_dataset_and_dataloader(data_config) + val_ds, val_loader = self._init_validation_dataset_and_dataloader(data_config) + + # Set up logging dir and logger + log_dir = Path(artifacts_config.output_basedir) / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + writer = SummaryWriter(log_dir=log_dir) + + # Set up checkpoint dir + checkpoint_dir = Path(artifacts_config.output_basedir) / "checkpoints" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Set up optimizer + optimizer = self._create_optimizer(optimizer_config) + + # Set up mixed-precision training + grad_scaler = torch.amp.GradScaler(self.device_type, enabled=self.use_float16) + self._check_amp_status_for_model_params( + grad_scaler, subtitle="Model parameters before training" + ) + + # Check if loss function is provided + if self.loss_func is None: + logger.critical("Loss function must be provided for training.") + raise ValueError("Loss function must be provided for training.") + + # Training loop + self.model.train() + for epoch_idx in range(n_epochs): + logger.info( + f"Starting epoch {epoch_idx} out of {n_epochs} at {datetime.now()}..." + ) + running_loss_dict = defaultdict(lambda: 0.0) + epoch_start_time = time() + running_start_time = time() + + for step_idx, atomic_batches in enumerate(train_loader): + # Format data + atomic_batches_frames, atomic_batches_sim_data = atomic_batches + frames, sim_data = atomic_batches_to_simple_batch( + atomic_batches_frames, atomic_batches_sim_data, device=self.device + ) + ... + + def _init_training_dataset_and_dataloader( + self, data_config: config.TrainingDataConfig + ): + return init_atomic_dataset_and_dataloader( + data_dirs=data_config.train_data_dirs, + atomic_batch_n_samples=data_config.atomic_batch_n_samples, + atomic_batch_n_variants=data_config.atomic_batch_n_variants, + input_image_size=data_config.input_image_size, + batch_size=data_config.train_batch_size, + load_dof_angles=False, + load_keypoint_positions=False, + load_body_segment_maps=False, + shuffle=True, + n_workers=data_config.n_workers, + n_channels=3, + pin_memory=True, + drop_last=True, + ) + + def _init_validation_dataset_and_dataloader( + self, data_config: config.TrainingDataConfig + ): + return init_atomic_dataset_and_dataloader( + data_dirs=data_config.val_data_dirs, + atomic_batch_n_samples=data_config.atomic_batch_n_samples, + atomic_batch_n_variants=data_config.atomic_batch_n_variants, + input_image_size=data_config.input_image_size, + batch_size=data_config.val_batch_size, + load_dof_angles=False, + load_keypoint_positions=False, + load_body_segment_maps=False, + shuffle=False, + n_workers=data_config.n_workers, + n_channels=3, + pin_memory=True, + drop_last=True, + ) + + def _create_optimizer(self, optimizer_config: config.OptimizerConfig): + params = [ + { + "params": self.model.feature_extractor.parameters(), + "lr": optimizer_config.learning_rate_encoder, + }, + { + "params": list( + chain( + self.model.dec_layer1.parameters(), + self.model.dec_layer2.parameters(), + self.model.dec_layer3.parameters(), + self.model.dec_layer4.parameters(), + ) + ), + "lr": optimizer_config.learning_rate_deconv, + }, + { + "params": list( + chain( + self.model.pose6d_heads.parameters(), + ) + ), + "lr": optimizer_config.learning_rate_pose6d_heads, + }, + ] + + optimizer = torch.optim.AdamW( + params, weight_decay=optimizer_config.weight_decay + ) + + # Check if all parameters are covered + n_params_optimizer = count_optimizer_parameters(optimizer) + n_params_model = count_module_parameters(self.model) + assert n_params_optimizer == n_params_model, ( + f"Number of parameters in optimizer ({n_params_optimizer}) does not match " + f"number of parameters in model ({n_params_model})." + ) + + return optimizer + + def _check_amp_status_for_model_params( + self, grad_scaler: torch.amp.GradScaler, subtitle: str = "Model parameters" + ): + return check_mixed_precision_status( + self.use_float16, + self.device, + print_results=True, + tensors={ + "feature_extractor_params": self.model.feature_extractor.parameters(), + "decoder_params": chain( + self.model.dec_layer1.parameters(), + self.model.dec_layer2.parameters(), + self.model.dec_layer3.parameters(), + self.model.dec_layer4.parameters(), + ), + "pose6d_heads_params": self.model.pose6d_heads.parameters(), + }, + grad_scaler=grad_scaler, + subtitle=subtitle, + ) + + # def _check_amp_status_during_training( + # self, + # input_images: torch.Tensor, + # target: torch.Tensor, + # pred_dict: torch.Tensor, + # grad_scaler: torch.amp.GradScaler, + # subtitle: str = "Variables during training", + # ): + # return check_mixed_precision_status( + # self.use_float16, + # self.device, + # print_results=True, + # tensors={ + # "input_images": input_images, + # "target": target, + # "pred": pred_dict["logits"], + # "pred_conf": pred_dict["confidence"], + # }, + # grad_scaler=grad_scaler, + # subtitle=subtitle, + # ) From 514975afa0a0f7d9076ebdd1e74763a383cb0f7f Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 30 Nov 2025 16:45:51 +0100 Subject: [PATCH 10/33] [refactor] minor enhancements for keypoints3d and bodyseg models --- src/poseforge/pose/bodyseg/model.py | 3 ++- src/poseforge/pose/bodyseg/pipeline.py | 7 +++---- src/poseforge/pose/keypoints3d/pipeline.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/poseforge/pose/bodyseg/model.py b/src/poseforge/pose/bodyseg/model.py index b8cf7af..e850a5c 100644 --- a/src/poseforge/pose/bodyseg/model.py +++ b/src/poseforge/pose/bodyseg/model.py @@ -103,11 +103,12 @@ def load_weights_from_config( and weights_config.model_weights is None ): logging.warning("weights_config contains nothing useful. No action taken.") + return # If full model weights are provided, load them directly if weights_config.model_weights is not None: checkpoint_path = Path(weights_config.model_weights) - if not checkpoint_path.is_file(): + if not checkpoint_path.exists(): raise ValueError(f"Model weights path {checkpoint_path} is not a file") weights = torch.load(checkpoint_path, map_location="cpu") self.load_state_dict(weights) diff --git a/src/poseforge/pose/bodyseg/pipeline.py b/src/poseforge/pose/bodyseg/pipeline.py index c185595..6934d29 100644 --- a/src/poseforge/pose/bodyseg/pipeline.py +++ b/src/poseforge/pose/bodyseg/pipeline.py @@ -98,7 +98,7 @@ def train( # Set up optimizer optimizer = self._create_optimizer(optimizer_config) - # Set up mixed-point training + # Set up mixed-precision training grad_scaler = torch.amp.GradScaler(self.device_type, enabled=self.use_float16) self._check_amp_status_for_model_params( grad_scaler, subtitle="Model parameters before training" @@ -118,10 +118,9 @@ def train( epoch_start_time = time() running_start_time = time() - for step_idx, (atomic_batches_frames, atomic_batches_sim_data) in enumerate( - train_loader - ): + for step_idx, atomic_batches in enumerate(train_loader): # Format data + atomic_batches_frames, atomic_batches_sim_data = atomic_batches frames, sim_data = atomic_batches_to_simple_batch( atomic_batches_frames, atomic_batches_sim_data, device=self.device ) diff --git a/src/poseforge/pose/keypoints3d/pipeline.py b/src/poseforge/pose/keypoints3d/pipeline.py index 12ef8b9..3826ef7 100644 --- a/src/poseforge/pose/keypoints3d/pipeline.py +++ b/src/poseforge/pose/keypoints3d/pipeline.py @@ -76,7 +76,7 @@ def train( # Set up optimizer optimizer = self._create_optimizer(optimizer_config) - # Set up mixed-point training + # Set up mixed-precision training grad_scaler = torch.amp.GradScaler(self.device_type, enabled=self.use_float16) self._check_amp_status_for_model_params( grad_scaler, subtitle="Model parameters before training" From 3581d289565775f40fd02cf2ed01c6a30992f694 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 30 Nov 2025 16:58:43 +0100 Subject: [PATCH 11/33] [fix] add missing class labels metadata --- src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py b/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py index 71f9b4e..fa593c5 100644 --- a/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py +++ b/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py @@ -95,6 +95,7 @@ def test_bodyseg_model( shuffle=True, ) ds_probs.attrs["scale"] = 100 # transform from 0-1 to 0-100 for uint8 storage + ds_probs.attrs["class_labels"] = pipeline.class_labels # Inference loop log_interval = max(len(dataloader) // 10, 1) From c2a962bd592fbc9d838f036dccade5622e6d392e Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 30 Nov 2025 17:13:43 +0100 Subject: [PATCH 12/33] [refactor] update feature extractor calls to use forward_with_intermediates method --- src/poseforge/pose/bodyseg/model.py | 4 +- src/poseforge/pose/common.py | 70 +++++++++++++++---------- src/poseforge/pose/keypoints3d/model.py | 4 +- 3 files changed, 45 insertions(+), 33 deletions(-) diff --git a/src/poseforge/pose/bodyseg/model.py b/src/poseforge/pose/bodyseg/model.py index e850a5c..a45df5c 100644 --- a/src/poseforge/pose/bodyseg/model.py +++ b/src/poseforge/pose/bodyseg/model.py @@ -158,9 +158,7 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: └──(bottleneck/identity)──┘ """ # Run feature extractor - e0, e1, e2, e3, e4 = self.feature_extractor.forward( - x, return_intermediates=True - ) + e0, e1, e2, e3, e4 = self.feature_extractor.forward_with_intermediates(x) d4 = e4 # this is just the bottleneck diff --git a/src/poseforge/pose/common.py b/src/poseforge/pose/common.py index 5663f6b..15c5172 100644 --- a/src/poseforge/pose/common.py +++ b/src/poseforge/pose/common.py @@ -87,36 +87,34 @@ def _apply_imagenet_normalization( std = torch.tensor(std, device=x.device).view(1, 3, 1, 1) x_normalized = (x - mean) / std return x_normalized + + def forward_with_intermediates(self, x: torch.Tensor): + """Run forward pass, returning intermediate feature maps as well as + the final features. - def forward(self, x, return_intermediates: bool = False): - """ Args: x (torch.Tensor): Input image tensor of shape (batch_size, 3, height, width), with pixel values in [0, 1]. - return_intermediates (bool): Whether to return intermediate - feature maps from various layers. Default False. Returns: - If return_intermediates is False: - features (torch.Tensor): Extracted features. The shape is - (batch_size, out_channels, *output_feature_map_size) - where output_feature_map_size depends on the input - image size. - If return_intermediates is True: - A tuple of 5 torch.Tensors: - - Features after initial Conv-BN-ReLU but before maxpool: - tensor of shape (batch_size, 64, 128, 128) - - Features after layer1: - tensor of shape (batch_size, 64, 64, 64) - - Features after layer2: - tensor of shape (batch_size, 128, 32, 32) - - Features after layer3: - tensor of shape (batch_size, 256, 16, 16) - - Features after layer4: - tensor of shape (batch_size, 512, 8, 8) - This is the same as the single output returned if - return_intermediates is False. + A tuple of 5 torch.Tensors: + - Features after initial Conv-BN-ReLU but before maxpool: + tensor of shape (batch_size, 64, 128, 128) + - Features after layer1: + tensor of shape (batch_size, 64, 64, 64) + - Features after layer2: + tensor of shape (batch_size, 128, 32, 32) + - Features after layer3: + tensor of shape (batch_size, 256, 16, 16) + - Features after layer4: + tensor of shape (batch_size, 512, 8, 8) """ + if x.shape[-2:] != self.input_size: + raise NotImplementedError( + f"Input image size {x.shape[-2:]} not supported; " + f"expected {self.input_size}" + ) + x_norm = self._apply_imagenet_normalization(x) # Remove the final classification head (avgpool + fc) @@ -149,10 +147,28 @@ def forward(self, x, return_intermediates: bool = False): assert x4.shape == (batch_size, 512, 8, 8) self._first_time_forward = False - if return_intermediates: - return x0, x1, x2, x3, x4 - else: - return x4 + return x0, x1, x2, x3, x4 + + def forward(self, x: torch.Tensor): + """Run forward pass through the ResNet-18 backbone and return the + final extracted features. + + See also `.forward_with_intermediates`, which returns intermediate + feature maps as well (useful for identity connections in U-Net-like + architectures). + + Args: + x (torch.Tensor): Input image tensor of shape (batch_size, 3, + height, width), with pixel values in [0, 1]. + + Returns: + features (torch.Tensor): Extracted features. The shape is + (batch_size, out_channels, *output_feature_map_size) + where output_feature_map_size depends on the input + image size. + """ + x0, x1, x2, x3, x4 = self.forward_with_intermediates(x) + return x4 class DecoderBlock(nn.Module): diff --git a/src/poseforge/pose/keypoints3d/model.py b/src/poseforge/pose/keypoints3d/model.py index 4a414ff..bddfa83 100644 --- a/src/poseforge/pose/keypoints3d/model.py +++ b/src/poseforge/pose/keypoints3d/model.py @@ -416,9 +416,7 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: └──(bottleneck/identity)──┘ """ # Run feature extractor - e0, e1, e2, e3, e4 = self.feature_extractor.forward( - x, return_intermediates=True - ) + e0, e1, e2, e3, e4 = self.feature_extractor.forward_with_intermediates(x) d4 = e4 # this is just the bottleneck From 9b7587ca2d7c132727a937aff276c789fba3c594 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 30 Nov 2025 21:46:06 +0100 Subject: [PATCH 13/33] [refactor] rename scripts for clarity + other refactors --- scripts_on_cluster/bodyseg_training/job.run | 2 +- scripts_on_cluster/keypoints3d_training/job.run | 2 +- src/poseforge/pose/bodyseg/pipeline.py | 8 ++++---- ...inference.py => run_inference_on_spotlight_data.py} | 0 .../{run_bodyseg_training.py => run_training.py} | 0 src/poseforge/pose/common.py | 10 ++++++---- src/poseforge/pose/keypoints3d/pipeline.py | 3 ++- ...inference.py => run_inference_on_spotlight_data.py} | 0 ...3d_models.py => run_inference_on_synthetic_data.py} | 0 .../pose/keypoints3d/scripts/run_inverse_kinematics.py | 2 +- .../{run_keypoints3d_training.py => run_training.py} | 0 .../scripts/visualize_production_keypoints3d.py | 5 +++-- 12 files changed, 18 insertions(+), 14 deletions(-) rename src/poseforge/pose/bodyseg/scripts/{run_bodyseg_inference.py => run_inference_on_spotlight_data.py} (100%) rename src/poseforge/pose/bodyseg/scripts/{run_bodyseg_training.py => run_training.py} (100%) rename src/poseforge/pose/keypoints3d/scripts/{run_keypoints3d_inference.py => run_inference_on_spotlight_data.py} (100%) rename src/poseforge/pose/keypoints3d/scripts/{test_keypoints3d_models.py => run_inference_on_synthetic_data.py} (100%) rename src/poseforge/pose/keypoints3d/scripts/{run_keypoints3d_training.py => run_training.py} (100%) diff --git a/scripts_on_cluster/bodyseg_training/job.run b/scripts_on_cluster/bodyseg_training/job.run index 7a1154c..8838990 100644 --- a/scripts_on_cluster/bodyseg_training/job.run +++ b/scripts_on_cluster/bodyseg_training/job.run @@ -18,7 +18,7 @@ spack load ffmpeg conda activate poseforge cd $HOME/poseforge -training_cli_path="src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py" +training_cli_path="src/poseforge/pose/bodyseg/scripts/run_training.py" training_trial_name="trial_20251127a" contrastive_pretraining_trial_name="trial_20251125a_lowlr" contrastive_pretraining_epoch="epoch009" diff --git a/scripts_on_cluster/keypoints3d_training/job.run b/scripts_on_cluster/keypoints3d_training/job.run index 5e623d4..7b057f0 100644 --- a/scripts_on_cluster/keypoints3d_training/job.run +++ b/scripts_on_cluster/keypoints3d_training/job.run @@ -18,7 +18,7 @@ spack load ffmpeg conda activate poseforge cd $HOME/poseforge -training_cli_path="src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py" +training_cli_path="src/poseforge/pose/keypoints3d/scripts/run_training.py" training_trial_name="trial_20251127a" contrastive_pretraining_trial_name="trial_20251125a_lowlr" contrastive_pretraining_epoch="epoch009" diff --git a/src/poseforge/pose/bodyseg/pipeline.py b/src/poseforge/pose/bodyseg/pipeline.py index 6934d29..485d620 100644 --- a/src/poseforge/pose/bodyseg/pipeline.py +++ b/src/poseforge/pose/bodyseg/pipeline.py @@ -168,9 +168,8 @@ def train( for k, x in running_loss_dict.items() } time_now = time() - throughput = artifacts_config.logging_interval / ( - time_now - running_start_time - ) + time_elapsed = time_now - running_start_time + throughput = artifacts_config.logging_interval / time_elapsed running_loss_dict = defaultdict(lambda: 0.0) running_start_time = time_now @@ -244,6 +243,7 @@ def validate( raise ValueError("Loss function must be provided for validation") total_loss_dict = defaultdict(lambda: 0.0) + n_steps_iterated = 0 self.model.eval() with torch.no_grad(): for step_idx, (atomic_batches_frames, atomic_batches_sim_data) in enumerate( @@ -273,6 +273,7 @@ def validate( # Accumulate losses for key, loss in loss_dict.items(): total_loss_dict[key] += loss.item() + n_steps_iterated += 1 del ( atomic_batches_frames, @@ -283,7 +284,6 @@ def validate( ) clear_memory_cache() self.model.train() - n_steps_iterated = step_idx + 1 return {k: v / n_steps_iterated for k, v in total_loss_dict.items()} def inference(self, frames: torch.Tensor) -> dict[str, torch.Tensor]: diff --git a/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py b/src/poseforge/pose/bodyseg/scripts/run_inference_on_spotlight_data.py similarity index 100% rename from src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py rename to src/poseforge/pose/bodyseg/scripts/run_inference_on_spotlight_data.py diff --git a/src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py b/src/poseforge/pose/bodyseg/scripts/run_training.py similarity index 100% rename from src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py rename to src/poseforge/pose/bodyseg/scripts/run_training.py diff --git a/src/poseforge/pose/common.py b/src/poseforge/pose/common.py index 15c5172..d3a1a21 100644 --- a/src/poseforge/pose/common.py +++ b/src/poseforge/pose/common.py @@ -87,8 +87,10 @@ def _apply_imagenet_normalization( std = torch.tensor(std, device=x.device).view(1, 3, 1, 1) x_normalized = (x - mean) / std return x_normalized - - def forward_with_intermediates(self, x: torch.Tensor): + + def forward_with_intermediates( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Run forward pass, returning intermediate feature maps as well as the final features. @@ -148,8 +150,8 @@ def forward_with_intermediates(self, x: torch.Tensor): self._first_time_forward = False return x0, x1, x2, x3, x4 - - def forward(self, x: torch.Tensor): + + def forward(self, x: torch.Tensor) -> torch.Tensor: """Run forward pass through the ResNet-18 backbone and return the final extracted features. diff --git a/src/poseforge/pose/keypoints3d/pipeline.py b/src/poseforge/pose/keypoints3d/pipeline.py index 3826ef7..cc16d14 100644 --- a/src/poseforge/pose/keypoints3d/pipeline.py +++ b/src/poseforge/pose/keypoints3d/pipeline.py @@ -232,6 +232,7 @@ def validate( if max_batches <= 0: raise ValueError("max_batches must be positive or None") total_loss_dict = defaultdict(lambda: 0.0) + n_steps_iterated = 0 if self.loss_func is None: raise ValueError("Loss function must be provided for validation") @@ -263,10 +264,10 @@ def validate( # Accumulate losses for key, loss in loss_dict.items(): total_loss_dict[key] += loss.item() + n_steps_iterated += 1 clear_memory_cache() self.model.train() - n_steps_iterated = step_idx + 1 return {k: v / n_steps_iterated for k, v in total_loss_dict.items()} def inference(self, frames: torch.Tensor) -> dict[str, torch.Tensor]: diff --git a/src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py b/src/poseforge/pose/keypoints3d/scripts/run_inference_on_spotlight_data.py similarity index 100% rename from src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py rename to src/poseforge/pose/keypoints3d/scripts/run_inference_on_spotlight_data.py diff --git a/src/poseforge/pose/keypoints3d/scripts/test_keypoints3d_models.py b/src/poseforge/pose/keypoints3d/scripts/run_inference_on_synthetic_data.py similarity index 100% rename from src/poseforge/pose/keypoints3d/scripts/test_keypoints3d_models.py rename to src/poseforge/pose/keypoints3d/scripts/run_inference_on_synthetic_data.py diff --git a/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py b/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py index 65354b1..b046680 100644 --- a/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py +++ b/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py @@ -829,7 +829,7 @@ def _save_seqikpy_output( # --input-images-basedir bulk_data/behavior_images/spotlight_aligned_and_cropped/ # * Processing from this script directly - epoch = 15 # these must be consistent with run_keypoints3d_inference.py + epoch = 15 # these must be consistent with run_inference_on_spotlight_data.py step = 9167 # same as above production_model_basedir = Path( f"bulk_data/pose_estimation/keypoints3d/trial_20251127a/production/epoch{epoch}_step{step}/" diff --git a/src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py b/src/poseforge/pose/keypoints3d/scripts/run_training.py similarity index 100% rename from src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py rename to src/poseforge/pose/keypoints3d/scripts/run_training.py diff --git a/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py b/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py index 9421c47..c577609 100644 --- a/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py +++ b/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py @@ -374,7 +374,8 @@ def visualize_predictions( with h5py.File(inference_output_path, "r") as f: frame_ids = f["frame_ids"][:] # 3D world coordinates (n_frames, n_kp, 3) - # Upstream inference currently writes this as 'keypoints_world_xyz' in run_keypoints3d_inference + # Upstream inference currently writes this as 'keypoints_world_xyz' in + # run_inference_on_spotlight_data.py if "keypoints_pos" in f: keypoints_pos = f["keypoints_pos"][:] keypoints_order = list(f["keypoints_pos"].attrs["keypoints"]) @@ -509,7 +510,7 @@ def visualize_predictions( input_basedir = Path("bulk_data/behavior_images/spotlight_aligned_and_cropped/") model_dir = Path("bulk_data/pose_estimation/keypoints3d/trial_20251127a") recordings = ["20250613-fly1b-002"] - epoch = 15 # these must be consistent with run_keypoints3d_inference.py + epoch = 15 # these must be consistent with run_inference_on_spotlight_data.py step = 9167 # same as above for recording in recordings: From 6cd42fbd07ca7098e79ee2e0c29e6a1fdc4d8a1e Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 30 Nov 2025 21:47:00 +0100 Subject: [PATCH 14/33] [feature] initial training attempt on workstation --- src/poseforge/neuromechfly/constants.py | 15 + .../pose/data/synthetic/atomic_batch.py | 21 + src/poseforge/pose/pose6d/___init__.py | 0 src/poseforge/pose/pose6d/__init__.py | 12 + src/poseforge/pose/pose6d/config.py | 7 +- src/poseforge/pose/pose6d/model.py | 116 +++--- src/poseforge/pose/pose6d/pipeline.py | 384 ++++++++++++++++-- .../pose/pose6d/scripts/train_mesh6d_model.py | 169 ++++++++ 8 files changed, 605 insertions(+), 119 deletions(-) delete mode 100644 src/poseforge/pose/pose6d/___init__.py create mode 100644 src/poseforge/pose/pose6d/__init__.py create mode 100644 src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py diff --git a/src/poseforge/neuromechfly/constants.py b/src/poseforge/neuromechfly/constants.py index 755b626..c2b246f 100644 --- a/src/poseforge/neuromechfly/constants.py +++ b/src/poseforge/neuromechfly/constants.py @@ -84,6 +84,21 @@ f"{leg}{seg}" for leg in legs for seg in segments_for_6dpose_per_leg ] + ["Thorax"] +# fmt: off +# all_body_segments = [ +# "Thorax", "A1A2", "A3", "A4", "A5", "A6", +# "LHaltere", "LWing", "RHaltere", "RWing", +# "LFCoxa", "LFFemur", "LFTibia", "LFTarsus1", "LFTarsus2", "LFTarsus3", "LFTarsus4", "LFTarsus5", +# "LHCoxa", "LHFemur", "LHTibia", "LHTarsus1", "LHTarsus2", "LHTarsus3", "LHTarsus4", "LHTarsus5", +# "LMCoxa", "LMFemur", "LMTibia", "LMTarsus1", "LMTarsus2", "LMTarsus3", "LMTarsus4", "LMTarsus5", +# "RHCoxa", "RHFemur", "RHTibia", "RHTarsus1", "RHTarsus2", "RHTarsus3", "RHTarsus4", "RHTarsus5", +# "RMCoxa", "RMFemur", "RMTibia", "RMTarsus1", "RMTarsus2", "RMTarsus3", "RMTarsus4", "RMTarsus5", +# "RFCoxa", "RFFemur", "RFTibia", "RFTarsus1", "RFTarsus2", "RFTarsus3", "RFTarsus4", "RFTarsus5", +# "Head", "LEye", "REye", "Rostrum", "Haustellum", +# "LPedicel", "LFuniculus", "LArista", "RPedicel", "RFuniculus", "RArista", +# ] +# fmt: on + ########################################################################### ## COLORS FOR BODY SEGMENT RENDERING BELOW ## diff --git a/src/poseforge/pose/data/synthetic/atomic_batch.py b/src/poseforge/pose/data/synthetic/atomic_batch.py index d447a3f..20be9de 100644 --- a/src/poseforge/pose/data/synthetic/atomic_batch.py +++ b/src/poseforge/pose/data/synthetic/atomic_batch.py @@ -2,6 +2,7 @@ import numpy as np import h5py import logging +from typing import Callable from torch.utils.data import Dataset, DataLoader from pathlib import Path from pvio.io import read_frames_from_video, write_frames_to_video @@ -20,7 +21,11 @@ def __init__( load_dof_angles: bool = False, load_keypoint_positions: bool = False, load_body_segment_maps: bool = False, + load_mesh_pose6d: bool = False, + transform: Callable | None = None, ): + self.transform = transform + # Find all .h5 and .mp4 files in the provided directories all_h5_files = set() all_mp4_files = set() @@ -70,6 +75,9 @@ def __init__( self.label_keys.append("keypoint_pos") if load_body_segment_maps: self.label_keys.append("body_seg_maps") + if load_mesh_pose6d: + self.label_keys.append("mesh_pos") + self.label_keys.append("mesh_quat") def __len__(self): return len(self.atomic_batches) @@ -91,6 +99,10 @@ def __getitem__(self, idx): # Load labels data sim_data = self.load_atomic_batch_sim_data(h5_path, self.label_keys) + # Apply transform if provided + if self.transform is not None: + frames, sim_data = self.transform(frames, sim_data) + return frames, sim_data @staticmethod @@ -358,12 +370,15 @@ def init_atomic_dataset_and_dataloader( load_dof_angles: bool = False, load_keypoint_positions: bool = False, load_body_segment_maps: bool = False, + load_mesh_pose6d: bool = False, shuffle: bool = False, n_workers: int | None = None, n_channels: int = 3, pin_memory: bool = True, drop_last: bool = True, prefetch_factor: int | None = None, + transform: Callable | None = None, + return_index: bool = False, ): """ Initializes an AtomicBatchDataset and a corresponding DataLoader for @@ -385,6 +400,8 @@ def init_atomic_dataset_and_dataloader( positions. Defaults to False. load_body_segment_maps (bool, optional): Whether to load body segment maps. Defaults to False. + load_mesh_pose6d (bool, optional): Whether to load mesh 6D pose. + Defaults to False. shuffle (bool, optional): Whether to shuffle the data. Defaults to False. n_workers (int | None, optional): Number of worker threads for @@ -396,6 +413,8 @@ def init_atomic_dataset_and_dataloader( batch. Defaults to True. prefetch_factor (int | None, optional): Number of samples to load in advance by each worker. If None, uses PyTorch default. + transform (Callable | None, optional): Optional transform to apply + to each atomic batch when loading. Defaults to None. Returns: dataset (AtomicBatchDataset): @@ -412,6 +431,8 @@ def init_atomic_dataset_and_dataloader( load_dof_angles=load_dof_angles, load_keypoint_positions=load_keypoint_positions, load_body_segment_maps=load_body_segment_maps, + load_mesh_pose6d=load_mesh_pose6d, + transform=transform, ) # Check if batch size is valid diff --git a/src/poseforge/pose/pose6d/___init__.py b/src/poseforge/pose/pose6d/___init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/poseforge/pose/pose6d/__init__.py b/src/poseforge/pose/pose6d/__init__.py new file mode 100644 index 0000000..8b734d9 --- /dev/null +++ b/src/poseforge/pose/pose6d/__init__.py @@ -0,0 +1,12 @@ +from .model import Pose6DModel, Pose6DLoss +from .pipeline import Pose6DPipeline, ComputeBodysegProbs +from . import config + + +__all__ = [ + "Pose6DModel", + "Pose6DLoss", + "Pose6DPipeline", + "ComputeBodysegProbs", + "config", +] diff --git a/src/poseforge/pose/pose6d/config.py b/src/poseforge/pose/pose6d/config.py index a761cbb..76b4938 100644 --- a/src/poseforge/pose/pose6d/config.py +++ b/src/poseforge/pose/pose6d/config.py @@ -8,13 +8,8 @@ class ModelArchitectureConfig(SerializableDataClass): # Number of object segments. # Default: 6 legs * (coxa, femur, tibia, tarsus1) + thorax = 25 n_segments: int = 25 - # Feature extractor configuration - feature_extractor: dict = None # Number of hidden channels in the final upsampling layer before pose regression - final_upsampler_n_hidden_channels: int = 64 - # Hidden sizes of intermediate layers of the MLP mesh heads - # (comma-separated list of ints) - pose6d_head_hidden_sizes: str = "512,256" + final_upsampler_n_hidden_channels: int = 256 @dataclass(frozen=True) diff --git a/src/poseforge/pose/pose6d/model.py b/src/poseforge/pose/pose6d/model.py index 75307ea..ceca4f6 100644 --- a/src/poseforge/pose/pose6d/model.py +++ b/src/poseforge/pose/pose6d/model.py @@ -14,13 +14,11 @@ def __init__( n_segments: int, feature_extractor: ResNetFeatureExtractor, final_upsampler_n_hidden_channels: int, - pose6d_head_hidden_sizes: list[int], ): super(Pose6DModel, self).__init__() self.n_segments = n_segments self.feature_extractor = feature_extractor self.final_upsampler_n_hidden_channels = final_upsampler_n_hidden_channels - self.pose6d_head_hidden_sizes = pose6d_head_hidden_sizes # Create decoder (decoder1/2/3/4 mirror encoder layers 1/2/3/4) # Note that when upsampling, we actually run decoder4 first, decoder1 last @@ -30,16 +28,30 @@ def __init__( # information. self.dec_layer4 = DecoderBlock(512, 256, 512) self.dec_layer3 = DecoderBlock(512, 128, 256) - self.dec_layer2 = DecoderBlock(256, 64, 256) - self.dec_layer1 = DecoderBlock(256, 64, final_upsampler_n_hidden_channels) + self.dec_layer2 = DecoderBlock(256, 64, final_upsampler_n_hidden_channels) # 6D pose prediction head (one per segment) - self.pose6d_heads = nn.ModuleList( - [self._make_pose6d_head() for _ in range(n_segments)] - ) + _pose6d_heads = [] + for _ in range(n_segments): + attention_conv = nn.Conv2d( + in_channels=self.final_upsampler_n_hidden_channels + n_segments, + out_channels=self.final_upsampler_n_hidden_channels, + kernel_size=3, + padding=1, + ) + pool = nn.AdaptiveAvgPool2d((1, 1)) + linear1 = nn.Linear(self.final_upsampler_n_hidden_channels, 256) + relu = nn.ReLU(inplace=True) + dropout = nn.Dropout(p=0.3) + linear2 = nn.Linear(256, 7) # 3 translation + 4 rotation (quaternion) + pose6d_head = nn.Sequential( + attention_conv, pool, nn.Flatten(), linear1, relu, dropout, linear2 + ) + _pose6d_heads.append(pose6d_head) + self.pose6d_heads = nn.ModuleList(_pose6d_heads) @classmethod - def from_config( + def create_from_config( cls, architecture_config: config.ModelArchitectureConfig | Path | str ) -> "Pose6DModel": # Load from file if config is a path @@ -53,24 +65,10 @@ def from_config( feature_extractor = ResNetFeatureExtractor() # Initialize Pose6DModel (self) from config (WITHOUT WEIGHTS at this step!) - try: - # Parse pose6d_head_hidden_sizes from string - # (don't specify directly as list[int] in YAML - mutability issues) - pose6d_head_hidden_sizes = [ - int(x.strip()) - for x in architecture_config.pose6d_head_hidden_sizes.split(",") - ] - except ValueError as e: - logger.critical( - f"Invalid pose6d_head_hidden_sizes in ModelArchitectureConfig: {e}. " - f"Expected a comma-separated list of integers." - ) - raise e obj = cls( n_segments=architecture_config.n_segments, feature_extractor=feature_extractor, final_upsampler_n_hidden_channels=architecture_config.final_upsampler_n_hidden_channels, - pose6d_head_hidden_sizes=pose6d_head_hidden_sizes, ) logger.info("Initialized Pose6DModel from architecture config") @@ -116,49 +114,32 @@ def load_weights_from_config( f"{weights_config.feature_extractor_weights}" ) - def _make_pose6d_head(self) -> nn.Module: - all_layers = [] - n_channels_in = self.final_upsampler_n_hidden_channels - for hidden_size in self.pose6d_head_hidden_sizes: - layers_within_block = [ - nn.Linear(n_channels_in, hidden_size), - nn.ReLU(inplace=True), - nn.Dropout(p=0.3), - ] - all_layers.extend(layers_within_block) - n_channels_in = hidden_size - # 3 translation + 4 rotation (quaternion) - all_layers.append(nn.Linear(n_channels_in, 7)) - return nn.Sequential(*all_layers) - - def forward(self, input_img: torch.Tensor, bodyseg_probs: torch.Tensor) -> dict: - # Extract features with ResNet backbone and upsample to 128x128 - features = self.feature_extractor(input_img) - x = self.dec_layer4(features) - x = self.dec_layer3(x) - x = self.dec_layer2(x) - x = self.dec_layer1(x) - # Now x has shape (B, final_upsampler_n_hidden_channels, 128, 128) + def forward(self, image: torch.Tensor, bodyseg_prob: torch.Tensor) -> dict: + bat_size = image.shape[0] + assert image.shape == (bat_size, 3, 256, 256) + + # Run feature extractor and upsample to 64x64 + e0, e1, e2, e3, e4 = self.feature_extractor.forward_with_intermediates(image) + d4 = e4 # this is just the bottleneck + d3 = self.dec_layer4(d4, e3) + d2 = self.dec_layer3(d3, e2) + d1 = self.dec_layer2(d2, e1) + + # Attention using bodyseg prediction happens at 64x64 resolution (stride 4) + assert d1.shape == (bat_size, self.final_upsampler_n_hidden_channels, 64, 64) + assert bodyseg_prob.shape == (bat_size, self.n_segments, 64, 64) + # (B, n_channels_feature_map + n_segments, H, W) + feature_map_with_seg_prob = torch.cat([d1, bodyseg_prob], dim=1) # Process each segment separately translation_pred_list = [] quaternion_pred_list = [] - for seg_idx in range(self.n_segments): - # Confidence-weighted global average pooling - mask_probs = bodyseg_probs[:, seg_idx, :, :].unsqueeze(1) # (B, 1, H, W) - weighted_features = x * mask_probs # (B, C, H, W) - feature_sums = weighted_features.sum(dim=(2, 3)) # (B, C) - confidence_sums = mask_probs.sum(dim=(2, 3)) + 1e-6 # (B, 1) - pooled_features = feature_sums / confidence_sums.clamp_min(1e-6) # (B, C) - - # Predict 6D pose - pose_pred = self.pose6d_heads[seg_idx](pooled_features) # (B, 7) - translation_pred = pose_pred[:, 0:3] # (B, 3) - quaternion_pred = pose_pred[:, 3:7] # (B, 4) - # Normalize quaternion to unit length - quaternion_pred = F.normalize(quaternion_pred, p=2, dim=1) - + head = self.pose6d_heads[seg_idx] + pose6d_pred = head(feature_map_with_seg_prob) # (B, 7) + translation_pred = pose6d_pred[:, 0:3] # (B, 3) + quaternion_pred = pose6d_pred[:, 3:7] # (B, 4) + quaternion_pred = F.normalize(quaternion_pred, p=2, dim=1) # normalize quat translation_pred_list.append(translation_pred) quaternion_pred_list.append(quaternion_pred) @@ -181,14 +162,7 @@ def forward( quaternion_pred: torch.Tensor, translation_label: torch.Tensor, quaternion_label: torch.Tensor, - valid_mask: torch.Tensor, ) -> torch.Tensor: - # If the segment is too small, don't include it in loss computation - translation_pred = translation_pred[valid_mask, ...] - quaternion_pred = quaternion_pred[valid_mask, ...] - translation_label = translation_label[valid_mask, ...] - quaternion_label = quaternion_label[valid_mask, ...] - # Compute losses translation_loss = self.translation_loss( translation_pred.view(-1, 3), translation_label.view(-1, 3) @@ -196,11 +170,17 @@ def forward( rotation_loss = self.rotation_loss( quaternion_pred.view(-1, 4), quaternion_label.view(-1, 4) ) - return ( + total_loss = ( self.translation_weight * translation_loss + self.rotation_weight * rotation_loss ) + return { + "translation_loss_unweighted": translation_loss, + "rotation_loss_unweighted": rotation_loss, + "total_loss": total_loss, + } + @classmethod def create_from_config( cls, loss_config: config.LossConfig | Path | str diff --git a/src/poseforge/pose/pose6d/pipeline.py b/src/poseforge/pose/pose6d/pipeline.py index 4beac02..ddc32ea 100644 --- a/src/poseforge/pose/pose6d/pipeline.py +++ b/src/poseforge/pose/pose6d/pipeline.py @@ -9,20 +9,52 @@ from tqdm import tqdm from loguru import logger - +import poseforge.util as util +import poseforge.pose.data.synthetic as synth_data import poseforge.pose.pose6d.config as config from poseforge.pose.pose6d.model import Pose6DModel, Pose6DLoss -from poseforge.pose.data.synthetic import ( - init_atomic_dataset_and_dataloader, - atomic_batches_to_simple_batch, -) -from poseforge.util import ( - set_random_seed, - check_mixed_precision_status, - count_optimizer_parameters, - count_module_parameters, - clear_memory_cache, -) +from poseforge.pose.bodyseg import BodySegmentationPipeline +from poseforge.neuromechfly.constants import segments_for_6dpose + + +class ComputeBodysegProbs(torch.nn.Module): + def __init__(self, scale_factor: int): + super(ComputeBodysegProbs, self).__init__() + self.scale_factor = scale_factor + self.pool = torch.nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor) + + self.pose6d_idx_to_bodyseg_idx = [] + for segment in segments_for_6dpose: + if segment.endswith("Tarsus1"): + segment = segment.replace("Tarsus1", "Tarsus") + bodyseg_class_labels = list(BodySegmentationPipeline.class_labels) + if segment not in bodyseg_class_labels: + logger.critical( + f"Segment {segment} in Pose6D model has no matching segment in " + "body segmentation maps." + ) + raise ValueError("Invalid segment name.") + self.pose6d_idx_to_bodyseg_idx.append(bodyseg_class_labels.index(segment)) + + def forward( + self, atomic_batch_frames: torch.Tensor, sim_data: dict[str, torch.Tensor] + ) -> torch.Tensor: + segmaps = sim_data["body_seg_maps"] # (B, H, W), value = seg idx in sim data + bat_size, nrows, ncols = segmaps.size() + + # Convert to one-hot (B, len(pose6d_segments), H, W) + segmaps_onehot = torch.zeros( + (bat_size, len(segments_for_6dpose), nrows, ncols), dtype=torch.float32 + ) + for i, bodyseg_idx in enumerate(self.pose6d_idx_to_bodyseg_idx): + segmaps_onehot[:, i, :, :] = (segmaps == bodyseg_idx).float() + + # Downsample using max pooling + segmaps_onehot_downsampled = self.pool(segmaps_onehot) + sim_data["segmap_probs_label"] = segmaps_onehot_downsampled.to(segmaps.device) + sim_data.pop("body_seg_maps") # remove original segmap + + return atomic_batch_frames, sim_data class Pose6DPipeline: @@ -42,6 +74,12 @@ def __init__( else: self.device_type = "cpu" + # Body segmentation maps are at full working resolution (256x256). The Pose6D + # model performs soft global pooling weighted by segmentation map probabilities + # at 64x64 resolution, so we need to downsample the segmentation maps. Use max + # pooling because it's better to slightly overestimate the presence of segments. + self._compute_bodyseg_probs_transform = ComputeBodysegProbs(scale_factor=4) + def train( self, n_epochs: int, @@ -49,13 +87,23 @@ def train( optimizer_config: config.OptimizerConfig, artifacts_config: config.TrainingArtifactsConfig, seed: int = 42, + half_batch_size_for_debugging: bool = False, ): # Set random seed for reproducibility - set_random_seed(seed) + util.set_random_seed(seed) + + # If half_batch_size_for_debugging, cut batch sizes in half to save memory + self.half_batch_size_for_debugging = half_batch_size_for_debugging + if self.half_batch_size_for_debugging: + logger.warning( + "Debug mode: using half batch sizes for training and validation in " + "order to fit the model in memory for a GeForce RTX 3080 Ti." + ) # Set up training and validation data train_ds, train_loader = self._init_training_dataset_and_dataloader(data_config) val_ds, val_loader = self._init_validation_dataset_and_dataloader(data_config) + n_batches_per_epoch = len(train_loader) # Set up logging dir and logger log_dir = Path(artifacts_config.output_basedir) / "logs" @@ -93,15 +141,185 @@ def train( for step_idx, atomic_batches in enumerate(train_loader): # Format data atomic_batches_frames, atomic_batches_sim_data = atomic_batches - frames, sim_data = atomic_batches_to_simple_batch( + frames, sim_data = synth_data.atomic_batches_to_simple_batch( + atomic_batches_frames, atomic_batches_sim_data, device=self.device + ) + if self.half_batch_size_for_debugging: + frames, sim_data = self._get_half_batch(frames, sim_data) + bodyseg_probs_label = sim_data["segmap_probs_label"] + + # Forward pass with mixed precision + with torch.amp.autocast(self.device_type, enabled=self.use_float16): + pred_pos, pred_quat = self.model(frames, bodyseg_probs_label) + loss_dict = self.loss_func( + pred_pos, pred_quat, sim_data["mesh_pos"], sim_data["mesh_quat"] + ) + + # Check if float16 is used + if epoch_idx == 0 and step_idx == 0: + self._check_amp_status_for_model_params( + grad_scaler, + subtitle="Model parameters at start of training", + ) + self._check_amp_status_during_training( + frames, + pred_pos, + pred_quat, + sim_data["mesh_pos"], + sim_data["mesh_quat"], + grad_scaler, + subtitle="Variables at start of training", + ) + + # Backpropagate and optimize + optimizer.zero_grad(set_to_none=True) + grad_scaler.scale(loss_dict["total_loss"]).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + + # Logging + for key, value in loss_dict.items(): + running_loss_dict[key] += value.item() + if step_idx % artifacts_config.logging_interval == 0 and step_idx > 0: + avg_loss_dict = { + k: x / artifacts_config.logging_interval + for k, x in running_loss_dict.items() + } + time_now = time() + time_elapsed = time_now - running_start_time + throughput = artifacts_config.logging_interval / time_elapsed + + running_loss_dict = defaultdict(lambda: 0.0) + running_start_time = time_now + self._update_logs_training( + writer, + epoch_index=epoch_idx, + within_epoch_step_idx=step_idx, + n_batches_per_epoch=n_batches_per_epoch, + avg_loss_dict=avg_loss_dict, + throughput=throughput, + ) + + # Run validation + if ( + step_idx % artifacts_config.validation_interval == 0 + and step_idx > 0 + ): + del ( + atomic_batches_frames, + atomic_batches_sim_data, + frames, + sim_data, + bodyseg_probs_label, + pred_pos, + pred_quat, + ) + util.clear_memory_cache() + val_loss_dict = self.validate( + val_loader, + max_batches=artifacts_config.n_batches_per_validation, + ) + self._update_logs_validation( + writer, + epoch_idx=epoch_idx, + within_epoch_step_idx=step_idx, + n_batches_per_epoch=n_batches_per_epoch, + val_loss_dict=val_loss_dict, + ) + + # Save checkpoint + # (every log_interval steps and last step of each epoch) + if ( + step_idx % artifacts_config.checkpoint_interval == 0 + and step_idx > 0 + ) or (step_idx == n_batches_per_epoch - 1): + checkpoint_path_stem = ( + checkpoint_dir / f"epoch{epoch_idx}_step{step_idx}" + ) + self._save_checkpoint( + checkpoint_path_stem, + model=self.model, + loss=self.loss_func, + optimizer=optimizer, + grad_scaler=grad_scaler, + ) + logger.info(f"Saved checkpoint to {checkpoint_path_stem}.*.pth") + + epoch_wall_time = time() - epoch_start_time + logger.info(f"Finished epoch {epoch_idx} in {epoch_wall_time:.2f} seconds.") + + writer.close() + + def validate( + self, validation_data_loader: DataLoader, max_batches: int | None = None + ): + if max_batches is None: + max_batches = len(validation_data_loader) + if max_batches <= 0: + raise ValueError("max_batches must be positive or None") + if self.loss_func is None: + raise ValueError("Loss function must be provided for validation") + + total_loss_dict = defaultdict(lambda: 0.0) + n_steps_iterated = 0 + self.model.eval() + with torch.no_grad(): + for step_idx, (atomic_batches_frames, atomic_batches_sim_data) in enumerate( + tqdm(validation_data_loader, desc="Validation", disable=None) + ): + if step_idx >= max_batches: + break + + # Format data + frames, sim_data = synth_data.atomic_batches_to_simple_batch( atomic_batches_frames, atomic_batches_sim_data, device=self.device ) - ... + if self.half_batch_size_for_debugging: + frames, sim_data = self._get_half_batch(frames, sim_data) + bodyseg_probs_label = sim_data["segmap_probs_label"] + + # Run model + with torch.amp.autocast(self.device_type, enabled=self.use_float16): + pred_pos, pred_quat = self.model(frames, bodyseg_probs_label) + loss_dict = self.loss_func( + pred_pos, pred_quat, sim_data["mesh_pos"], sim_data["mesh_quat"] + ) + + # Accumulate losses + for key, loss in loss_dict.items(): + total_loss_dict[key] += loss.item() + n_steps_iterated += 1 + + del ( + atomic_batches_frames, + atomic_batches_sim_data, + frames, + sim_data, + bodyseg_probs_label, + pred_pos, + pred_quat, + ) + util.clear_memory_cache() + self.model.train() + return {k: v / n_steps_iterated for k, v in total_loss_dict.items()} + + def inference( + self, frames: torch.Tensor, bodyseg_probs: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + input_device = frames.device + self.model.eval() + with torch.no_grad(): + frames = frames.to(self.device) + bodyseg_probs = bodyseg_probs.to(self.device) + with torch.amp.autocast(self.device_type, enabled=self.use_float16): + pred_pos, pred_quat = self.model(frames, bodyseg_probs) + self.model.train() + return pred_pos.to(input_device), pred_quat.to(input_device) def _init_training_dataset_and_dataloader( self, data_config: config.TrainingDataConfig ): - return init_atomic_dataset_and_dataloader( + return synth_data.init_atomic_dataset_and_dataloader( data_dirs=data_config.train_data_dirs, atomic_batch_n_samples=data_config.atomic_batch_n_samples, atomic_batch_n_variants=data_config.atomic_batch_n_variants, @@ -109,18 +327,20 @@ def _init_training_dataset_and_dataloader( batch_size=data_config.train_batch_size, load_dof_angles=False, load_keypoint_positions=False, - load_body_segment_maps=False, + load_body_segment_maps=True, + load_mesh_pose6d=True, shuffle=True, n_workers=data_config.n_workers, n_channels=3, pin_memory=True, drop_last=True, + transform=self._compute_bodyseg_probs_transform, ) def _init_validation_dataset_and_dataloader( self, data_config: config.TrainingDataConfig ): - return init_atomic_dataset_and_dataloader( + return synth_data.init_atomic_dataset_and_dataloader( data_dirs=data_config.val_data_dirs, atomic_batch_n_samples=data_config.atomic_batch_n_samples, atomic_batch_n_variants=data_config.atomic_batch_n_variants, @@ -128,12 +348,14 @@ def _init_validation_dataset_and_dataloader( batch_size=data_config.val_batch_size, load_dof_angles=False, load_keypoint_positions=False, - load_body_segment_maps=False, + load_body_segment_maps=True, + load_mesh_pose6d=True, shuffle=False, n_workers=data_config.n_workers, n_channels=3, pin_memory=True, drop_last=True, + transform=self._compute_bodyseg_probs_transform, ) def _create_optimizer(self, optimizer_config: config.OptimizerConfig): @@ -145,7 +367,7 @@ def _create_optimizer(self, optimizer_config: config.OptimizerConfig): { "params": list( chain( - self.model.dec_layer1.parameters(), + # no dec_layer1 because we don't upsample all the way to 128 self.model.dec_layer2.parameters(), self.model.dec_layer3.parameters(), self.model.dec_layer4.parameters(), @@ -168,8 +390,8 @@ def _create_optimizer(self, optimizer_config: config.OptimizerConfig): ) # Check if all parameters are covered - n_params_optimizer = count_optimizer_parameters(optimizer) - n_params_model = count_module_parameters(self.model) + n_params_optimizer = util.count_optimizer_parameters(optimizer) + n_params_model = util.count_module_parameters(self.model) assert n_params_optimizer == n_params_model, ( f"Number of parameters in optimizer ({n_params_optimizer}) does not match " f"number of parameters in model ({n_params_model})." @@ -180,14 +402,13 @@ def _create_optimizer(self, optimizer_config: config.OptimizerConfig): def _check_amp_status_for_model_params( self, grad_scaler: torch.amp.GradScaler, subtitle: str = "Model parameters" ): - return check_mixed_precision_status( + return util.check_mixed_precision_status( self.use_float16, self.device, print_results=True, tensors={ "feature_extractor_params": self.model.feature_extractor.parameters(), "decoder_params": chain( - self.model.dec_layer1.parameters(), self.model.dec_layer2.parameters(), self.model.dec_layer3.parameters(), self.model.dec_layer4.parameters(), @@ -198,24 +419,97 @@ def _check_amp_status_for_model_params( subtitle=subtitle, ) - # def _check_amp_status_during_training( - # self, - # input_images: torch.Tensor, - # target: torch.Tensor, - # pred_dict: torch.Tensor, - # grad_scaler: torch.amp.GradScaler, - # subtitle: str = "Variables during training", - # ): - # return check_mixed_precision_status( - # self.use_float16, - # self.device, - # print_results=True, - # tensors={ - # "input_images": input_images, - # "target": target, - # "pred": pred_dict["logits"], - # "pred_conf": pred_dict["confidence"], - # }, - # grad_scaler=grad_scaler, - # subtitle=subtitle, - # ) + def _check_amp_status_during_training( + self, + input_images: torch.Tensor, + pred_pos: torch.Tensor, + pred_quat: torch.Tensor, + target_pos: torch.Tensor, + target_quat: torch.Tensor, + grad_scaler: torch.amp.GradScaler, + subtitle: str = "Variables during training", + ): + return util.check_mixed_precision_status( + self.use_float16, + self.device, + print_results=True, + tensors={ + "input_images": input_images, + "target_pos": target_pos, + "target_quat": target_quat, + "pred_pos": pred_pos, + "pred_quat": pred_quat, + }, + grad_scaler=grad_scaler, + subtitle=subtitle, + ) + + @staticmethod + def _save_checkpoint( + checkpoint_path_stem: Path, + model: Pose6DModel, + loss: Pose6DLoss | None = None, + optimizer: torch.optim.Optimizer | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> None: + path = checkpoint_path_stem.with_suffix(".model.pth") + torch.save(model.state_dict(), path) + if loss is not None: + path = checkpoint_path_stem.with_suffix(".loss.pth") + torch.save(loss.state_dict(), path) + if optimizer is not None: + path = checkpoint_path_stem.with_suffix(".optimizer.pth") + torch.save(optimizer.state_dict(), path) + if grad_scaler is not None: + path = checkpoint_path_stem.with_suffix(".grad_scaler.pth") + torch.save(grad_scaler.state_dict(), path) + + def _update_logs_training( + self, + writer: SummaryWriter, + *, + epoch_index: int, + within_epoch_step_idx: int, + n_batches_per_epoch: int, + avg_loss_dict: dict[str, float], + throughput: float, + ) -> None: + global_step_idx = epoch_index * n_batches_per_epoch + within_epoch_step_idx + writer.add_scalar("train/epoch", epoch_index, global_step_idx) + log_str = ( + f"Epoch {epoch_index}, step {within_epoch_step_idx}/{n_batches_per_epoch}, " + ) + for key, value in avg_loss_dict.items(): + log_str += f"{key}: {value:.4f}, " + writer.add_scalar(f"train/loss/{key}", value, global_step_idx) + log_str += f"throughput: {throughput:.2f} batches/sec" + logger.info(log_str) + writer.add_scalar("train/sys/throughput", throughput, global_step_idx) + + def _update_logs_validation( + self, + writer: SummaryWriter, + *, + epoch_idx: int, + within_epoch_step_idx: int, + n_batches_per_epoch: int, + val_loss_dict: dict[str, float], + ) -> None: + global_step_idx = epoch_idx * n_batches_per_epoch + within_epoch_step_idx + log_str = ( + f"Validation at epoch {epoch_idx}, " + f"step {within_epoch_step_idx}/{n_batches_per_epoch}, " + ) + for key, value in val_loss_dict.items(): + log_str += f"{key}: {value:.4f}, " + writer.add_scalar(f"val/loss/{key}", value, global_step_idx) + logger.info(log_str) + + def _get_half_batch(self, frames_batch, sim_data_batch): + """Return half of the batch to save memory (for debugging only).""" + half_batch_size = frames_batch.shape[0] // 2 + frames_batch = frames_batch[:half_batch_size, ...] + sim_data_batch = { + k: v[:half_batch_size, ...] for k, v in sim_data_batch.items() + } + return frames_batch, sim_data_batch diff --git a/src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py b/src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py new file mode 100644 index 0000000..aab8a9d --- /dev/null +++ b/src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py @@ -0,0 +1,169 @@ +import torch +from pathlib import Path +from torchsummary import summary +from loguru import logger + +import poseforge.pose.pose6d.config as config +from poseforge.pose.pose6d import Pose6DModel, Pose6DLoss, Pose6DPipeline +from poseforge.util import get_hardware_availability +from poseforge.neuromechfly.constants import segments_for_6dpose + + +def setup_model( + architecture_config: config.ModelArchitectureConfig, + weights_config: config.ModelWeightsConfig | None, +) -> Pose6DModel: + model = Pose6DModel.create_from_config(architecture_config) + logger.info( + f"Model architecture set up based on architecture config {architecture_config}" + ) + if weights_config is not None: + model.load_weights_from_config(weights_config) + logger.info(f"Loaded weights into model based on weights config {weights_config}") + return model + + +def setup_loss_func(loss_config: config.LossConfig) -> Pose6DLoss: + loss_func = Pose6DLoss.create_from_config(loss_config) + logger.info("Set up Pose6D loss function") + return loss_func + + +def save_configs( + configs_dir: Path, + *, + model_architecture_config: config.ModelArchitectureConfig, + loss_config: config.LossConfig, + training_data_config: config.TrainingDataConfig, + optimizer_config: config.OptimizerConfig, + training_artifacts_config: config.TrainingArtifactsConfig, + model_weights_config: config.ModelWeightsConfig | None = None, +) -> None: + configs_dir.mkdir(parents=True, exist_ok=True) + model_architecture_config.save(configs_dir / "model_architecture_config.yaml") + loss_config.save(configs_dir / "loss_config.yaml") + training_data_config.save(configs_dir / "data_config.yaml") + optimizer_config.save(configs_dir / "optimizer_config.yaml") + training_artifacts_config.save(configs_dir / "artifacts_config.yaml") + if model_weights_config is not None: + model_weights_config.save(configs_dir / "model_weights_config.yaml") + + +def train_mesh6d_model( + n_epochs: int, + model_architecture_config: config.ModelArchitectureConfig, + model_weights_config: config.ModelWeightsConfig, + loss_config: config.LossConfig, + training_data_config: config.TrainingDataConfig, + optimizer_config: config.OptimizerConfig, + training_artifacts_config: config.TrainingArtifactsConfig, + seed: int = 42, + half_batch_size_for_debugging: bool = False, +) -> None: + # System setup + hardware_avail = get_hardware_availability(check_gpu=True, print_results=True) + if len(hardware_avail["gpus"]) == 0: + raise RuntimeError("No GPU available for training") + torch.backends.cudnn.benchmark = True + + # Save configs + save_configs( + Path(training_artifacts_config.output_basedir) / "configs", + model_architecture_config=model_architecture_config, + model_weights_config=model_weights_config, + loss_config=loss_config, + training_data_config=training_data_config, + optimizer_config=optimizer_config, + training_artifacts_config=training_artifacts_config, + ) + + # Initialize model and loss function + model = setup_model(model_architecture_config, model_weights_config) + criterion = setup_loss_func(loss_config) + + # Print model summary + print_model_summary(training_data_config, model) + + # Set up loss function + loss_func = setup_loss_func(loss_config) + + # Set up training pipeline + pipeline = Pose6DPipeline( + model=model, loss_func=loss_func, device="cuda", use_float16=True + ) + logger.info("Training pipeline set up") + + # Start training + pipeline.train( + n_epochs=n_epochs, + data_config=training_data_config, + optimizer_config=optimizer_config, + artifacts_config=training_artifacts_config, + seed=seed, + half_batch_size_for_debugging=half_batch_size_for_debugging, + ) + logger.info("Training completed") + + +def print_model_summary(training_data_config, model): + input_image_size = (3, *training_data_config.input_image_size) + attention_feat_map_size = (len(segments_for_6dpose), 64, 64) # hardcoded for now + in_dim = [input_image_size, attention_feat_map_size] + print("============== Full Model Summary ===============") + summary(model, in_dim, device="cpu") + print("=========== Feature Extractor Summary ===========") + summary(model.feature_extractor, input_image_size, device="cpu") + + +if __name__ == "__main__": + # import tyro + + # tyro.cli( + # train_mesh6d_model, + # prog=f"python {Path(__file__).name}", + # description="Train a 6D mesh pose model using pretrained feature extractor.", + # ) + + # Example using native Python function calls: + model_architecture_config = config.ModelArchitectureConfig() + model_weights_config = config.ModelWeightsConfig( + feature_extractor_weights="bulk_data/pose_estimation/contrastive_pretraining/trial_20251125a_lowlr/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth", + model_weights=None, + ) + loss_config = config.LossConfig() + data_basedir = Path("bulk_data/pose_estimation/atomic_batches/4variants") + train_data_dirs = [ + data_basedir / f"BO_Gal4_fly{fly}_trial{trial:03d}" + for fly in range(1, 5) # flies 1-4 + for trial in range(1, 6) # trials 1-5 + ] + val_data_dirs = [data_basedir / f"BO_Gal4_fly5_trial001"] + training_data_config = config.TrainingDataConfig( + train_data_dirs=[str(path) for path in train_data_dirs], + val_data_dirs=[str(path) for path in val_data_dirs], + input_image_size=(256, 256), + atomic_batch_n_samples=32, + atomic_batch_n_variants=4, + train_batch_size=32, + val_batch_size=32, + n_workers=8, + ) + optimizer_config = config.OptimizerConfig() + training_artifacts_config = config.TrainingArtifactsConfig( + output_basedir="bulk_data/pose_estimation/mesh6d/trial_20251130z/", + logging_interval=10, # 1000 + checkpoint_interval=30, # 1000 + validation_interval=30, # 1000 + n_batches_per_validation=30, # 300 + ) + train_mesh6d_model( + n_epochs=10, + model_architecture_config=model_architecture_config, + model_weights_config=model_weights_config, + loss_config=loss_config, + training_data_config=training_data_config, + optimizer_config=optimizer_config, + training_artifacts_config=training_artifacts_config, + seed=42, + half_batch_size_for_debugging=True, + ) From 128e35f4b1ca1274bef33443887e2dc226a3df86 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 1 Dec 2025 00:40:15 +0100 Subject: [PATCH 15/33] [fix] temporary fix for mesh state coords system bug --- src/poseforge/neuromechfly/postprocessing.py | 1 + .../pose/data/synthetic/sim_data_seq.py | 47 +++++++++++++++---- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/src/poseforge/neuromechfly/postprocessing.py b/src/poseforge/neuromechfly/postprocessing.py index fa33870..507d8f3 100644 --- a/src/poseforge/neuromechfly/postprocessing.py +++ b/src/poseforge/neuromechfly/postprocessing.py @@ -497,6 +497,7 @@ def process_subsegment( data=source_ds[frame_idx_start:frame_idx_end, :, :], dtype="float32", ) + # ! TODO: Convert to camera coords here! def _draw_pose_2d_and_3d( diff --git a/src/poseforge/pose/data/synthetic/sim_data_seq.py b/src/poseforge/pose/data/synthetic/sim_data_seq.py index 6d583a1..7457c5f 100644 --- a/src/poseforge/pose/data/synthetic/sim_data_seq.py +++ b/src/poseforge/pose/data/synthetic/sim_data_seq.py @@ -8,6 +8,10 @@ from poseforge.neuromechfly.constants import segments_for_6dpose +# ! TODO Can be removed later +from scipy.linalg import rq +from scipy.spatial.transform import Rotation + class SimulatedDataSequence: def __init__( @@ -127,8 +131,8 @@ def read_simulated_labels( self._check_frame_indices_validity(frame_indices) labels = {} - with h5py.File(self.simulated_labels_path, "r") as ds: - ds = ds["postprocessed"] + with h5py.File(self.simulated_labels_path, "r") as f: + ds = f["postprocessed"] if load_dof_angles: labels["dof_angles"] = ds["dof_angles"][frame_indices, :] @@ -147,12 +151,6 @@ def read_simulated_labels( if load_mesh_states: pose6d_grp = ds["body_segment_states"] all_avail_segments = pose6d_grp.attrs["keys"] - # segment_indices_lookup = { - # name: i for i, name in enumerate(all_avail_segments) - # } - # segment_indices = [ - # segment_indices_lookup[seg_name] for seg_name in segments_for_6dpose - # ] seg_mask = np.array( [name in segments_for_6dpose for name in all_avail_segments] ) @@ -162,6 +160,39 @@ def read_simulated_labels( mesh_quat = pose6d_grp["quat_global"][frame_indices, :, :] labels["mesh_quat"] = mesh_quat[:, seg_mask, :] + # ! TODO: Move the following logic to neuromechfly.postprocess + frame_range_in_full_sim = ds.attrs["frame_indices_in_full_simulation"] + full_sim_frame_start, full_sim_frame_end = frame_range_in_full_sim + cam_matrices_all = f["raw/camera_matrix"][ + full_sim_frame_start:full_sim_frame_end, :, : + ] + assert cam_matrices_all.shape[0] == pose6d_grp["pos_global"].shape[0] + cam_matrices = cam_matrices_all[frame_indices, :, :] + assert cam_matrices.shape == (len(frame_indices), 3, 4) + for frame_idx in range(len(frame_indices)): + cam_mat = cam_matrices[frame_idx, :, :] + cam_intrinsics, cam_rotation = rq(cam_mat[:, :3]) + _sign_multiplier = np.diag(np.sign(np.diag(cam_intrinsics))) + cam_intrinsics = cam_intrinsics @ _sign_multiplier + cam_rotation = _sign_multiplier @ cam_rotation + cam_translation = np.linalg.inv(cam_intrinsics) @ cam_mat[:, 3] + # Ensure proper rotation matrix (det = 1) + if np.linalg.det(cam_rotation) < 0: + cam_rotation = -cam_rotation + cam_intrinsics = -cam_intrinsics + rot_world_to_cam = Rotation.from_matrix(cam_rotation) + + for seg_idx in range(seg_mask.sum()): + glob_pos = labels["mesh_pos"][frame_idx, seg_idx, :] + glob_quat = labels["mesh_quat"][frame_idx, seg_idx, :] + pos_rel_cam = cam_rotation @ glob_pos + cam_translation + rot_mesh = Rotation.from_quat(glob_quat, scalar_first=True) + rot_rel_cam = rot_world_to_cam * rot_mesh + quat_rel_cam = rot_rel_cam.as_quat(scalar_first=True) + labels["mesh_pos"][frame_idx, seg_idx, :] = pos_rel_cam + labels["mesh_quat"][frame_idx, seg_idx, :] = quat_rel_cam + # ! TODO end + if load_body_seg_maps: seg_labels_ds = ds["segmentation_labels"] # Resize to shape of synthetic frames via nearest neighbor resampling From f0d3f299d8c6b1301ff049dd9cfd93e500289c10 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 1 Dec 2025 02:39:47 +0100 Subject: [PATCH 16/33] [run] run atomic batch extraction with multiple torque jobs --- .../gen_batch_scripts.py | 22 ++++++++++ .../atomic_batch_extraction/submit_all.sh | 20 +++++++++ .../atomic_batch_extraction/template.run | 42 +++++++++++++++++++ 3 files changed, 84 insertions(+) create mode 100644 scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py create mode 100644 scripts_on_cluster/atomic_batch_extraction/submit_all.sh create mode 100644 scripts_on_cluster/atomic_batch_extraction/template.run diff --git a/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py b/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py new file mode 100644 index 0000000..60bf884 --- /dev/null +++ b/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py @@ -0,0 +1,22 @@ +from pathlib import Path + + +# Batch script generation +template_path = Path("template.run") +batch_scripts_dir = Path("batch_scripts/") +batch_scripts_dir.mkdir(exist_ok=True, parents=True) + +# Configs by task +synthetic_videos_basedir = Path("/work/upramdya/sibo_temp/poseforge/bulk_data/style_transfer/production/translated_videos/") +trial_names_all = [x.name for x in synthetic_videos_basedir.glob("BO_Gal4_*")] + +# Generate batch scripts +with open("template.run") as f: + template_str = f.read() + +for trial_name in trial_names_all: + batch_script_str = template_str.replace("<<>>", trial_name) + with open(batch_scripts_dir / f"{trial_name}.run", "w") as f: + f.write(batch_script_str) + +print(f"Generated {len(trial_names_all)} batch scripts under {batch_scripts_dir}") diff --git a/scripts_on_cluster/atomic_batch_extraction/submit_all.sh b/scripts_on_cluster/atomic_batch_extraction/submit_all.sh new file mode 100644 index 0000000..1dc2024 --- /dev/null +++ b/scripts_on_cluster/atomic_batch_extraction/submit_all.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +scripts_dir="./batch_scripts" + +files=($(ls $scripts_dir/*.run | sort)) +file_count=${#files[@]} + +echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " +read -r confirmation +if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then + echo "Submission canceled." + exit 1 +fi + +for file in "${files[@]}"; do + echo "Submitting $file" + sbatch $file +done + +echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/atomic_batch_extraction/template.run b/scripts_on_cluster/atomic_batch_extraction/template.run new file mode 100644 index 0000000..0e562d7 --- /dev/null +++ b/scripts_on_cluster/atomic_batch_extraction/template.run @@ -0,0 +1,42 @@ +#!/bin/bash -l + +#SBATCH --job-name exabat_<<>> +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 36 +#SBATCH --mem 200GB +#SBATCH --time 8:00:00 +#SBATCH --partition=standard +#SBATCH --qos=serial +#SBATCH --output logs/<<>>.log + +echo "Hello from $(hostname)" + +# Set up environment and package +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge_temp + +# Define paths +project_root="/work/upramdya/sibo_temp/poseforge" +extraction_cli_path="src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py" +synthetic_videos_basedir="bulk_data/style_transfer/production/translated_videos/" +nmf_rendering_basedir="bulk_data/nmf_rendering_with_6dpose/" +output_basedir="/work/upramdya/sibo/poseforge/bulk_data/pose_estimation/atomic_batches/4variants_with_6dpose/" + +# Setup +cd $project_root + +# Training +echo "Processing $trial_name at $(date)" +python -u $extraction_cli_path \ + --n_jobs 8 \ + --atomic-batch-nframes 32 \ + --atomic-batch-nvariants-max 4 \ + --minimum-time-diff-frames 60 \ + --original-image-size 464 464 \ + --input-basedir $synthetic_videos_basedir/<<>> \ + --nmf-sim-rendering-basedir $nmf_rendering_basedir \ + --output-dir $output_basedir/<<>> + +echo "Extraction ends at $(date)" From f88888455022e68b99e242a9413ed173aa9bccc8 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 1 Dec 2025 02:47:31 +0100 Subject: [PATCH 17/33] [feature/fix] use hybrid attention-gated/global features, account for camera z offset --- src/poseforge/pose/pose6d/config.py | 9 +- src/poseforge/pose/pose6d/model.py | 112 +++++++++++++----- src/poseforge/pose/pose6d/pipeline.py | 13 +- .../pose/pose6d/scripts/train_mesh6d_model.py | 6 +- 4 files changed, 99 insertions(+), 41 deletions(-) diff --git a/src/poseforge/pose/pose6d/config.py b/src/poseforge/pose/pose6d/config.py index 76b4938..c2bef76 100644 --- a/src/poseforge/pose/pose6d/config.py +++ b/src/poseforge/pose/pose6d/config.py @@ -8,8 +8,12 @@ class ModelArchitectureConfig(SerializableDataClass): # Number of object segments. # Default: 6 legs * (coxa, femur, tibia, tarsus1) + thorax = 25 n_segments: int = 25 - # Number of hidden channels in the final upsampling layer before pose regression - final_upsampler_n_hidden_channels: int = 256 + # Number of feature channels that are gated by a per-segment attention mechanism + n_attention_gated_feature_channels: int = 128 + # Number of feature channels that are not gated by attention + n_global_feature_channels: int = 128 + # Camera distance + camera_distance: float = 100.0 @dataclass(frozen=True) @@ -60,6 +64,7 @@ class TrainingDataConfig(SerializableDataClass): class OptimizerConfig(SerializableDataClass): learning_rate_encoder: float = 3e-5 learning_rate_deconv: float = 3e-4 + learning_rate_attention_heads: float = 3e-4 learning_rate_pose6d_heads: float = 3e-4 weight_decay: float = 1e-5 diff --git a/src/poseforge/pose/pose6d/model.py b/src/poseforge/pose/pose6d/model.py index ceca4f6..42f3ddc 100644 --- a/src/poseforge/pose/pose6d/model.py +++ b/src/poseforge/pose/pose6d/model.py @@ -13,12 +13,22 @@ def __init__( self, n_segments: int, feature_extractor: ResNetFeatureExtractor, - final_upsampler_n_hidden_channels: int, + n_attention_gated_feature_channels: int, + n_global_feature_channels: int, + camera_distance: float, ): super(Pose6DModel, self).__init__() self.n_segments = n_segments self.feature_extractor = feature_extractor - self.final_upsampler_n_hidden_channels = final_upsampler_n_hidden_channels + self.n_attention_gated_feature_channels = n_attention_gated_feature_channels + self.n_global_feature_channels = n_global_feature_channels + self.n_feature_channels_total = ( + n_attention_gated_feature_channels + n_global_feature_channels + ) + self.camera_distance = camera_distance + # MuJoCo convention: in front of camera = negative z, so floor is at + # (*, *, -camera_distance) + self.z_center = -camera_distance # Create decoder (decoder1/2/3/4 mirror encoder layers 1/2/3/4) # Note that when upsampling, we actually run decoder4 first, decoder1 last @@ -28,47 +38,62 @@ def __init__( # information. self.dec_layer4 = DecoderBlock(512, 256, 512) self.dec_layer3 = DecoderBlock(512, 128, 256) - self.dec_layer2 = DecoderBlock(256, 64, final_upsampler_n_hidden_channels) + self.dec_layer2 = DecoderBlock(256, 64, self.n_feature_channels_total) + + # Attention layer + if n_attention_gated_feature_channels > 0: + _attention_heads = [ + nn.Sequential( + nn.Conv2d( + in_channels=self.n_feature_channels_total + n_segments, + out_channels=128, + kernel_size=3, + padding=1, + ), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1), + nn.Sigmoid(), + ) + for _ in range(n_segments) + ] + self.attention_heads = nn.ModuleList(_attention_heads) + else: + self.attention_heads = None # 6D pose prediction head (one per segment) - _pose6d_heads = [] - for _ in range(n_segments): - attention_conv = nn.Conv2d( - in_channels=self.final_upsampler_n_hidden_channels + n_segments, - out_channels=self.final_upsampler_n_hidden_channels, - kernel_size=3, - padding=1, + _pose6d_heads = [ + nn.Sequential( + nn.Linear(in_features=self.n_feature_channels_total, out_features=512), + nn.ReLU(inplace=True), + nn.Dropout(p=0.3), + nn.Linear(in_features=512, out_features=256), + nn.ReLU(inplace=True), + nn.Dropout(p=0.3), + nn.Linear(in_features=256, out_features=7), ) - pool = nn.AdaptiveAvgPool2d((1, 1)) - linear1 = nn.Linear(self.final_upsampler_n_hidden_channels, 256) - relu = nn.ReLU(inplace=True) - dropout = nn.Dropout(p=0.3) - linear2 = nn.Linear(256, 7) # 3 translation + 4 rotation (quaternion) - pose6d_head = nn.Sequential( - attention_conv, pool, nn.Flatten(), linear1, relu, dropout, linear2 - ) - _pose6d_heads.append(pose6d_head) + for _ in range(n_segments) + ] self.pose6d_heads = nn.ModuleList(_pose6d_heads) @classmethod def create_from_config( - cls, architecture_config: config.ModelArchitectureConfig | Path | str + cls, arch_config: config.ModelArchitectureConfig | Path | str ) -> "Pose6DModel": # Load from file if config is a path - if isinstance(architecture_config, (str, Path)): - architecture_config = config.ModelArchitectureConfig.load( - architecture_config - ) - logger.info(f"Loaded model architecture config from {architecture_config}") + if isinstance(arch_config, (str, Path)): + arch_config = config.ModelArchitectureConfig.load(arch_config) + logger.info(f"Loaded model architecture config from {arch_config}") # Initialize feature extractor (WITHOUT WEIGHTS at this step!) feature_extractor = ResNetFeatureExtractor() # Initialize Pose6DModel (self) from config (WITHOUT WEIGHTS at this step!) obj = cls( - n_segments=architecture_config.n_segments, + n_segments=arch_config.n_segments, feature_extractor=feature_extractor, - final_upsampler_n_hidden_channels=architecture_config.final_upsampler_n_hidden_channels, + n_attention_gated_feature_channels=arch_config.n_attention_gated_feature_channels, + n_global_feature_channels=arch_config.n_global_feature_channels, + camera_distance=arch_config.camera_distance, ) logger.info("Initialized Pose6DModel from architecture config") @@ -123,20 +148,44 @@ def forward(self, image: torch.Tensor, bodyseg_prob: torch.Tensor) -> dict: d4 = e4 # this is just the bottleneck d3 = self.dec_layer4(d4, e3) d2 = self.dec_layer3(d3, e2) - d1 = self.dec_layer2(d2, e1) + features = self.dec_layer2(d2, e1) # (B, n_feature_channels_total, 64, 64) # Attention using bodyseg prediction happens at 64x64 resolution (stride 4) - assert d1.shape == (bat_size, self.final_upsampler_n_hidden_channels, 64, 64) + assert features.shape == (bat_size, self.n_feature_channels_total, 64, 64) assert bodyseg_prob.shape == (bat_size, self.n_segments, 64, 64) # (B, n_channels_feature_map + n_segments, H, W) - feature_map_with_seg_prob = torch.cat([d1, bodyseg_prob], dim=1) + feature_map_with_seg_prob = torch.cat([features, bodyseg_prob], dim=1) # Process each segment separately translation_pred_list = [] quaternion_pred_list = [] for seg_idx in range(self.n_segments): + # Global features + global_features_pooled = torch.mean( + features[:, : self.n_global_feature_channels, :, :], dim=[2, 3] + ) + + # Attention gated features + if self.n_attention_gated_feature_channels > 0: + attn_head = self.attention_heads[seg_idx] + attn_map = attn_head(feature_map_with_seg_prob) # (B, 1, H, W) + gated_features = ( + features[:, -self.n_attention_gated_feature_channels :, :, :] + * attn_map + ) + attn_weight_total = torch.sum(attn_map, dim=[2, 3]) + 1e-6 # (B, 1) + gated_features_pooled = ( + gated_features.sum(dim=[2, 3]) / attn_weight_total + ) # (B, C) + all_features_pooled = torch.cat( + [global_features_pooled, gated_features_pooled], dim=1 + ) + else: + all_features_pooled = global_features_pooled + + # FC layers for 6D pose prediction head = self.pose6d_heads[seg_idx] - pose6d_pred = head(feature_map_with_seg_prob) # (B, 7) + pose6d_pred = head(all_features_pooled) # (B, 7) translation_pred = pose6d_pred[:, 0:3] # (B, 3) quaternion_pred = pose6d_pred[:, 3:7] # (B, 4) quaternion_pred = F.normalize(quaternion_pred, p=2, dim=1) # normalize quat @@ -145,6 +194,7 @@ def forward(self, image: torch.Tensor, bodyseg_prob: torch.Tensor) -> dict: # Stack predictions into single tensors of shape (B, n_segments, 3 or 4) translation_pred_all = torch.stack(translation_pred_list, dim=1) + translation_pred_all[:, :, 2] += self.z_center quaternion_pred_all = torch.stack(quaternion_pred_list, dim=1) return translation_pred_all, quaternion_pred_all diff --git a/src/poseforge/pose/pose6d/pipeline.py b/src/poseforge/pose/pose6d/pipeline.py index ddc32ea..dcbdfdc 100644 --- a/src/poseforge/pose/pose6d/pipeline.py +++ b/src/poseforge/pose/pose6d/pipeline.py @@ -376,14 +376,17 @@ def _create_optimizer(self, optimizer_config: config.OptimizerConfig): "lr": optimizer_config.learning_rate_deconv, }, { - "params": list( - chain( - self.model.pose6d_heads.parameters(), - ) - ), + "params": self.model.pose6d_heads.parameters(), "lr": optimizer_config.learning_rate_pose6d_heads, }, ] + if self.model.n_attention_gated_feature_channels > 0: + params.append( + { + "params": self.model.attention_heads.parameters(), + "lr": optimizer_config.learning_rate_attention_heads, + } + ) optimizer = torch.optim.AdamW( params, weight_decay=optimizer_config.weight_decay diff --git a/src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py b/src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py index aab8a9d..6c4ce39 100644 --- a/src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py +++ b/src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py @@ -150,10 +150,10 @@ def print_model_summary(training_data_config, model): ) optimizer_config = config.OptimizerConfig() training_artifacts_config = config.TrainingArtifactsConfig( - output_basedir="bulk_data/pose_estimation/mesh6d/trial_20251130z/", + output_basedir="bulk_data/pose_estimation/mesh6d/trial_20251130z2/", logging_interval=10, # 1000 - checkpoint_interval=30, # 1000 - validation_interval=30, # 1000 + checkpoint_interval=100, # 1000 + validation_interval=100, # 1000 n_batches_per_validation=30, # 300 ) train_mesh6d_model( From a9f435954c0c25f5fe1a360fb1cac01e83d35c3e Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 1 Dec 2025 03:22:51 +0100 Subject: [PATCH 18/33] [feature] update training script and configuration for 6D mesh pose model --- scripts_on_cluster/pose6d_training/train.run | 81 +++++++++++++++ src/poseforge/pose/pose6d/config.py | 6 +- src/poseforge/pose/pose6d/pipeline.py | 4 +- ...{train_mesh6d_model.py => run_training.py} | 98 +++++++++---------- 4 files changed, 134 insertions(+), 55 deletions(-) create mode 100644 scripts_on_cluster/pose6d_training/train.run rename src/poseforge/pose/pose6d/scripts/{train_mesh6d_model.py => run_training.py} (66%) diff --git a/scripts_on_cluster/pose6d_training/train.run b/scripts_on_cluster/pose6d_training/train.run new file mode 100644 index 0000000..6a56654 --- /dev/null +++ b/scripts_on_cluster/pose6d_training/train.run @@ -0,0 +1,81 @@ +#!/bin/bash -l + +#SBATCH --job-name pose6d-training +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 32 +#SBATCH --mem 92GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/pose6d_training/output_20251201a.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge +cd $HOME/poseforge + +training_cli_path="src/poseforge/pose/pose6d/scripts/run_training.py" +training_trial_name="output_20251201a" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" + +echo "Training starting at $(date)" + +python -u $training_cli_path \ + --n_epochs 50 \ + --seed 42 \ + --model-arch-config.n-segments 25 \ + --model-arch-config.n-attention-gated-feature-channels 128 \ + --model-arch-config.n-global-feature-channels 128 \ + --model-arch-config.camera-distance 100.0 \ + --model-weights-config.feature-extractor-weights \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ + --loss-config.translation-weight 1.0 \ + --loss-config.quaternion-weight 2.0 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ + --training-data-config.input-image-size 256 256 \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 128 \ + --training-data-config.val-batch-size 512 \ + --training-data-config.n-workers 8 \ + --optimizer-config.learning_rate_encoder 0.00003 \ + --optimizer-config.learning_rate_upsample 0.0003 \ + --optimizer-config.learning_rate_attention 0.0003 \ + --optimizer-config.learning_rate_pose6d_heads 0.0003 \ + --optimizer-config.weight-decay 0.00001 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/pose6d/$training_trial_name/" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 1000 \ + --training-artifacts-config.validation-interval 300 \ + --training-artifacts-config.n-batches-per-validation 30 + +echo "Training ends at $(date)" diff --git a/src/poseforge/pose/pose6d/config.py b/src/poseforge/pose/pose6d/config.py index c2bef76..dfacd2a 100644 --- a/src/poseforge/pose/pose6d/config.py +++ b/src/poseforge/pose/pose6d/config.py @@ -56,15 +56,13 @@ class TrainingDataConfig(SerializableDataClass): val_batch_size: int # Number of workers for data loading. Use number of CPU cores if None. n_workers: int | None = None - # Optional kernel size for dilating bodyseg masks - mask_dilation_kernel: int | None = None @dataclass(frozen=True) class OptimizerConfig(SerializableDataClass): learning_rate_encoder: float = 3e-5 - learning_rate_deconv: float = 3e-4 - learning_rate_attention_heads: float = 3e-4 + learning_rate_upsample: float = 3e-4 + learning_rate_attention: float = 3e-4 learning_rate_pose6d_heads: float = 3e-4 weight_decay: float = 1e-5 diff --git a/src/poseforge/pose/pose6d/pipeline.py b/src/poseforge/pose/pose6d/pipeline.py index dcbdfdc..49c96f5 100644 --- a/src/poseforge/pose/pose6d/pipeline.py +++ b/src/poseforge/pose/pose6d/pipeline.py @@ -373,7 +373,7 @@ def _create_optimizer(self, optimizer_config: config.OptimizerConfig): self.model.dec_layer4.parameters(), ) ), - "lr": optimizer_config.learning_rate_deconv, + "lr": optimizer_config.learning_rate_upsample, }, { "params": self.model.pose6d_heads.parameters(), @@ -384,7 +384,7 @@ def _create_optimizer(self, optimizer_config: config.OptimizerConfig): params.append( { "params": self.model.attention_heads.parameters(), - "lr": optimizer_config.learning_rate_attention_heads, + "lr": optimizer_config.learning_rate_attention, } ) diff --git a/src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py b/src/poseforge/pose/pose6d/scripts/run_training.py similarity index 66% rename from src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py rename to src/poseforge/pose/pose6d/scripts/run_training.py index 6c4ce39..fe7b3e2 100644 --- a/src/poseforge/pose/pose6d/scripts/train_mesh6d_model.py +++ b/src/poseforge/pose/pose6d/scripts/run_training.py @@ -116,54 +116,54 @@ def print_model_summary(training_data_config, model): if __name__ == "__main__": - # import tyro + import tyro - # tyro.cli( - # train_mesh6d_model, - # prog=f"python {Path(__file__).name}", - # description="Train a 6D mesh pose model using pretrained feature extractor.", - # ) - - # Example using native Python function calls: - model_architecture_config = config.ModelArchitectureConfig() - model_weights_config = config.ModelWeightsConfig( - feature_extractor_weights="bulk_data/pose_estimation/contrastive_pretraining/trial_20251125a_lowlr/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth", - model_weights=None, - ) - loss_config = config.LossConfig() - data_basedir = Path("bulk_data/pose_estimation/atomic_batches/4variants") - train_data_dirs = [ - data_basedir / f"BO_Gal4_fly{fly}_trial{trial:03d}" - for fly in range(1, 5) # flies 1-4 - for trial in range(1, 6) # trials 1-5 - ] - val_data_dirs = [data_basedir / f"BO_Gal4_fly5_trial001"] - training_data_config = config.TrainingDataConfig( - train_data_dirs=[str(path) for path in train_data_dirs], - val_data_dirs=[str(path) for path in val_data_dirs], - input_image_size=(256, 256), - atomic_batch_n_samples=32, - atomic_batch_n_variants=4, - train_batch_size=32, - val_batch_size=32, - n_workers=8, - ) - optimizer_config = config.OptimizerConfig() - training_artifacts_config = config.TrainingArtifactsConfig( - output_basedir="bulk_data/pose_estimation/mesh6d/trial_20251130z2/", - logging_interval=10, # 1000 - checkpoint_interval=100, # 1000 - validation_interval=100, # 1000 - n_batches_per_validation=30, # 300 - ) - train_mesh6d_model( - n_epochs=10, - model_architecture_config=model_architecture_config, - model_weights_config=model_weights_config, - loss_config=loss_config, - training_data_config=training_data_config, - optimizer_config=optimizer_config, - training_artifacts_config=training_artifacts_config, - seed=42, - half_batch_size_for_debugging=True, + tyro.cli( + train_mesh6d_model, + prog=f"python {Path(__file__).name}", + description="Train a 6D mesh pose model using pretrained feature extractor.", ) + + # # Example using native Python function calls: + # model_architecture_config = config.ModelArchitectureConfig() + # model_weights_config = config.ModelWeightsConfig( + # feature_extractor_weights="bulk_data/pose_estimation/contrastive_pretraining/trial_20251125a_lowlr/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth", + # model_weights=None, + # ) + # loss_config = config.LossConfig(translation_weight=1.0, quaternion_weight=2.0) + # data_basedir = Path("bulk_data/pose_estimation/atomic_batches/4variants") + # train_data_dirs = [ + # data_basedir / f"BO_Gal4_fly{fly}_trial{trial:03d}" + # for fly in range(1, 5) # flies 1-4 + # for trial in range(1, 6) # trials 1-5 + # ] + # val_data_dirs = [data_basedir / f"BO_Gal4_fly5_trial001"] + # training_data_config = config.TrainingDataConfig( + # train_data_dirs=[str(path) for path in train_data_dirs], + # val_data_dirs=[str(path) for path in val_data_dirs], + # input_image_size=(256, 256), + # atomic_batch_n_samples=32, + # atomic_batch_n_variants=4, + # train_batch_size=32, + # val_batch_size=32, + # n_workers=8, + # ) + # optimizer_config = config.OptimizerConfig() + # training_artifacts_config = config.TrainingArtifactsConfig( + # output_basedir="bulk_data/pose_estimation/mesh6d/trial_20251130z2/", + # logging_interval=10, # 1000 + # checkpoint_interval=100, # 1000 + # validation_interval=100, # 1000 + # n_batches_per_validation=30, # 300 + # ) + # train_mesh6d_model( + # n_epochs=10, + # model_architecture_config=model_architecture_config, + # model_weights_config=model_weights_config, + # loss_config=loss_config, + # training_data_config=training_data_config, + # optimizer_config=optimizer_config, + # training_artifacts_config=training_artifacts_config, + # seed=42, + # half_batch_size_for_debugging=True, + # ) From 40966eec1727680c2b569664870d1be932e1e258 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 1 Dec 2025 10:13:04 +0100 Subject: [PATCH 19/33] add logging --- .../{train.run => train_20251201a.run} | 14 ++-- .../pose6d_training/train_20251201b.run | 81 +++++++++++++++++++ .../pose6d_training/train_20251201c.run | 81 +++++++++++++++++++ src/poseforge/pose/pose6d/pipeline.py | 12 ++- 4 files changed, 180 insertions(+), 8 deletions(-) rename scripts_on_cluster/pose6d_training/{train.run => train_20251201a.run} (92%) create mode 100644 scripts_on_cluster/pose6d_training/train_20251201b.run create mode 100644 scripts_on_cluster/pose6d_training/train_20251201c.run diff --git a/scripts_on_cluster/pose6d_training/train.run b/scripts_on_cluster/pose6d_training/train_20251201a.run similarity index 92% rename from scripts_on_cluster/pose6d_training/train.run rename to scripts_on_cluster/pose6d_training/train_20251201a.run index 6a56654..5511dcc 100644 --- a/scripts_on_cluster/pose6d_training/train.run +++ b/scripts_on_cluster/pose6d_training/train_20251201a.run @@ -19,7 +19,7 @@ conda activate poseforge cd $HOME/poseforge training_cli_path="src/poseforge/pose/pose6d/scripts/run_training.py" -training_trial_name="output_20251201a" +training_trial_name="20251201a" contrastive_pretraining_trial_name="trial_20251125a_lowlr" contrastive_pretraining_epoch="epoch009" contrastive_pretraining_local_step="step003055" @@ -27,16 +27,16 @@ contrastive_pretraining_local_step="step003055" echo "Training starting at $(date)" python -u $training_cli_path \ - --n_epochs 50 \ + --n_epochs 15 \ --seed 42 \ - --model-arch-config.n-segments 25 \ - --model-arch-config.n-attention-gated-feature-channels 128 \ - --model-arch-config.n-global-feature-channels 128 \ - --model-arch-config.camera-distance 100.0 \ + --model-architecture-config.n-segments 25 \ + --model-architecture-config.n-attention-gated-feature-channels 128 \ + --model-architecture-config.n-global-feature-channels 128 \ + --model-architecture-config.camera-distance 100.0 \ --model-weights-config.feature-extractor-weights \ "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ --loss-config.translation-weight 1.0 \ - --loss-config.quaternion-weight 2.0 \ + --loss-config.rotation-weight 2.0 \ --training-data-config.train-data-dirs \ "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ diff --git a/scripts_on_cluster/pose6d_training/train_20251201b.run b/scripts_on_cluster/pose6d_training/train_20251201b.run new file mode 100644 index 0000000..15f058d --- /dev/null +++ b/scripts_on_cluster/pose6d_training/train_20251201b.run @@ -0,0 +1,81 @@ +#!/bin/bash -l + +#SBATCH --job-name pose6d-training +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 32 +#SBATCH --mem 92GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/pose6d_training/output_20251201b.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge +cd $HOME/poseforge + +training_cli_path="src/poseforge/pose/pose6d/scripts/run_training.py" +training_trial_name="20251201b" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" + +echo "Training starting at $(date)" + +python -u $training_cli_path \ + --n_epochs 15 \ + --seed 42 \ + --model-architecture-config.n-segments 25 \ + --model-architecture-config.n-attention-gated-feature-channels 128 \ + --model-architecture-config.n-global-feature-channels 128 \ + --model-architecture-config.camera-distance 100.0 \ + --model-weights-config.feature-extractor-weights \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ + --loss-config.translation-weight 1.0 \ + --loss-config.rotation-weight 10.0 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ + --training-data-config.input-image-size 256 256 \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 128 \ + --training-data-config.val-batch-size 512 \ + --training-data-config.n-workers 8 \ + --optimizer-config.learning_rate_encoder 0.00003 \ + --optimizer-config.learning_rate_upsample 0.0003 \ + --optimizer-config.learning_rate_attention 0.0003 \ + --optimizer-config.learning_rate_pose6d_heads 0.0003 \ + --optimizer-config.weight-decay 0.00001 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/pose6d/$training_trial_name/" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 1000 \ + --training-artifacts-config.validation-interval 300 \ + --training-artifacts-config.n-batches-per-validation 30 + +echo "Training ends at $(date)" diff --git a/scripts_on_cluster/pose6d_training/train_20251201c.run b/scripts_on_cluster/pose6d_training/train_20251201c.run new file mode 100644 index 0000000..f822166 --- /dev/null +++ b/scripts_on_cluster/pose6d_training/train_20251201c.run @@ -0,0 +1,81 @@ +#!/bin/bash -l + +#SBATCH --job-name pose6d-training +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 32 +#SBATCH --mem 92GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/pose6d_training/output_20251201c.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge +cd $HOME/poseforge + +training_cli_path="src/poseforge/pose/pose6d/scripts/run_training.py" +training_trial_name="20251201c" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" + +echo "Training starting at $(date)" + +python -u $training_cli_path \ + --n_epochs 15 \ + --seed 42 \ + --model-architecture-config.n-segments 25 \ + --model-architecture-config.n-attention-gated-feature-channels 192 \ + --model-architecture-config.n-global-feature-channels 64 \ + --model-architecture-config.camera-distance 100.0 \ + --model-weights-config.feature-extractor-weights \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ + --loss-config.translation-weight 1.0 \ + --loss-config.rotation-weight 10.0 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ + --training-data-config.input-image-size 256 256 \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 128 \ + --training-data-config.val-batch-size 512 \ + --training-data-config.n-workers 8 \ + --optimizer-config.learning_rate_encoder 0.0001 \ + --optimizer-config.learning_rate_upsample 0.001 \ + --optimizer-config.learning_rate_attention 0.001 \ + --optimizer-config.learning_rate_pose6d_heads 0.001 \ + --optimizer-config.weight-decay 0.00001 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/pose6d/$training_trial_name/" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 1000 \ + --training-artifacts-config.validation-interval 300 \ + --training-artifacts-config.n-batches-per-validation 30 + +echo "Training ends at $(date)" diff --git a/src/poseforge/pose/pose6d/pipeline.py b/src/poseforge/pose/pose6d/pipeline.py index 49c96f5..9fa10e4 100644 --- a/src/poseforge/pose/pose6d/pipeline.py +++ b/src/poseforge/pose/pose6d/pipeline.py @@ -155,8 +155,8 @@ def train( pred_pos, pred_quat, sim_data["mesh_pos"], sim_data["mesh_quat"] ) - # Check if float16 is used if epoch_idx == 0 and step_idx == 0: + # Check if float16 is used self._check_amp_status_for_model_params( grad_scaler, subtitle="Model parameters at start of training", @@ -171,6 +171,16 @@ def train( subtitle="Variables at start of training", ) + # Check magnitude of predictions + for i, seg_name in enumerate(segments_for_6dpose): + _pos = pred_pos[0, i, :].detach().cpu().tolist() + _quat = pred_quat[0, i, :].detach().cpu().tolist() + logger.info( + f"Sample {i}, {seg_name}: " + f"pred_pos={[round(x, 3) for x in _pos]}, " + f"pred_quat={[round(x, 3) for x in _quat]}" + ) + # Backpropagate and optimize optimizer.zero_grad(set_to_none=True) grad_scaler.scale(loss_dict["total_loss"]).backward() From f337127503c953e764e25727da10ce2eb2704a76 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 1 Dec 2025 10:24:11 +0100 Subject: [PATCH 20/33] [fix] compute relative 6dpose in postprocessing --- src/poseforge/neuromechfly/postprocessing.py | 78 +++++++++++++++++--- 1 file changed, 69 insertions(+), 9 deletions(-) diff --git a/src/poseforge/neuromechfly/postprocessing.py b/src/poseforge/neuromechfly/postprocessing.py index 507d8f3..68727d8 100644 --- a/src/poseforge/neuromechfly/postprocessing.py +++ b/src/poseforge/neuromechfly/postprocessing.py @@ -10,6 +10,8 @@ from pathlib import Path from tqdm import tqdm from joblib import Parallel, delayed +from scipy.linalg import rq +from scipy.spatial.transform import Rotation import poseforge.neuromechfly.constants as constants from poseforge.util.plot import ( @@ -304,9 +306,63 @@ def process_single_frame( # Save object segmentation masks seg_labels = segment_label_parser(rendered_images_transformed) + # Transform mesh states to camera's coordinate system + pos_glob = h5_file["body_segment_states/pos_global"][frame_idx, :, :] + quat_glob = h5_file["body_segment_states/quat_global"][frame_idx, :, :] + cam_projmat = h5_file["camera_matrix"][frame_idx, :, :] # 3x4 mat from mujoco + derived_variables["pos_rel_cam"], derived_variables["quat_rel_cam"] = ( + calculate_6dpose_relative_to_camera(pos_glob, quat_glob, cam_projmat) + ) + return rendered_images_transformed, derived_variables, seg_labels +def calculate_6dpose_relative_to_camera(pos_global, quat_global, cam_projmat): + """Convert global mesh positions and orientations to camera-relative coordinates. + + Args: + pos_global: (num_segments, 3) array of global positions + quat_global: (num_segments, 4) array of global quaternions (scalar first) + cam_projmat: (3, 4) camera projection matrix + + Returns: + pos_rel_cam: (num_segments, 3) array of positions relative to camera + quat_rel_cam: (num_segments, 4) array of quaternions relative to camera + """ + assert pos_global.shape[0] == quat_global.shape[0], "Number of segments mismatch." + n_segments = pos_global.shape[0] + + # Decompose camera projection matrix to get camera intrinsics, rotation, translation + cam_intrinsics, cam_rotation = rq(cam_projmat[:, :3]) + _sign_multiplier = np.diag(np.sign(np.diag(cam_intrinsics))) + cam_intrinsics = cam_intrinsics @ _sign_multiplier + cam_rotation = _sign_multiplier @ cam_rotation + if np.linalg.det(cam_rotation) < 0: + cam_rotation = -cam_rotation # ensure proper rotation matrix (det = 1) + cam_intrinsics = -cam_intrinsics + cam_translation = np.linalg.inv(cam_intrinsics) @ cam_projmat[:, 3] + + # Compute rotation from world to camera coordinates + rot_world_to_cam = Rotation.from_matrix(cam_rotation) + + # Convert each segment's position and orientation + pos_rel_cam = np.zeros_like(pos_global) + quat_rel_cam = np.zeros_like(quat_global) + for seg_idx in range(n_segments): + this_pos_glob = pos_global[seg_idx, :] + this_quat_glob = quat_global[seg_idx, :] + + this_pos_rel_cam = cam_rotation @ this_pos_glob + cam_translation + mesh_rot = Rotation.from_quat(this_quat_glob, scalar_first=True) + mesh_rot_rel_cam = rot_world_to_cam * mesh_rot + this_quat_rel_cam = mesh_rot_rel_cam.as_quat(scalar_first=True) + + pos_rel_cam[seg_idx, :] = this_pos_rel_cam + quat_rel_cam[seg_idx, :] = this_quat_rel_cam + + return pos_rel_cam, quat_rel_cam + + def process_subsegment( frames_by_color_coding: list[list[np.ndarray]], segment_h5_file_path: Path, @@ -488,16 +544,20 @@ def process_subsegment( ) # Add mesh state labels - seg_states_grp = postprocessed_group.create_group("body_segment_states") + seg_states_grp = postprocessed_group.create_group("mesh_pose6d_rel_camera") + seg_states_grp.create_dataset( + "pos_rel_cam", + data=np.array(derived_variables_by_key["pos_rel_cam"]), + dtype="float32", + compression="lzf", + ) + seg_states_grp.create_dataset( + "quat_rel_cam", + data=np.array(derived_variables_by_key["quat_rel_cam"]), + dtype="float32", + compression="lzf", + ) seg_states_grp.attrs.update(source_h5_file["body_segment_states"].attrs) - for sensor_type in source_h5_file["body_segment_states"].keys(): - source_ds = source_h5_file["body_segment_states"][sensor_type] - seg_states_grp.create_dataset( - sensor_type, - data=source_ds[frame_idx_start:frame_idx_end, :, :], - dtype="float32", - ) - # ! TODO: Convert to camera coords here! def _draw_pose_2d_and_3d( From e0f5350938e3ab8484030e0082213a69eab1c222 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 1 Dec 2025 11:49:38 +0100 Subject: [PATCH 21/33] [fix] rely on pose6d extraction from nmf postprocessing instead of atomic batch extraction --- .../pose/data/synthetic/sim_data_seq.py | 43 ++----------------- 1 file changed, 3 insertions(+), 40 deletions(-) diff --git a/src/poseforge/pose/data/synthetic/sim_data_seq.py b/src/poseforge/pose/data/synthetic/sim_data_seq.py index 7457c5f..0782124 100644 --- a/src/poseforge/pose/data/synthetic/sim_data_seq.py +++ b/src/poseforge/pose/data/synthetic/sim_data_seq.py @@ -8,10 +8,6 @@ from poseforge.neuromechfly.constants import segments_for_6dpose -# ! TODO Can be removed later -from scipy.linalg import rq -from scipy.spatial.transform import Rotation - class SimulatedDataSequence: def __init__( @@ -149,50 +145,17 @@ def read_simulated_labels( labels["keypoint_pos"] = keypoint_pos if load_mesh_states: - pose6d_grp = ds["body_segment_states"] + pose6d_grp = ds["mesh_pose6d_rel_camera"] all_avail_segments = pose6d_grp.attrs["keys"] seg_mask = np.array( [name in segments_for_6dpose for name in all_avail_segments] ) # h5py only supports fancy indexing along one axis - mesh_pos = pose6d_grp["pos_global"][frame_indices, :, :] + mesh_pos = pose6d_grp["pos_rel_cam"][frame_indices, :, :] labels["mesh_pos"] = mesh_pos[:, seg_mask, :] - mesh_quat = pose6d_grp["quat_global"][frame_indices, :, :] + mesh_quat = pose6d_grp["quat_rel_cam"][frame_indices, :, :] labels["mesh_quat"] = mesh_quat[:, seg_mask, :] - # ! TODO: Move the following logic to neuromechfly.postprocess - frame_range_in_full_sim = ds.attrs["frame_indices_in_full_simulation"] - full_sim_frame_start, full_sim_frame_end = frame_range_in_full_sim - cam_matrices_all = f["raw/camera_matrix"][ - full_sim_frame_start:full_sim_frame_end, :, : - ] - assert cam_matrices_all.shape[0] == pose6d_grp["pos_global"].shape[0] - cam_matrices = cam_matrices_all[frame_indices, :, :] - assert cam_matrices.shape == (len(frame_indices), 3, 4) - for frame_idx in range(len(frame_indices)): - cam_mat = cam_matrices[frame_idx, :, :] - cam_intrinsics, cam_rotation = rq(cam_mat[:, :3]) - _sign_multiplier = np.diag(np.sign(np.diag(cam_intrinsics))) - cam_intrinsics = cam_intrinsics @ _sign_multiplier - cam_rotation = _sign_multiplier @ cam_rotation - cam_translation = np.linalg.inv(cam_intrinsics) @ cam_mat[:, 3] - # Ensure proper rotation matrix (det = 1) - if np.linalg.det(cam_rotation) < 0: - cam_rotation = -cam_rotation - cam_intrinsics = -cam_intrinsics - rot_world_to_cam = Rotation.from_matrix(cam_rotation) - - for seg_idx in range(seg_mask.sum()): - glob_pos = labels["mesh_pos"][frame_idx, seg_idx, :] - glob_quat = labels["mesh_quat"][frame_idx, seg_idx, :] - pos_rel_cam = cam_rotation @ glob_pos + cam_translation - rot_mesh = Rotation.from_quat(glob_quat, scalar_first=True) - rot_rel_cam = rot_world_to_cam * rot_mesh - quat_rel_cam = rot_rel_cam.as_quat(scalar_first=True) - labels["mesh_pos"][frame_idx, seg_idx, :] = pos_rel_cam - labels["mesh_quat"][frame_idx, seg_idx, :] = quat_rel_cam - # ! TODO end - if load_body_seg_maps: seg_labels_ds = ds["segmentation_labels"] # Resize to shape of synthetic frames via nearest neighbor resampling From c34370f54307cdb9ed2b37bf0deb39e3f1bfb8aa Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Tue, 2 Dec 2025 18:05:37 +0100 Subject: [PATCH 22/33] [run] fix paths in batch script --- scripts_on_cluster/atomic_batch_extraction/template.run | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts_on_cluster/atomic_batch_extraction/template.run b/scripts_on_cluster/atomic_batch_extraction/template.run index 0e562d7..99fb8d3 100644 --- a/scripts_on_cluster/atomic_batch_extraction/template.run +++ b/scripts_on_cluster/atomic_batch_extraction/template.run @@ -15,14 +15,14 @@ echo "Hello from $(hostname)" # Set up environment and package . ~/spack/share/spack/setup-env.sh spack load ffmpeg -conda activate poseforge_temp +conda activate poseforge # Define paths -project_root="/work/upramdya/sibo_temp/poseforge" +project_root="/home/sibwang/poseforge" extraction_cli_path="src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py" synthetic_videos_basedir="bulk_data/style_transfer/production/translated_videos/" -nmf_rendering_basedir="bulk_data/nmf_rendering_with_6dpose/" -output_basedir="/work/upramdya/sibo/poseforge/bulk_data/pose_estimation/atomic_batches/4variants_with_6dpose/" +nmf_rendering_basedir="bulk_data/nmf_rendering/" +output_basedir="/home/sibwang/poseforge/bulk_data/pose_estimation/atomic_batches/4variants/" # Setup cd $project_root From 11a9f1479881c20c7444f6d1341488f91e681760 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 8 Dec 2025 18:33:35 +0100 Subject: [PATCH 23/33] [feature] add script to replay simulated 6d pose sequence --- .../neuromechfly/scripts/visualize_meshes.py | 400 +++++++++++++----- src/poseforge/util/data.py | 6 + 2 files changed, 294 insertions(+), 112 deletions(-) diff --git a/src/poseforge/neuromechfly/scripts/visualize_meshes.py b/src/poseforge/neuromechfly/scripts/visualize_meshes.py index cb6554e..db69096 100644 --- a/src/poseforge/neuromechfly/scripts/visualize_meshes.py +++ b/src/poseforge/neuromechfly/scripts/visualize_meshes.py @@ -1,125 +1,301 @@ import numpy as np -import pandas as pd import pyvista as pv import h5py +from tqdm import trange from xml.etree import ElementTree from pathlib import Path +from dataclasses import dataclass from scipy.spatial.transform import Rotation -from scipy.linalg import rq +from pyvista import PolyData +from loguru import logger from poseforge.neuromechfly.constants import legs, all_segment_names_per_leg +from poseforge.util.data import bulk_data_dir -# Define paths -subsegment_dir = Path( - "bulk_data/nmf_rendering/BO_Gal4_fly1_trial001/segment_000/subsegment_000" -) -sim_data_path = subsegment_dir / "processed_simulation_data.h5" -flygym_data_dir = Path("~/projects/flygym/flygym").expanduser() / "data" -nmf_mesh_dir = flygym_data_dir / "mesh" -mjcf_path = flygym_data_dir / "mjcf/neuromechfly_seqik_kinorder_ypr.xml" - -# Load NeuroMechFly model -mjcf_tree = ElementTree.parse(mjcf_path) -worldbody = mjcf_tree.find("worldbody") -body_attributes = {body.attrib["name"]: body.attrib for body in worldbody.iter("body")} - -# Load simulation data -with h5py.File(sim_data_path, "r") as f: - all_seg_pos_global = f["raw/body_segment_states/pos_global"][:] - all_seg_quat_global = f["raw/body_segment_states/quat_global"][:] - all_cam_matrices = f["raw/camera_matrix"][:] - all_seg_names = list(f["raw/body_segment_states"].attrs["keys"]) - -n_frames = all_seg_pos_global.shape[0] -segments_to_include = [ - f"{leg}{seg}" for leg in legs for seg in all_segment_names_per_leg -] -segments_to_include += ["Thorax"] - -# Load original meshes once (before any transformations) -original_meshes = {} -for seg_name in segments_to_include: - mesh_file = nmf_mesh_dir / f"{seg_name}.stl" - original_meshes[seg_name] = pv.read(mesh_file) - -# Create plotter -plotter = pv.Plotter() -plotter.set_background("black") -plotter.show_axes() - -# Add all meshes to plotter initially -current_meshes = {} -for seg_name in segments_to_include: - current_meshes[seg_name] = original_meshes[seg_name].copy() - plotter.add_mesh( - current_meshes[seg_name], - show_edges=False, - name=seg_name, - smooth_shading=True - ) +def load_neuromechfly_meshes(flygym_path: Path | str) -> dict[str, PolyData]: + """Load NeuroMechFly meshes from the specified FlyGym installation path.""" + # Define paths + mesh_dir = Path(flygym_path) / "data/mesh" + mjcf_path = Path(flygym_path) / "data/mjcf/neuromechfly_seqik_kinorder_ypr.xml" + + # Load NeuroMechFly MJCF file (which is just an XML file) + mjcf_tree = ElementTree.parse(mjcf_path) + + # Load body segment meshes + mesh_lookup = {} + for mesh_element in mjcf_tree.findall("asset/mesh"): # findall iterates in order + mesh_path = mesh_dir / Path(mesh_element.attrib["file"]).name + mesh_name = mesh_element.attrib["name"] + if not mesh_name.startswith("mesh_"): + raise RuntimeError( + "Unexpected mesh specified in MJCF: name doesn't start with 'mesh_'" + ) + mesh_name = mesh_name[len("mesh_") :] + mesh_scale_str = mesh_element.attrib["scale"] + try: + mesh_scale_xyz = [float(s) for s in mesh_scale_str.split()] + except ValueError as e: + raise RuntimeError( + "Unexpected mesh specified in MJCF: scale not following format " + "'xscale yscale zscale' (each a number)" + ) from e -# Current frame tracker -current_frame = [0] - - -def update_frame(): - """Update all meshes to the current frame""" - frame_idx = current_frame[0] - cam_matrix = all_cam_matrices[frame_idx, :, :] - - # Compute camera transformation once per frame - cam_intrinsics, cam_rotation = rq(cam_matrix[:, :3]) - _sign_multiplier = np.diag(np.sign(np.diag(cam_intrinsics))) - cam_intrinsics = cam_intrinsics @ _sign_multiplier - cam_rotation = _sign_multiplier @ cam_rotation - cam_translation = np.linalg.inv(cam_intrinsics) @ cam_matrix[:, 3] - - transform_world2cam = np.eye(4) - transform_world2cam[:3, :3] = cam_rotation - transform_world2cam[:3, 3] = cam_translation - transform_cam2world = np.linalg.inv(transform_world2cam) - - # Update each segment mesh - for seg_name in segments_to_include: - seg_idx = all_seg_names.index(seg_name) - translation = all_seg_pos_global[frame_idx, seg_idx, :] - quaternion = all_seg_quat_global[frame_idx, seg_idx, :] - - # Start with original mesh - mesh = original_meshes[seg_name].copy() - - # Scale + # Load mesh and apply initial scaling transform + # (lengths in NeuroMechFly simulations are in mm instead of m - this is how NMF + # tells MuJoCo to scale everything accordingly) + mesh: PolyData = pv.read(mesh_path) + if mesh is None: + raise RuntimeError(f"Mesh '{mesh_name}' from '{mesh_path}' is empty") scale_transform = np.eye(4) - np.fill_diagonal(scale_transform, [1000, 1000, 1000, 1]) + np.fill_diagonal(scale_transform, [*mesh_scale_xyz, 1]) mesh = mesh.transform(scale_transform, inplace=False) - - # Apply MuJoCo state transformation - placement_transform = np.eye(4) - rotation_object = Rotation.from_quat(quaternion, scalar_first=True) - placement_transform[:3, :3] = rotation_object.as_matrix() - placement_transform[:3, 3] = translation - mesh = mesh.transform(placement_transform, inplace=False) - - # Transform to camera coordinates - mesh = mesh.transform(transform_cam2world, inplace=False) - - # Update the mesh points in place - current_meshes[seg_name].points[:] = mesh.points - - # Update frame counter - current_frame[0] = (current_frame[0] + 1) % n_frames - - # Update title to show current frame - plotter.add_text(f"Frame: {frame_idx}/{n_frames}", name="frame_counter", position="upper_left") - - -# Initialize first frame -update_frame() -plotter.reset_camera() - -# Add timer callback for animation (30 fps) -# The callback receives a step argument, so we need to accept it -plotter.add_timer_event(max_steps=n_frames, duration=int(1000/30), callback=lambda step: update_frame()) - -plotter.show() \ No newline at end of file + mesh_lookup[mesh_name] = mesh + + return mesh_lookup + + +@dataclass +class Pose6DSequence: + # XYZ position relative to camera (n_frames, n_segments, 3) + pos_ts: np.ndarray + # Quaternion rotation relative to camera in scalar-first order (wxyz) + # (n_frames, n_segments, 4) + quat_ts: np.ndarray + # Names of body segments included in pos_ts and quat_ts + segments: list[str] + + @classmethod + def from_processed_simulation_data( + cls, + data_path: Path, + segments: list[str] | None, + relative_to_camera: bool = True, + ) -> "Pose6DSequence": + """ + Load 6D pose sequence from processed NeuroMechFly simulation data file. + + Args: + data_path: Path to processed simulation data HDF5 file + segments: List of segment names to load. If None, load all recorded segments + relative_to_camera: If True, load segment poses relative to camera. + If False, load raw global segment poses (pre-postprocessing). + + Returns: + An instance of Pose6DSequence containing the loaded data + """ + pos_ts, quat_ts, segments = _load_simulation_data( + data_path, segments, load_raw=not relative_to_camera + ) + return cls(pos_ts=pos_ts, quat_ts=quat_ts, segments=segments) + + def __len__(self): + return self.pos_ts.shape[0] + + def render( + self, + mesh_assets: dict[str, PolyData], + render_fps: int, + output_path: Path | str | None = None, + display_live: bool = False, + theme: str = "dark", + disable_pbar: bool = False, + ) -> None: + """Render the 6D pose sequence using NeuroMechFly meshes in PyVista. + + Args: + mesh_assets: Dictionary mapping segment names to PyVista PolyData meshes. + render_fps: Frames per second for rendering. + output_path: If display_live is False, path to save the rendered video. + display_live: If True, display the rendering live instead of saving to file. + theme: Display theme, either "dark" or "light". + disable_pbar: If True, disable the progress bar when saving to file. + """ + # Filter & check mesh files from NeuroMechFly + mesh_assets = {k: v for k, v in mesh_assets.items() if k in self.segments} + missing_keys = set(self.segments) - set(mesh_assets.keys()) + if missing_keys: + logger.critical( + "The following segments from Pose6DSequence are not found in " + f"NeuroMechFly meshes: {missing_keys}" + ) + raise KeyError("Some meshes are missing") + + # Set up rendering + plotter, plotted_meshes = _set_up_renderer(mesh_assets, display_live, theme) + + def _update_scene_by_frameid(frameid): + _transform_all_meshes( + mesh_assets, + plotted_meshes, + self.pos_ts[frameid, :, :], + self.quat_ts[frameid, :, :], + self.segments, + ) + + # Set up first frame + _update_scene_by_frameid(0) + plotter.reset_camera() + + # Render each frame + if display_live: + if output_path is not None: + logger.warning( + "Output path is specified but display_live is set to true. " + "Nothing will be saved and the output path will be ignored." + ) + + def _update_with_cleanup(step): + _update_scene_by_frameid(step) + if step == len(self) - 1: + plotter.close() + + plotter.add_timer_event( + max_steps=len(self), + duration=int(1000 / render_fps), # in ms + callback=_update_with_cleanup, + ) + plotter.show() + else: + if output_path is None: + logger.critical( + "When Pose6DSequenceRenderer.display_live is set to false, " + "output_path must be specified." + ) + raise ValueError("Video output path not specified.") + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + plotter.open_movie(output_path, framerate=render_fps) + for i in trange(len(self), disable=disable_pbar): + _update_scene_by_frameid(i) + plotter.write_frame() + plotter.close() + + +def _load_simulation_data( + processed_data_path: Path | str, segments: list[str] | None, load_raw: bool = False +) -> tuple[np.ndarray, np.ndarray, list[str]]: + """Load 6D pose data from processed NeuroMechFly simulation data file. + + Args: + processed_data_path: Path to processed simulation data HDF5 file. + segments: List of segment names to load. If None, load all recorded segments. + load_raw: If True, load raw global segment poses (pre-postprocessing). If False, + load postprocessed segment poses relative to camera. + + Returns: + pos_ts: Array of shape (n_frames, n_segments, 3) with XYZ positions. + quat_ts: Array of shape (n_frames, n_segments, 4) with quaternion rotations + in scalar-first (wxyz) order. + segments: List of segment names corresponding to the loaded data. + """ + with h5py.File(processed_data_path, "r") as f: + if load_raw: + grp = f["raw/body_segment_states/"] + all_recorded_segments = list(grp.attrs["keys"]) + pos_ts = grp["pos_global"][:] + quat_ts = grp["quat_global"][:] + else: + grp = f["postprocessed/mesh_pose6d_rel_camera"] # relative to camera + all_recorded_segments = list(grp.attrs["keys"]) + # grp["pos_rel_cam"][:, seg_indices, :] would be better, but h5py only + # supports index-based fancy indexing if the indices are monotonic, so meh + pos_ts = grp["pos_rel_cam"][:] + quat_ts = grp["quat_rel_cam"][:] + if segments is None: + segments = all_recorded_segments + else: + seg2idx = {seg: i for i, seg in enumerate(all_recorded_segments)} + try: + seg_indices = np.array([seg2idx[seg] for seg in segments]) + except KeyError as e: + raise KeyError(f"Some segments not found in simulation data") from e + pos_ts = pos_ts[:, seg_indices, :] + quat_ts = quat_ts[:, seg_indices, :] + + n_frames, n_segments, _ = pos_ts.shape + assert pos_ts.shape == (n_frames, n_segments, 3), "Invalid shape for mesh pos_ts" + assert quat_ts.shape == (n_frames, n_segments, 4), "Invalid shape for mesh quat_ts" + return pos_ts, quat_ts, segments + + +def _set_up_renderer( + nmf_meshes: dict[str, PolyData], display_live: bool = False, theme: str = "dark" +) -> tuple[pv.Plotter, dict[str, PolyData]]: + """Set up PyVista plotter and add NeuroMechFly meshes to it""" + # Set up plotter + plotter = pv.Plotter(off_screen=not display_live) + plotter.show_axes() + + # Apply display theme + if theme.lower() == "dark": + plotter.set_background("black") + mesh_color = "#eeeeee" + elif theme.lower() == "light": + plotter.set_background("white") + mesh_color = "#bbbbbb" + else: + raise ValueError(f"Undefined display theme '{theme}'") + + # Add meshes initially (don't care about positions) + plotted_meshes = {} + for seg_name, mesh in nmf_meshes.items(): + mesh = nmf_meshes[seg_name].copy() + plotter.add_mesh(mesh, show_edges=False, name=seg_name, color=mesh_color) + plotted_meshes[seg_name] = mesh + + # Set up camera + plotter.camera.position = (0, 0, 0) + plotter.camera.focal_point = (0, 0, 1) # look down +Z axis + plotter.camera.up = (0, -1, 0) # -Y is up (OpenCV/MuJoCo convention) + + return plotter, plotted_meshes + + +def _transform_one_mesh(mesh: PolyData, pos: np.ndarray, quat: np.ndarray) -> PolyData: + """Apply 6D pose (translation + quaternion rotation) to a mesh""" + rot = Rotation.from_quat(quat, scalar_first=True) # MuJoCo uses wxyz order + transform = np.eye(4) + transform[:3, :3] = rot.as_matrix() + transform[:3, 3] = pos + return mesh.transform(transform, inplace=False) + + +def _transform_all_meshes( + mesh_assets: dict[str, PolyData], + plotted_meshes: dict[str, PolyData], + pos_all_segments: np.ndarray, + quat_all_segments: np.ndarray, + segment_names: list[str], +) -> None: + """Update all meshes to the specified 6D poses""" + assert pos_all_segments.shape == (len(segment_names), 3) + assert quat_all_segments.shape == (len(segment_names), 4) + for i, seg_name in enumerate(segment_names): + transformed_mesh = _transform_one_mesh( + mesh_assets[seg_name], pos_all_segments[i, :], quat_all_segments[i, :] + ) + plotted_meshes[seg_name].points = transformed_mesh.points + + +if __name__ == "__main__": + flygym_dir = Path("~/projects/flygym/flygym").expanduser() + sample_sim_path = ( + bulk_data_dir + / "nmf_rendering/BO_Gal4_fly1_trial001/segment_000/subsegment_001/processed_simulation_data.h5" + ) + replayed_segments = [ + f"{leg}{seg}" for leg in legs for seg in all_segment_names_per_leg + ] + ["Thorax"] + + nmf_mesh_assets = load_neuromechfly_meshes(flygym_dir) + sim_data = Pose6DSequence.from_processed_simulation_data( + sample_sim_path, segments=replayed_segments, relative_to_camera=True + ) + sim_data.render( + mesh_assets=nmf_mesh_assets, + render_fps=33, + output_path=sample_sim_path.parent / "pose6d_render.mp4", + display_live=True, + theme="light", + ) diff --git a/src/poseforge/util/data.py b/src/poseforge/util/data.py index 464b0dc..d618ff6 100644 --- a/src/poseforge/util/data.py +++ b/src/poseforge/util/data.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Hashable, Callable, Any +import poseforge + @dataclasses.dataclass(frozen=True) class SerializableDataClass: @@ -143,3 +145,7 @@ def n_open_buckets(self) -> int: @property def n_data_total(self) -> int: return sum(len(buf) for buf in self.buffers.values()) + + +assert len(poseforge.__path__) == 1, "poseforge.__path__ contains multiple paths" +bulk_data_dir = Path(poseforge.__path__[0]).parent.parent / "bulk_data" From 67723b76ace91abcd7ddbe3dc760d4e2606eecdf Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 8 Dec 2025 18:46:03 +0100 Subject: [PATCH 24/33] [housekeeping] update dependencies, gitignore, etc --- .gitignore | 133 +++++++++---------------------------------------- poetry.lock | 84 ++++++++++++++++++++++++++++++- pyproject.toml | 7 ++- 3 files changed, 111 insertions(+), 113 deletions(-) diff --git a/.gitignore b/.gitignore index a5d337c..64f5412 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,24 @@ -# Created by https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,jupyternotebooks,visualstudiocode,vim,git,pycharm+all -# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux,macos,windows,jupyternotebooks,visualstudiocode,vim,git,pycharm+all +bulk_data +production_models/** +!production_models/**/ +!production_models/**/README.md +logs/ +wandb/ +*.mp4 +*.jpg +*.pth +*.png +temp/ +batch_scripts/ +outputs/ +notebooks/ + +# I want all JetBrains-specific file to be ignored +# (gitignore.io defaults allow style and run settings to be included) +.idea/ + +# Created by https://www.toptal.com/developers/gitignore/api/git,vim,linux,macos,python,windows,jupyternotebooks,visualstudiocode +# Edit at https://www.toptal.com/developers/gitignore?templates=git,vim,linux,macos,python,windows,jupyternotebooks,visualstudiocode ### Git ### # Created by git for backups. To disable backups in Git: @@ -78,94 +97,6 @@ Temporary Items # iCloud generated files *.icloud -### PyCharm+all ### -# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider -# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 - -# User-specific stuff -.idea/**/workspace.xml -.idea/**/tasks.xml -.idea/**/usage.statistics.xml -.idea/**/dictionaries -.idea/**/shelf - -# AWS User-specific -.idea/**/aws.xml - -# Generated files -.idea/**/contentModel.xml - -# Sensitive or high-churn files -.idea/**/dataSources/ -.idea/**/dataSources.ids -.idea/**/dataSources.local.xml -.idea/**/sqlDataSources.xml -.idea/**/dynamic.xml -.idea/**/uiDesigner.xml -.idea/**/dbnavigator.xml - -# Gradle -.idea/**/gradle.xml -.idea/**/libraries - -# Gradle and Maven with auto-import -# When using Gradle or Maven with auto-import, you should exclude module files, -# since they will be recreated, and may cause churn. Uncomment if using -# auto-import. -# .idea/artifacts -# .idea/compiler.xml -# .idea/jarRepositories.xml -# .idea/modules.xml -# .idea/*.iml -# .idea/modules -# *.iml -# *.ipr - -# CMake -cmake-build-*/ - -# Mongo Explorer plugin -.idea/**/mongoSettings.xml - -# File-based project format -*.iws - -# IntelliJ -out/ - -# mpeltonen/sbt-idea plugin -.idea_modules/ - -# JIRA plugin -atlassian-ide-plugin.xml - -# Cursive Clojure plugin -.idea/replstate.xml - -# SonarLint plugin -.idea/sonarlint/ - -# Crashlytics plugin (for Android Studio and IntelliJ) -com_crashlytics_export_strings.xml -crashlytics.properties -crashlytics-build.properties -fabric.properties - -# Editor-based Rest Client -.idea/httpRequests - -# Android studio 3.1+ serialized cache file -.idea/caches/build_file_checksums.ser - -### PyCharm+all Patch ### -# Ignore everything but code style settings and run configurations -# that are supposed to be shared within teams. - -.idea/* - -!.idea/codeStyles -!.idea/runConfigurations - ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ @@ -288,7 +219,7 @@ celerybeat.pid # Environments .env .venv -# env/ +env/ venv/ ENV/ env.bak/ @@ -357,7 +288,7 @@ tags ### VisualStudioCode ### .vscode/* -# !.vscode/settings.json +!.vscode/settings.json !.vscode/tasks.json !.vscode/launch.json !.vscode/extensions.json @@ -400,20 +331,4 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk -# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,jupyternotebooks,visualstudiocode,vim,git,pycharm+all - - -bulk_data -production_models/** -!production_models/**/ -!production_models/**/README.md -logs/ -wandb/ -*.mp4 -*.jpg -*.pth -*.png -temp/ -batch_scripts/ -outputs/ -notebooks/ +# End of https://www.toptal.com/developers/gitignore/api/git,vim,linux,macos,python,windows,jupyternotebooks,visualstudiocode \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index bcdb877..1fc35b2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -1288,6 +1288,21 @@ files = [ {file = "numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78"}, ] +[[package]] +name = "numpy-typing-compat" +version = "20250818.2.0" +description = "Static typing compatibility layer for older versions of NumPy" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "numpy_typing_compat-20250818.2.0-py3-none-any.whl", hash = "sha256:042da86a786b6eb164f900efdfc3ba132f4371a2e44a93109976b1d7538253ed"}, + {file = "numpy_typing_compat-20250818.2.0.tar.gz", hash = "sha256:3f77ba873ec9668e9b7bd15ae083cc16c82aa732b651ed2bf5aa284cdd0dc71d"}, +] + +[package.dependencies] +numpy = ">=2.0,<2.1" + [[package]] name = "nvidia-cublas-cu12" version = "12.8.4.1" @@ -1529,6 +1544,25 @@ files = [ [package.dependencies] numpy = {version = ">=2,<2.3.0", markers = "python_version >= \"3.9\""} +[[package]] +name = "optype" +version = "0.14.0" +description = "Building Blocks for Precise & Flexible Type Hints" +optional = false +python-versions = ">=3.11" +groups = ["main"] +files = [ + {file = "optype-0.14.0-py3-none-any.whl", hash = "sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b"}, + {file = "optype-0.14.0.tar.gz", hash = "sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6"}, +] + +[package.dependencies] +numpy = {version = ">=1.25", optional = true, markers = "extra == \"numpy\""} +numpy-typing-compat = {version = ">=20250818.1.25,<20250819", optional = true, markers = "extra == \"numpy\""} + +[package.extras] +numpy = ["numpy (>=1.25)", "numpy-typing-compat (>=20250818.1.25,<20250819)"] + [[package]] name = "packaging" version = "25.0" @@ -1637,6 +1671,22 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pandas-stubs" +version = "2.3.3.251201" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "pandas_stubs-2.3.3.251201-py3-none-any.whl", hash = "sha256:eb5c9b6138bd8492fd74a47b09c9497341a278fcfbc8633ea4b35b230ebf4be5"}, + {file = "pandas_stubs-2.3.3.251201.tar.gz", hash = "sha256:7a980f4f08cff2a6d7e4c6d6d26f4c5fcdb82a6f6531489b2f75c81567fe4536"}, +] + +[package.dependencies] +numpy = ">=1.23.5" +types-pytz = ">=2022.1.1" + [[package]] name = "pillow" version = "11.3.0" @@ -2372,6 +2422,24 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodest doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.19.1)", "jupytext", "linkify-it-py", "matplotlib (>=3.5)", "myst-nb (>=1.2.0)", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.2.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)"] test = ["Cython", "array-api-strict (>=2.3.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja ; sys_platform != \"emscripten\"", "pooch", "pytest (>=8.0.0)", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "scipy-stubs" +version = "1.16.3.2" +description = "Type annotations for SciPy" +optional = false +python-versions = ">=3.11" +groups = ["main"] +files = [ + {file = "scipy_stubs-1.16.3.2-py3-none-any.whl", hash = "sha256:c8df8975d69d77237a12e7b32f0bce08a7cc27bbb53ff7f4671ef1e9de033490"}, + {file = "scipy_stubs-1.16.3.2.tar.gz", hash = "sha256:04c498da2bfb445cf1dce06cb012bbaafd30fa19d9c7b03852a0574e7e83df3a"}, +] + +[package.dependencies] +optype = {version = ">=0.14.0,<0.15", extras = ["numpy"]} + +[package.extras] +scipy = ["scipy (>=1.16.3,<1.17)"] + [[package]] name = "scooby" version = "0.10.2" @@ -2813,6 +2881,18 @@ files = [ [package.dependencies] typing_extensions = ">=4.14.0" +[[package]] +name = "types-pytz" +version = "2025.2.0.20251108" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "types_pytz-2025.2.0.20251108-py3-none-any.whl", hash = "sha256:0f1c9792cab4eb0e46c52f8845c8f77cf1e313cb3d68bf826aa867fe4717d91c"}, + {file = "types_pytz-2025.2.0.20251108.tar.gz", hash = "sha256:fca87917836ae843f07129567b74c1929f1870610681b4c92cb86a3df5817bdb"}, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -3022,4 +3102,4 @@ dev = ["black (>=19.3b0) ; python_version >= \"3.6\"", "pytest (>=4.6.2)"] [metadata] lock-version = "2.1" python-versions = ">=3.13,<3.14" -content-hash = "6b6d3ec0385494898f80d1d821a1b48f1bfca0f081be0600d583487e664119cb" +content-hash = "faa526e8c9a8209823ec8cabb8cc87dd42446459bb0c237366c2d3a94487d28c" diff --git a/pyproject.toml b/pyproject.toml index fe4a71e..668ba53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,8 +9,8 @@ readme = "README.md" requires-python = ">=3.13,<3.14" dependencies = [ "numpy>=2.0,<3", - "scipy>=1.16.1,<2", - "pandas>=2.3,<3", + "scipy>=1.16.2,<1.17.0", + "pandas>=2.3.3,<2.4.0", "matplotlib>=3.10,<4", "torch==2.9.*", # specific version required by torchvision and torchcodec! "torchvision==0.24.*", # specific to torch 2.9 @@ -41,6 +41,9 @@ dependencies = [ "filelock>=3.20.0,<4", "av<16", # imageio depends on pyav but no pyav 16.0.0 wheel is available for my hardware yet "loguru==0.7.*", + # PEP561 stub packages for better static analysis + "pandas-stubs (>=2.3.3,<2.4.0)", + "scipy-stubs (>=1.16.2,<1.17.0)" # The following should be installed manually (pip install -e .): # flygym & dm_control # parallel-video-io From ad1f686ec703ff59d527e6b265cf72ca93691dec Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Mon, 8 Dec 2025 19:25:51 +0100 Subject: [PATCH 25/33] [fix] remove buggy auto-close at the end in live mode --- src/poseforge/neuromechfly/scripts/visualize_meshes.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/poseforge/neuromechfly/scripts/visualize_meshes.py b/src/poseforge/neuromechfly/scripts/visualize_meshes.py index db69096..368c5c1 100644 --- a/src/poseforge/neuromechfly/scripts/visualize_meshes.py +++ b/src/poseforge/neuromechfly/scripts/visualize_meshes.py @@ -145,15 +145,10 @@ def _update_scene_by_frameid(frameid): "Nothing will be saved and the output path will be ignored." ) - def _update_with_cleanup(step): - _update_scene_by_frameid(step) - if step == len(self) - 1: - plotter.close() - plotter.add_timer_event( max_steps=len(self), duration=int(1000 / render_fps), # in ms - callback=_update_with_cleanup, + callback=_update_scene_by_frameid, ) plotter.show() else: From 104f571a3d8d5ebcd31877e65bd9d2408c7e3a9e Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Thu, 11 Dec 2025 00:15:39 +0100 Subject: [PATCH 26/33] [refactor] fix style/docs/typehints for nmf postprocessing module --- src/poseforge/neuromechfly/postprocessing.py | 45 +++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/src/poseforge/neuromechfly/postprocessing.py b/src/poseforge/neuromechfly/postprocessing.py index 68727d8..eefd151 100644 --- a/src/poseforge/neuromechfly/postprocessing.py +++ b/src/poseforge/neuromechfly/postprocessing.py @@ -1,3 +1,5 @@ +from typing import Any + import numpy as np import matplotlib.pyplot as plt import scipy.ndimage as ndimage @@ -8,6 +10,8 @@ import h5py from collections import defaultdict from pathlib import Path + +from numpy import ndarray, dtype from tqdm import tqdm from joblib import Parallel, delayed from scipy.linalg import rq @@ -80,7 +84,9 @@ def __init__(self): self.label_colors_6d = np.array(self.label_colors_6d) - def __call__(self, images_by_color_coding: list[np.ndarray]): + def __call__( + self, images_by_color_coding: list[np.ndarray] | np.ndarray + ) -> np.ndarray: assert ( len(images_by_color_coding) == 2 ), "Expecting two images each with a different color coding." @@ -89,8 +95,8 @@ def __call__(self, images_by_color_coding: list[np.ndarray]): ), "Color coding images must have the same shape." if not isinstance(images_by_color_coding, list): images_by_color_coding = [ - images_by_color_coding[0, :, :, :], - images_by_color_coding[1, :, :, :], + images_by_color_coding[0, ...], + images_by_color_coding[1, ...], ] image_6d = np.concatenate(images_by_color_coding, axis=-1) @@ -103,7 +109,7 @@ def __call__(self, images_by_color_coding: list[np.ndarray]): return label_indices -def load_video_frames(video_path: Path) -> list[np.ndarray]: +def load_video_frames(video_path: Path) -> tuple[list[np.ndarray], float]: """Load video frames from a video file.""" if not video_path.is_file(): raise FileNotFoundError(f"{video_path} is not a file.") @@ -122,7 +128,7 @@ def load_video_frames(video_path: Path) -> list[np.ndarray]: return frames, fps -def get_rotation_angle_and_matrix(forward_vector: np.ndarray) -> np.ndarray: +def get_rotation_angle_and_matrix(forward_vector: np.ndarray) -> tuple[float, ndarray]: """Get a rotation matrix that rotates the forward vector to the z-axis.""" orientation = np.arctan2(forward_vector[1], forward_vector[0]) rotation_matrix = np.array( @@ -141,7 +147,9 @@ def rotate_image(image: np.ndarray, rotation_angle) -> np.ndarray: ) -def center_square_crop_image(image: np.ndarray, side_length) -> np.ndarray: +def center_square_crop_image( + image: np.ndarray, side_length: int +) -> tuple[np.ndarray, int, int]: """Crop the image to a square of given side length, centered on the image.""" height, width = image.shape[:2] start_col = (width - side_length) // 2 @@ -206,7 +214,7 @@ def rotate_keypoint_positions_camera( image_shape: tuple[int, int], start_col: int, start_row: int, -): +) -> dict[str, np.ndarray]: """ Apply rotation and cropping transformations to camera/image coordinates. Uses a simplified approach based on the working code snippet. @@ -215,7 +223,12 @@ def rotate_keypoint_positions_camera( keypoints_pos_lookup: Dict mapping segment names to (x, y, depth) positions rotation_matrix: 3x3 rotation matrix used for coordinate transformation image_shape: (height, width) of the original image before rotation - start_col, start_row: Cropping offsets after rotation + start_col: Cropping offsets after rotation (column) + start_row: Cropping offsets after rotation (row) + + Returns: + keypoints_pos_lookup_rotated: Dict mapping segment names to transformed + (x, y, depth) positions """ keypoints_pos_lookup_rotated = {} height, width = image_shape @@ -239,7 +252,7 @@ def rotate_keypoint_positions_camera( def rotate_keypoint_positions_world( keypoints_pos_lookup: dict[str, np.ndarray], rotation_matrix: np.ndarray -): +) -> dict[str, np.ndarray]: keypoints_pos_lookup_rotated = {} for body_segment_name, pos_before_rotation in keypoints_pos_lookup.items(): # Apply the rotation matrix to the keypoints @@ -368,7 +381,7 @@ def process_subsegment( segment_h5_file_path: Path, frames_range: tuple[int, int], processed_subsegment_dir: Path, - fps: int, + fps: int | float, crop_size: int = 464, num_color_codings: int = 2, n_jobs: int = -1, @@ -724,16 +737,18 @@ def visualize_single_frame( def visualize_subsegment( processed_subsegment_dir: Path, - fps: int, + render_fps: float | None = None, camera_elevation: float = 30.0, max_abs_azimuth: float = 30.0, azimuth_rotation_period: float = 300.0, n_jobs: int = -1, # Default to all available cores ) -> None: # Load video frames - frames, fps = load_video_frames( + frames, input_fps = load_video_frames( processed_subsegment_dir / "processed_nmf_sim_render_colorcode_0.mp4" ) + if render_fps is None: + render_fps = input_fps # Find processed simulation data processed_data_path = processed_subsegment_dir / "processed_simulation_data.h5" @@ -789,7 +804,7 @@ def visualize_subsegment( # Merge video with imageio.get_writer( str(processed_subsegment_dir / "visualization.mp4"), - fps=fps, + fps=render_fps, codec="libx264", quality=10, # 10 is highest for imageio, lower is lower quality ffmpeg_params=["-crf", "18", "-preset", "slow"], # lower crf = higher quality @@ -803,7 +818,7 @@ def visualize_subsegment( def select_subsegments( - upward_cardinal_vectors: np.array, + upward_cardinal_vectors: np.ndarray, max_tilt_angle_deg: float, mask_morph_closing_size_sec: float, min_subsegment_duration_sec: float, @@ -898,7 +913,7 @@ def postprocess_segment( if visualize: visualize_subsegment( processed_subsegment_dir=output_dir, - fps=fps, + render_fps=fps, camera_elevation=camera_elevation, max_abs_azimuth=max_abs_azimuth, azimuth_rotation_period=azimuth_rotation_period, From c21961ca21ba733f428169ad772ce50d41a41faa Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Thu, 11 Dec 2025 00:41:07 +0100 Subject: [PATCH 27/33] [feature/fix] apply alignment rotation to mesh 6d states --- src/poseforge/neuromechfly/postprocessing.py | 13 ++++++++++++- .../neuromechfly/scripts/run_simulation.py | 10 +++++----- src/poseforge/util/__init__.py | 3 ++- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/src/poseforge/neuromechfly/postprocessing.py b/src/poseforge/neuromechfly/postprocessing.py index eefd151..eac379c 100644 --- a/src/poseforge/neuromechfly/postprocessing.py +++ b/src/poseforge/neuromechfly/postprocessing.py @@ -323,9 +323,20 @@ def process_single_frame( pos_glob = h5_file["body_segment_states/pos_global"][frame_idx, :, :] quat_glob = h5_file["body_segment_states/quat_global"][frame_idx, :, :] cam_projmat = h5_file["camera_matrix"][frame_idx, :, :] # 3x4 mat from mujoco - derived_variables["pos_rel_cam"], derived_variables["quat_rel_cam"] = ( + pos_rel_cam_pre_alignment, quat_rel_cam_pre_alignment = ( calculate_6dpose_relative_to_camera(pos_glob, quat_glob, cam_projmat) ) + # As before, we're rotating the camera capture post hoc so that the fly is + # always upright. Therefore, upon computing the position and rotation of the + # mesh relative to the camera, we still need to rotate them around the z axis + alignment_rotation = z_rotation = Rotation.from_rotvec([0, 0, rotation_angle]) + pos_rel_cam_aligned = alignment_rotation.apply(pos_rel_cam_pre_alignment) + quat_rel_cam_aligned = ( + alignment_rotation + * Rotation.from_quat(quat_rel_cam_pre_alignment, scalar_first=True) + ).as_quat(scalar_first=True) + derived_variables["pos_rel_cam"] = pos_rel_cam_aligned + derived_variables["quat_rel_cam"] = quat_rel_cam_aligned return rendered_images_transformed, derived_variables, seg_labels diff --git a/src/poseforge/neuromechfly/scripts/run_simulation.py b/src/poseforge/neuromechfly/scripts/run_simulation.py index c875731..95a074c 100644 --- a/src/poseforge/neuromechfly/scripts/run_simulation.py +++ b/src/poseforge/neuromechfly/scripts/run_simulation.py @@ -55,7 +55,7 @@ from poseforge.neuromechfly.data import load_kinematic_recording # from poseforge.neuromechfly.simulate import simulate_one_segment # TODO: revert from poseforge.neuromechfly.postprocessing import postprocess_segment -from poseforge.util import get_hardware_availability +from poseforge.util import get_hardware_availability, bulk_data_dir def simulate_using_kinematic_prior( @@ -155,7 +155,7 @@ def simulate_using_kinematic_prior( def run_sequentially_for_testing(): """Run everything sequentially (for debugging)""" # Configs - output_basedir = Path("bulk_data/nmf_rendering_new/") # TODO: change back to *_test + output_basedir = bulk_data_dir / "nmf_rendering/" # TODO: change back to *_test input_timestep = 0.01 sim_timestep = 0.0001 # trial_paths = [ @@ -163,7 +163,7 @@ def run_sequentially_for_testing(): # Path("bulk_data/kinematic_prior/aymanns2022/trials/BO_Gal4_fly1_trial001.pkl") # ] trial_paths = sorted( # TODO: revert - Path("bulk_data/kinematic_prior/aymanns2022/trials/").glob("*.pkl") + (bulk_data_dir / "kinematic_prior/aymanns2022/trials/").glob("*.pkl") ) # Limit scope of simulation as this is only for testing @@ -191,7 +191,7 @@ def run_sequentially_for_testing(): get_hardware_availability(check_gpu=False, print_results=True) # Run the CLI - tyro.cli(simulate_using_kinematic_prior) # TODO: enable CLI + # tyro.cli(simulate_using_kinematic_prior) # TODO: enable CLI # Run everything sequentially (for debugging) # TODO: disable testing - # run_sequentially_for_testing() + run_sequentially_for_testing() diff --git a/src/poseforge/util/__init__.py b/src/poseforge/util/__init__.py index 83730ca..ceb54a2 100644 --- a/src/poseforge/util/__init__.py +++ b/src/poseforge/util/__init__.py @@ -5,7 +5,7 @@ check_mixed_precision_status, ) from .plot import configure_matplotlib_style -from .data import SerializableDataClass, OutputBuffer +from .data import SerializableDataClass, OutputBuffer, bulk_data_dir from .ml import count_optimizer_parameters, count_module_parameters @@ -17,6 +17,7 @@ "configure_matplotlib_style", "SerializableDataClass", "OutputBuffer", + "bulk_data_dir", "count_optimizer_parameters", "count_module_parameters", ] From 2f8c931ea129e123bf429a8d58fc5d8330492f8b Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Thu, 11 Dec 2025 01:13:03 +0100 Subject: [PATCH 28/33] move atomic batches preextraction from contrastive pretraining module to data module --- scripts_on_cluster/atomic_batch_extraction/process_all.run | 2 +- scripts_on_cluster/atomic_batch_extraction/template.run | 2 +- .../{contrast => data}/scripts/preextract_atomic_batches.py | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename src/poseforge/pose/{contrast => data}/scripts/preextract_atomic_batches.py (100%) diff --git a/scripts_on_cluster/atomic_batch_extraction/process_all.run b/scripts_on_cluster/atomic_batch_extraction/process_all.run index c7f1760..1fc124e 100644 --- a/scripts_on_cluster/atomic_batch_extraction/process_all.run +++ b/scripts_on_cluster/atomic_batch_extraction/process_all.run @@ -19,7 +19,7 @@ conda activate poseforge # Define paths project_root="/home/$USER/poseforge" -extraction_cli_path="src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py" +extraction_cli_path="src/poseforge/pose/data/scripts/preextract_atomic_batches.py" all_trials_basedir="/home/sibwang/poseforge/bulk_data/style_transfer/production/translated_videos/" # Setup diff --git a/scripts_on_cluster/atomic_batch_extraction/template.run b/scripts_on_cluster/atomic_batch_extraction/template.run index 99fb8d3..b883511 100644 --- a/scripts_on_cluster/atomic_batch_extraction/template.run +++ b/scripts_on_cluster/atomic_batch_extraction/template.run @@ -19,7 +19,7 @@ conda activate poseforge # Define paths project_root="/home/sibwang/poseforge" -extraction_cli_path="src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py" +extraction_cli_path="src/poseforge/pose/data/scripts/preextract_atomic_batches.py" synthetic_videos_basedir="bulk_data/style_transfer/production/translated_videos/" nmf_rendering_basedir="bulk_data/nmf_rendering/" output_basedir="/home/sibwang/poseforge/bulk_data/pose_estimation/atomic_batches/4variants/" diff --git a/src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py b/src/poseforge/pose/data/scripts/preextract_atomic_batches.py similarity index 100% rename from src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py rename to src/poseforge/pose/data/scripts/preextract_atomic_batches.py From 576d18334283836d86acb224b2322444cfd59285 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Thu, 11 Dec 2025 01:13:32 +0100 Subject: [PATCH 29/33] fix temp path; undo temporary test script --- .../atomic_batch_extraction/gen_batch_scripts.py | 2 +- scripts_on_cluster/atomic_batch_extraction/process_all.run | 4 ++-- scripts_on_cluster/atomic_batch_extraction/template.run | 4 ++-- src/poseforge/neuromechfly/scripts/run_simulation.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py b/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py index 60bf884..ada441e 100644 --- a/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py +++ b/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py @@ -7,7 +7,7 @@ batch_scripts_dir.mkdir(exist_ok=True, parents=True) # Configs by task -synthetic_videos_basedir = Path("/work/upramdya/sibo_temp/poseforge/bulk_data/style_transfer/production/translated_videos/") +synthetic_videos_basedir = Path("/home/sibwang/poseforge/bulk_data/style_transfer/production/translated_videos/") trial_names_all = [x.name for x in synthetic_videos_basedir.glob("BO_Gal4_*")] # Generate batch scripts diff --git a/scripts_on_cluster/atomic_batch_extraction/process_all.run b/scripts_on_cluster/atomic_batch_extraction/process_all.run index 1fc124e..133754b 100644 --- a/scripts_on_cluster/atomic_batch_extraction/process_all.run +++ b/scripts_on_cluster/atomic_batch_extraction/process_all.run @@ -18,9 +18,9 @@ spack load ffmpeg conda activate poseforge # Define paths -project_root="/home/$USER/poseforge" +project_root="$HOME/poseforge" extraction_cli_path="src/poseforge/pose/data/scripts/preextract_atomic_batches.py" -all_trials_basedir="/home/sibwang/poseforge/bulk_data/style_transfer/production/translated_videos/" +all_trials_basedir="bulk_data/style_transfer/production/translated_videos/" # Setup cd $project_root diff --git a/scripts_on_cluster/atomic_batch_extraction/template.run b/scripts_on_cluster/atomic_batch_extraction/template.run index b883511..a6aad3e 100644 --- a/scripts_on_cluster/atomic_batch_extraction/template.run +++ b/scripts_on_cluster/atomic_batch_extraction/template.run @@ -18,11 +18,11 @@ spack load ffmpeg conda activate poseforge # Define paths -project_root="/home/sibwang/poseforge" +project_root="$HOME/poseforge" extraction_cli_path="src/poseforge/pose/data/scripts/preextract_atomic_batches.py" synthetic_videos_basedir="bulk_data/style_transfer/production/translated_videos/" nmf_rendering_basedir="bulk_data/nmf_rendering/" -output_basedir="/home/sibwang/poseforge/bulk_data/pose_estimation/atomic_batches/4variants/" +output_basedir="bulk_data/pose_estimation/atomic_batches/4variants/" # Setup cd $project_root diff --git a/src/poseforge/neuromechfly/scripts/run_simulation.py b/src/poseforge/neuromechfly/scripts/run_simulation.py index 95a074c..2969a0b 100644 --- a/src/poseforge/neuromechfly/scripts/run_simulation.py +++ b/src/poseforge/neuromechfly/scripts/run_simulation.py @@ -191,7 +191,7 @@ def run_sequentially_for_testing(): get_hardware_availability(check_gpu=False, print_results=True) # Run the CLI - # tyro.cli(simulate_using_kinematic_prior) # TODO: enable CLI + tyro.cli(simulate_using_kinematic_prior) # TODO: enable CLI # Run everything sequentially (for debugging) # TODO: disable testing - run_sequentially_for_testing() + # run_sequentially_for_testing() From 5da520978ae15f47c081cb7422b1d68b42fe3e05 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Thu, 11 Dec 2025 18:23:28 +0100 Subject: [PATCH 30/33] [refactor] refactor scitas scripts - consolidate submission scripts and add script to compress output on compute nodes --- .../compress_results.run | 66 ++++++++++++++++++ .../submit_all.sh | 0 .../submit_all.sh | 20 ------ .../nmf_simulation/compress_results.run | 67 +++++++++++++++++++ .../nmf_simulation/submit_all.sh | 20 ------ .../style_transfer_inference/submit_all.sh | 20 ------ .../20250903_parameter_sweep/submit_all.sh | 20 ------ .../20250905_continued_training/submit_all.sh | 20 ------ .../20250905_parameter_sweep/submit_all.sh | 20 ------ 9 files changed, 133 insertions(+), 120 deletions(-) create mode 100644 scripts_on_cluster/atomic_batch_extraction/compress_results.run rename scripts_on_cluster/{atomic_batch_extraction => common}/submit_all.sh (100%) delete mode 100644 scripts_on_cluster/contrastive_pretraining_inference/submit_all.sh create mode 100644 scripts_on_cluster/nmf_simulation/compress_results.run delete mode 100644 scripts_on_cluster/nmf_simulation/submit_all.sh delete mode 100644 scripts_on_cluster/style_transfer_inference/submit_all.sh delete mode 100644 scripts_on_cluster/style_transfer_training/20250903_parameter_sweep/submit_all.sh delete mode 100644 scripts_on_cluster/style_transfer_training/20250905_continued_training/submit_all.sh delete mode 100644 scripts_on_cluster/style_transfer_training/20250905_parameter_sweep/submit_all.sh diff --git a/scripts_on_cluster/atomic_batch_extraction/compress_results.run b/scripts_on_cluster/atomic_batch_extraction/compress_results.run new file mode 100644 index 0000000..94281ee --- /dev/null +++ b/scripts_on_cluster/atomic_batch_extraction/compress_results.run @@ -0,0 +1,66 @@ +#!/bin/bash -l + +#SBATCH --job-name=compress_output +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 36 +#SBATCH --mem 96GB +#SBATCH --time 1:00:00 +#SBATCH --partition=standard +#SBATCH --qos=serial +#SBATCH --output=logs/compress_output.log + +# ------------------------------------------------------------------- +# CRASH CHECK: Ensure SLURM_CPUS_PER_TASK is set +# ------------------------------------------------------------------- + +# Get the requested number of CPUs from the Slurm environment +max_jobs=$SLURM_CPUS_PER_TASK + +# Check if the variable is empty or null. +if [ -z "$max_jobs" ]; then + echo "ERROR: SLURM_CPUS_PER_TASK environment variable is not set." >&2 + echo "This script must be submitted using 'sbatch' and requires the --cpus-per-task flag in the header." >&2 + exit 1 +fi + +echo "Running up to $max_jobs concurrent tar jobs ($(date))." +echo "----------------------------------------------" + +input_basedir="$HOME/poseforge/bulk_data/pose_estimation/atomic_batches/4variants" +output_basedir="$HOME/poseforge/bulk_data/pose_estimation/atomic_batches/4variants_tarballs" +mkdir -p $output_basedir + +# Loop through all subdirectories +job_count=0 +for dir_path in "$input_basedir"/*/; do + if [ -d "$dir_path" ]; then + + dir_name=$(basename "$dir_path") + output_file="$output_basedir/${dir_name}.tar.gz" + + # Launch the tar command in the background + # Note the -C flag to change directory to the parent + # *before* archiving. This keeps the archive clean. + + tar -czf "$output_file" -C "$dir_path/.." "$dir_name" & + + echo "Launched job for $dir_name (PID: $!)" + job_count=$((job_count + 1)) + + # When we hit the max number of concurrent jobs, we wait for *one* + # of the currently running background jobs to complete (wait -n). + if [ "$job_count" -ge "$max_jobs" ]; then + wait -n + job_count=$((job_count - 1)) + echo "A background job finished. Resuming launch..." + fi + fi +done + +# Wait for ALL remaining background jobs to finish +echo "----------------------------------------------" +echo "All jobs launched. Waiting for final completion..." +wait + +echo "All parallel tar operations complete ($(date))." diff --git a/scripts_on_cluster/atomic_batch_extraction/submit_all.sh b/scripts_on_cluster/common/submit_all.sh similarity index 100% rename from scripts_on_cluster/atomic_batch_extraction/submit_all.sh rename to scripts_on_cluster/common/submit_all.sh diff --git a/scripts_on_cluster/contrastive_pretraining_inference/submit_all.sh b/scripts_on_cluster/contrastive_pretraining_inference/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/contrastive_pretraining_inference/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/nmf_simulation/compress_results.run b/scripts_on_cluster/nmf_simulation/compress_results.run new file mode 100644 index 0000000..8bd810a --- /dev/null +++ b/scripts_on_cluster/nmf_simulation/compress_results.run @@ -0,0 +1,67 @@ +#!/bin/bash -l + +#SBATCH --job-name=compress_output +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 36 +#SBATCH --mem 96GB +#SBATCH --time 1:00:00 +#SBATCH --partition=standard +#SBATCH --qos=serial +#SBATCH --output=logs/compress_output.log + +# ------------------------------------------------------------------- +# CRASH CHECK: Ensure SLURM_CPUS_PER_TASK is set +# ------------------------------------------------------------------- + +# Get the requested number of CPUs from the Slurm environment +max_jobs=$SLURM_CPUS_PER_TASK + +# Check if the variable is empty or null. +if [ -z "$max_jobs" ]; then + echo "ERROR: SLURM_CPUS_PER_TASK environment variable is not set." >&2 + echo "This script must be submitted using 'sbatch' and requires the --cpus-per-task flag in the header." >&2 + exit 1 +fi + +echo "Running up to $max_jobs concurrent tar jobs ($(date))." +echo "----------------------------------------------" + +input_basedir="$HOME/poseforge/bulk_data/nmf_rendering" +output_basedir="$HOME/poseforge/bulk_data/nmf_rendering_tarballs" +mkdir -p $output_basedir + +# Loop through all subdirectories +job_count=0 +for dir_path in "$input_basedir"/*/; do + # Check if the path is actually a directory + if [ -d "$dir_path" ]; then + + dir_name=$(basename "$dir_path") + output_file="$output_basedir/${dir_name}.tar.gz" + + # --- Launch the tar command in the background --- + # Note the -C flag to change directory to the parent + # *before* archiving. This keeps the archive clean. + + tar -czf "$output_file" -C "$dir_path/.." "$dir_name" & + + echo "Launched job for $dir_name (PID: $!)" + job_count=$((job_count + 1)) + + # --- Job Control: Wait for a background job to finish --- + # When we hit the max number of concurrent jobs, we wait for *one* # of the currently running background jobs to complete (wait -n). + if [ "$job_count" -ge "$max_jobs" ]; then + wait -n + job_count=$((job_count - 1)) + echo "A background job finished. Resuming launch..." + fi + fi +done + +# --- Final Wait: Wait for ALL remaining background jobs to finish --- +echo "----------------------------------------------" +echo "All jobs launched. Waiting for final completion..." +wait + +echo "All parallel tar operations complete ($(date))." diff --git a/scripts_on_cluster/nmf_simulation/submit_all.sh b/scripts_on_cluster/nmf_simulation/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/nmf_simulation/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/style_transfer_inference/submit_all.sh b/scripts_on_cluster/style_transfer_inference/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/style_transfer_inference/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/style_transfer_training/20250903_parameter_sweep/submit_all.sh b/scripts_on_cluster/style_transfer_training/20250903_parameter_sweep/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/style_transfer_training/20250903_parameter_sweep/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/style_transfer_training/20250905_continued_training/submit_all.sh b/scripts_on_cluster/style_transfer_training/20250905_continued_training/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/style_transfer_training/20250905_continued_training/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/style_transfer_training/20250905_parameter_sweep/submit_all.sh b/scripts_on_cluster/style_transfer_training/20250905_parameter_sweep/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/style_transfer_training/20250905_parameter_sweep/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" From 7d5ff4f9f0b39cc88ac60528573bfb274501cc5a Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Thu, 11 Dec 2025 18:47:24 +0100 Subject: [PATCH 31/33] [run] rerun pose6d training with correctly rotated labels --- .../pose6d_training/train_20251211a.run | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 scripts_on_cluster/pose6d_training/train_20251211a.run diff --git a/scripts_on_cluster/pose6d_training/train_20251211a.run b/scripts_on_cluster/pose6d_training/train_20251211a.run new file mode 100644 index 0000000..a7c7258 --- /dev/null +++ b/scripts_on_cluster/pose6d_training/train_20251211a.run @@ -0,0 +1,81 @@ +#!/bin/bash -l + +#SBATCH --job-name pose6d-training +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 32 +#SBATCH --mem 92GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/pose6d_training/output_20251211a.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge +cd $HOME/poseforge + +training_cli_path="src/poseforge/pose/pose6d/scripts/run_training.py" +training_trial_name="20251211a" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" + +echo "Training starting at $(date)" + +python -u $training_cli_path \ + --n_epochs 15 \ + --seed 42 \ + --model-architecture-config.n-segments 25 \ + --model-architecture-config.n-attention-gated-feature-channels 192 \ + --model-architecture-config.n-global-feature-channels 64 \ + --model-architecture-config.camera-distance 100.0 \ + --model-weights-config.feature-extractor-weights \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ + --loss-config.translation-weight 1.0 \ + --loss-config.rotation-weight 10.0 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ + --training-data-config.input-image-size 256 256 \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 128 \ + --training-data-config.val-batch-size 512 \ + --training-data-config.n-workers 8 \ + --optimizer-config.learning_rate_encoder 0.0001 \ + --optimizer-config.learning_rate_upsample 0.001 \ + --optimizer-config.learning_rate_attention 0.001 \ + --optimizer-config.learning_rate_pose6d_heads 0.001 \ + --optimizer-config.weight-decay 0.00001 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/pose6d/$training_trial_name/" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 1000 \ + --training-artifacts-config.validation-interval 300 \ + --training-artifacts-config.n-batches-per-validation 30 + +echo "Training ends at $(date)" From b25403f0c9c4067e80198d6472f6f679840bb013 Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sat, 13 Dec 2025 20:21:15 +0100 Subject: [PATCH 32/33] [feature] add legends metadata to atomic batches files --- .../pose/data/scripts/_attach_metadata.py | 275 ++++++++++++++++++ .../data/scripts/preextract_atomic_batches.py | 13 +- .../pose/data/synthetic/atomic_batch.py | 10 +- src/poseforge/pose/data/synthetic/sampler.py | 21 +- .../pose/data/synthetic/sim_data_seq.py | 15 +- 5 files changed, 311 insertions(+), 23 deletions(-) create mode 100644 src/poseforge/pose/data/scripts/_attach_metadata.py diff --git a/src/poseforge/pose/data/scripts/_attach_metadata.py b/src/poseforge/pose/data/scripts/_attach_metadata.py new file mode 100644 index 0000000..9eee3c8 --- /dev/null +++ b/src/poseforge/pose/data/scripts/_attach_metadata.py @@ -0,0 +1,275 @@ +import h5py +from pathlib import Path +from tqdm import tqdm + +KEYS = { + "body_seg_maps": [ + "Background", + "OtherSegments", + "Thorax", + "LFCoxa", + "LFFemur", + "LFTibia", + "LFTarsus", + "LMCoxa", + "LMFemur", + "LMTibia", + "LMTarsus", + "LHCoxa", + "LHFemur", + "LHTibia", + "LHTarsus", + "RFCoxa", + "RFFemur", + "RFTibia", + "RFTarsus", + "RMCoxa", + "RMFemur", + "RMTibia", + "RMTarsus", + "RHCoxa", + "RHFemur", + "RHTibia", + "RHTarsus", + "LAntenna", + "RAntenna", + ], + "dof_angles": [ + "LFThC_pitch", + "LFThC_roll", + "LFThC_yaw", + "LFCTr_pitch", + "LFCTr_roll", + "LFFTi_pitch", + "LFTiTa_pitch", + "LMThC_pitch", + "LMThC_roll", + "LMThC_yaw", + "LMCTr_pitch", + "LMCTr_roll", + "LMFTi_pitch", + "LMTiTa_pitch", + "LHThC_pitch", + "LHThC_roll", + "LHThC_yaw", + "LHCTr_pitch", + "LHCTr_roll", + "LHFTi_pitch", + "LHTiTa_pitch", + "RFThC_pitch", + "RFThC_roll", + "RFThC_yaw", + "RFCTr_pitch", + "RFCTr_roll", + "RFFTi_pitch", + "RFTiTa_pitch", + "RMThC_pitch", + "RMThC_roll", + "RMThC_yaw", + "RMCTr_pitch", + "RMCTr_roll", + "RMFTi_pitch", + "RMTiTa_pitch", + "RHThC_pitch", + "RHThC_roll", + "RHThC_yaw", + "RHCTr_pitch", + "RHCTr_roll", + "RHFTi_pitch", + "RHTiTa_pitch", + ], + "keypoint_pos": [ + "LFCoxa", + "LFFemur", + "LFTibia", + "LFTarsus1", + "LFTarsus5", + "LMCoxa", + "LMFemur", + "LMTibia", + "LMTarsus1", + "LMTarsus5", + "LHCoxa", + "LHFemur", + "LHTibia", + "LHTarsus1", + "LHTarsus5", + "RFCoxa", + "RFFemur", + "RFTibia", + "RFTarsus1", + "RFTarsus5", + "RMCoxa", + "RMFemur", + "RMTibia", + "RMTarsus1", + "RMTarsus5", + "RHCoxa", + "RHFemur", + "RHTibia", + "RHTarsus1", + "RHTarsus5", + "LPedicel", + "RPedicel", + ], + "mesh_pos": [ + "Thorax", + "A1A2", + "A3", + "A4", + "A5", + "A6", + "LHaltere", + "LWing", + "RHaltere", + "RWing", + "LFCoxa", + "LFFemur", + "LFTibia", + "LFTarsus1", + "LFTarsus2", + "LFTarsus3", + "LFTarsus4", + "LFTarsus5", + "LHCoxa", + "LHFemur", + "LHTibia", + "LHTarsus1", + "LHTarsus2", + "LHTarsus3", + "LHTarsus4", + "LHTarsus5", + "LMCoxa", + "LMFemur", + "LMTibia", + "LMTarsus1", + "LMTarsus2", + "LMTarsus3", + "LMTarsus4", + "LMTarsus5", + "RHCoxa", + "RHFemur", + "RHTibia", + "RHTarsus1", + "RHTarsus2", + "RHTarsus3", + "RHTarsus4", + "RHTarsus5", + "RMCoxa", + "RMFemur", + "RMTibia", + "RMTarsus1", + "RMTarsus2", + "RMTarsus3", + "RMTarsus4", + "RMTarsus5", + "RFCoxa", + "RFFemur", + "RFTibia", + "RFTarsus1", + "RFTarsus2", + "RFTarsus3", + "RFTarsus4", + "RFTarsus5", + "Head", + "LEye", + "REye", + "Rostrum", + "Haustellum", + "LPedicel", + "LFuniculus", + "LArista", + "RPedicel", + "RFuniculus", + "RArista", + ], + "mesh_quat": [ + "Thorax", + "A1A2", + "A3", + "A4", + "A5", + "A6", + "LHaltere", + "LWing", + "RHaltere", + "RWing", + "LFCoxa", + "LFFemur", + "LFTibia", + "LFTarsus1", + "LFTarsus2", + "LFTarsus3", + "LFTarsus4", + "LFTarsus5", + "LHCoxa", + "LHFemur", + "LHTibia", + "LHTarsus1", + "LHTarsus2", + "LHTarsus3", + "LHTarsus4", + "LHTarsus5", + "LMCoxa", + "LMFemur", + "LMTibia", + "LMTarsus1", + "LMTarsus2", + "LMTarsus3", + "LMTarsus4", + "LMTarsus5", + "RHCoxa", + "RHFemur", + "RHTibia", + "RHTarsus1", + "RHTarsus2", + "RHTarsus3", + "RHTarsus4", + "RHTarsus5", + "RMCoxa", + "RMFemur", + "RMTibia", + "RMTarsus1", + "RMTarsus2", + "RMTarsus3", + "RMTarsus4", + "RMTarsus5", + "RFCoxa", + "RFFemur", + "RFTibia", + "RFTarsus1", + "RFTarsus2", + "RFTarsus3", + "RFTarsus4", + "RFTarsus5", + "Head", + "LEye", + "REye", + "Rostrum", + "Haustellum", + "LPedicel", + "LFuniculus", + "LArista", + "RPedicel", + "RFuniculus", + "RArista", + ], +} + + +def attach_metadata(path: Path): + with h5py.File(path, "a") as f: + for k, keys in KEYS.items(): + f[k].attrs["keys"] = keys + + +if __name__ == "__main__": + basedir = Path(".") + all_files = sorted(basedir.glob("**/*_labels.h5")) + print(f"Found {len(all_files)} files. Continue?") + resp = input("Type 'yes' to continue: ") + if resp.lower() != "yes": + print("Aborting.") + exit(0) + for file_path in tqdm(all_files): + attach_metadata(file_path) diff --git a/src/poseforge/pose/data/scripts/preextract_atomic_batches.py b/src/poseforge/pose/data/scripts/preextract_atomic_batches.py index 05979fe..91f8c88 100644 --- a/src/poseforge/pose/data/scripts/preextract_atomic_batches.py +++ b/src/poseforge/pose/data/scripts/preextract_atomic_batches.py @@ -172,11 +172,11 @@ def process_batch(batch_idx: int): AtomicBatchDataset.save_atomic_batch_frames( atomic_batch_frames, output_dir / f"{filename_stem}_frames.mp4", - fps=sampler.fps, - spacing=video_spacing, + sampler.fps, + video_spacing, ) AtomicBatchDataset.save_atomic_batch_sim_data( - labels, output_dir / f"{filename_stem}_labels.h5" + output_dir / f"{filename_stem}_labels.h5", labels, sampler.label_keys ) if batch_idx % logging_interval == 0: elapsed = time() - start_time @@ -222,13 +222,14 @@ def process_batch(batch_idx: int): # --output-dir bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001 # # Example usage not via CLI + # from poseforge.util import bulk_data_dir # extract_atomic_batches( # atomic_batch_nframes=32, # atomic_batch_nvariants_max=4, # minimum_time_diff_frames=60, - # input_basedir="bulk_data/style_transfer/production/translated_videos", - # nmf_sim_rendering_basedir="bulk_data/nmf_rendering/", - # output_dir="bulk_data/pose_estimation/atomic_batches_test", + # input_basedir=bulk_data_dir / "style_transfer/production/translated_videos", + # nmf_sim_rendering_basedir=bulk_data_dir / "nmf_rendering/", + # output_dir=bulk_data_dir / "pose_estimation/atomic_batches_test", # original_image_size=(464, 464), # n_jobs=-1, # logging_interval=100, diff --git a/src/poseforge/pose/data/synthetic/atomic_batch.py b/src/poseforge/pose/data/synthetic/atomic_batch.py index 20be9de..b2bbba5 100644 --- a/src/poseforge/pose/data/synthetic/atomic_batch.py +++ b/src/poseforge/pose/data/synthetic/atomic_batch.py @@ -208,9 +208,9 @@ def load_atomic_batch_frames( @staticmethod def save_atomic_batch_sim_data( - sim_data: dict[str, np.ndarray], output_path: Path, - metadata: dict | None = None, + sim_data: dict[str, np.ndarray], + label_keys: dict[str, list[str]], ): """Save simulation data for an atomic batch to an HDF5 file. @@ -226,11 +226,9 @@ def save_atomic_batch_sim_data( # these files will be accessed very frequently during training, so we # use lzf (faster than gzip) and no shuffling to optimize for speed compression = "lzf" if key == "body_seg_maps" else None - f.create_dataset(key, data=value, compression=compression) + ds = f.create_dataset(key, data=value, compression=compression) + ds.attrs["keys"] = label_keys[key] f.attrs["n_frames"] = next(iter(sim_data.values())).shape[0] - if metadata is not None: - for key, value in metadata.items(): - f.attrs[key] = value @staticmethod def load_atomic_batch_sim_data( diff --git a/src/poseforge/pose/data/synthetic/sampler.py b/src/poseforge/pose/data/synthetic/sampler.py index d0e3daf..3beb873 100644 --- a/src/poseforge/pose/data/synthetic/sampler.py +++ b/src/poseforge/pose/data/synthetic/sampler.py @@ -105,6 +105,19 @@ def __init__( assert self.sampling_stride >= 1, "`sampling_stride` must be >= 1" assert len(simulated_data_sequences) > 0, "At least 1 simulation required" + # Check if all simulations have the same metadata + metadata_li = [seq.get_sim_data_metadata() for seq in simulated_data_sequences] + md_ref = metadata_li[0] + for metadata in metadata_li[1:]: + assert metadata == md_ref, "All simulations must have the same metadata" + self.label_keys = { + "dof_angles": md_ref["dof_angles"]["keys"], + "keypoint_pos": md_ref["keypoint_pos"]["keys"], + "mesh_pos": md_ref["mesh_pose6d"]["keys"], + "mesh_quat": md_ref["mesh_pose6d"]["keys"], + "body_seg_maps": md_ref["segmentation_labels"]["keys"], + } + # Check number of variants per frame assert ( len(set([seq.n_variants for seq in simulated_data_sequences])) == 1 @@ -113,14 +126,14 @@ def __init__( # Check image size and FPS _image_sizes = set() - _fpss = set() + _fps = set() for seq in simulated_data_sequences: _image_sizes.add(seq.frame_size) - _fpss.add(seq.fps) + _fps.add(seq.fps) assert len(_image_sizes) == 1, "All simulations must have the same image size" - assert len(_fpss) == 1, "All simulations must have the same FPS" + assert len(_fps) == 1, "All simulations must have the same FPS" self.frame_size = _image_sizes.pop() - self.fps = _fpss.pop() + self.fps = _fps.pop() # Check number of frames per simulation and in total self.n_frames_per_sim = np.array( diff --git a/src/poseforge/pose/data/synthetic/sim_data_seq.py b/src/poseforge/pose/data/synthetic/sim_data_seq.py index 0782124..9c33d02 100644 --- a/src/poseforge/pose/data/synthetic/sim_data_seq.py +++ b/src/poseforge/pose/data/synthetic/sim_data_seq.py @@ -56,13 +56,14 @@ def __init__( def get_sim_data_metadata(self) -> dict: metadata = {} - with h5py.File(self.simulated_labels_path, "r") as ds: - postprocessed_ds = ds["postprocessed"] - metadata["dof_angles"] = dict(postprocessed_ds["dof_angles"].attrs) - metadata["keypoint_pos"] = dict(postprocessed_ds["keypoint_pos"].attrs) - metadata["segmentation_labels"] = dict( - postprocessed_ds["segmentation_labels"].attrs - ) + with h5py.File(self.simulated_labels_path, "r") as f: + ds = f["postprocessed"] + metadata["dof_angles"] = dict(ds["dof_angles"].attrs) + metadata["keypoint_pos"] = dict(ds["keypoint_pos"].attrs) + metadata["segmentation_labels"] = dict(ds["segmentation_labels"].attrs) + metadata["mesh_pose6d"] = dict(ds["mesh_pose6d_rel_camera"].attrs) + for inner_dict in metadata.values(): + inner_dict["keys"] = list(inner_dict["keys"]) return metadata def __len__(self) -> int: From a73ca9b501564ef0546e7d7556d2acda86b971dc Mon Sep 17 00:00:00 2001 From: Sibo Wang Date: Sun, 14 Dec 2025 20:15:14 +0100 Subject: [PATCH 33/33] [fix] fix hard-attached metadata --- .../pose/data/scripts/_attach_metadata.py | 88 ------------------- 1 file changed, 88 deletions(-) diff --git a/src/poseforge/pose/data/scripts/_attach_metadata.py b/src/poseforge/pose/data/scripts/_attach_metadata.py index 9eee3c8..31db7c7 100644 --- a/src/poseforge/pose/data/scripts/_attach_metadata.py +++ b/src/poseforge/pose/data/scripts/_attach_metadata.py @@ -114,145 +114,57 @@ ], "mesh_pos": [ "Thorax", - "A1A2", - "A3", - "A4", - "A5", - "A6", - "LHaltere", - "LWing", - "RHaltere", - "RWing", "LFCoxa", "LFFemur", "LFTibia", "LFTarsus1", - "LFTarsus2", - "LFTarsus3", - "LFTarsus4", - "LFTarsus5", "LHCoxa", "LHFemur", "LHTibia", "LHTarsus1", - "LHTarsus2", - "LHTarsus3", - "LHTarsus4", - "LHTarsus5", "LMCoxa", "LMFemur", "LMTibia", "LMTarsus1", - "LMTarsus2", - "LMTarsus3", - "LMTarsus4", - "LMTarsus5", "RHCoxa", "RHFemur", "RHTibia", "RHTarsus1", - "RHTarsus2", - "RHTarsus3", - "RHTarsus4", - "RHTarsus5", "RMCoxa", "RMFemur", "RMTibia", "RMTarsus1", - "RMTarsus2", - "RMTarsus3", - "RMTarsus4", - "RMTarsus5", "RFCoxa", "RFFemur", "RFTibia", "RFTarsus1", - "RFTarsus2", - "RFTarsus3", - "RFTarsus4", - "RFTarsus5", - "Head", - "LEye", - "REye", - "Rostrum", - "Haustellum", - "LPedicel", - "LFuniculus", - "LArista", - "RPedicel", - "RFuniculus", - "RArista", ], "mesh_quat": [ "Thorax", - "A1A2", - "A3", - "A4", - "A5", - "A6", - "LHaltere", - "LWing", - "RHaltere", - "RWing", "LFCoxa", "LFFemur", "LFTibia", "LFTarsus1", - "LFTarsus2", - "LFTarsus3", - "LFTarsus4", - "LFTarsus5", "LHCoxa", "LHFemur", "LHTibia", "LHTarsus1", - "LHTarsus2", - "LHTarsus3", - "LHTarsus4", - "LHTarsus5", "LMCoxa", "LMFemur", "LMTibia", "LMTarsus1", - "LMTarsus2", - "LMTarsus3", - "LMTarsus4", - "LMTarsus5", "RHCoxa", "RHFemur", "RHTibia", "RHTarsus1", - "RHTarsus2", - "RHTarsus3", - "RHTarsus4", - "RHTarsus5", "RMCoxa", "RMFemur", "RMTibia", "RMTarsus1", - "RMTarsus2", - "RMTarsus3", - "RMTarsus4", - "RMTarsus5", "RFCoxa", "RFFemur", "RFTibia", "RFTarsus1", - "RFTarsus2", - "RFTarsus3", - "RFTarsus4", - "RFTarsus5", - "Head", - "LEye", - "REye", - "Rostrum", - "Haustellum", - "LPedicel", - "LFuniculus", - "LArista", - "RPedicel", - "RFuniculus", - "RArista", ], }