diff --git a/.gitignore b/.gitignore index a5d337c..64f5412 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,24 @@ -# Created by https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,jupyternotebooks,visualstudiocode,vim,git,pycharm+all -# Edit at https://www.toptal.com/developers/gitignore?templates=python,linux,macos,windows,jupyternotebooks,visualstudiocode,vim,git,pycharm+all +bulk_data +production_models/** +!production_models/**/ +!production_models/**/README.md +logs/ +wandb/ +*.mp4 +*.jpg +*.pth +*.png +temp/ +batch_scripts/ +outputs/ +notebooks/ + +# I want all JetBrains-specific file to be ignored +# (gitignore.io defaults allow style and run settings to be included) +.idea/ + +# Created by https://www.toptal.com/developers/gitignore/api/git,vim,linux,macos,python,windows,jupyternotebooks,visualstudiocode +# Edit at https://www.toptal.com/developers/gitignore?templates=git,vim,linux,macos,python,windows,jupyternotebooks,visualstudiocode ### Git ### # Created by git for backups. To disable backups in Git: @@ -78,94 +97,6 @@ Temporary Items # iCloud generated files *.icloud -### PyCharm+all ### -# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider -# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 - -# User-specific stuff -.idea/**/workspace.xml -.idea/**/tasks.xml -.idea/**/usage.statistics.xml -.idea/**/dictionaries -.idea/**/shelf - -# AWS User-specific -.idea/**/aws.xml - -# Generated files -.idea/**/contentModel.xml - -# Sensitive or high-churn files -.idea/**/dataSources/ -.idea/**/dataSources.ids -.idea/**/dataSources.local.xml -.idea/**/sqlDataSources.xml -.idea/**/dynamic.xml -.idea/**/uiDesigner.xml -.idea/**/dbnavigator.xml - -# Gradle -.idea/**/gradle.xml -.idea/**/libraries - -# Gradle and Maven with auto-import -# When using Gradle or Maven with auto-import, you should exclude module files, -# since they will be recreated, and may cause churn. Uncomment if using -# auto-import. -# .idea/artifacts -# .idea/compiler.xml -# .idea/jarRepositories.xml -# .idea/modules.xml -# .idea/*.iml -# .idea/modules -# *.iml -# *.ipr - -# CMake -cmake-build-*/ - -# Mongo Explorer plugin -.idea/**/mongoSettings.xml - -# File-based project format -*.iws - -# IntelliJ -out/ - -# mpeltonen/sbt-idea plugin -.idea_modules/ - -# JIRA plugin -atlassian-ide-plugin.xml - -# Cursive Clojure plugin -.idea/replstate.xml - -# SonarLint plugin -.idea/sonarlint/ - -# Crashlytics plugin (for Android Studio and IntelliJ) -com_crashlytics_export_strings.xml -crashlytics.properties -crashlytics-build.properties -fabric.properties - -# Editor-based Rest Client -.idea/httpRequests - -# Android studio 3.1+ serialized cache file -.idea/caches/build_file_checksums.ser - -### PyCharm+all Patch ### -# Ignore everything but code style settings and run configurations -# that are supposed to be shared within teams. - -.idea/* - -!.idea/codeStyles -!.idea/runConfigurations - ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ @@ -288,7 +219,7 @@ celerybeat.pid # Environments .env .venv -# env/ +env/ venv/ ENV/ env.bak/ @@ -357,7 +288,7 @@ tags ### VisualStudioCode ### .vscode/* -# !.vscode/settings.json +!.vscode/settings.json !.vscode/tasks.json !.vscode/launch.json !.vscode/extensions.json @@ -400,20 +331,4 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk -# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,jupyternotebooks,visualstudiocode,vim,git,pycharm+all - - -bulk_data -production_models/** -!production_models/**/ -!production_models/**/README.md -logs/ -wandb/ -*.mp4 -*.jpg -*.pth -*.png -temp/ -batch_scripts/ -outputs/ -notebooks/ +# End of https://www.toptal.com/developers/gitignore/api/git,vim,linux,macos,python,windows,jupyternotebooks,visualstudiocode \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index bcdb877..1fc35b2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -1288,6 +1288,21 @@ files = [ {file = "numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78"}, ] +[[package]] +name = "numpy-typing-compat" +version = "20250818.2.0" +description = "Static typing compatibility layer for older versions of NumPy" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "numpy_typing_compat-20250818.2.0-py3-none-any.whl", hash = "sha256:042da86a786b6eb164f900efdfc3ba132f4371a2e44a93109976b1d7538253ed"}, + {file = "numpy_typing_compat-20250818.2.0.tar.gz", hash = "sha256:3f77ba873ec9668e9b7bd15ae083cc16c82aa732b651ed2bf5aa284cdd0dc71d"}, +] + +[package.dependencies] +numpy = ">=2.0,<2.1" + [[package]] name = "nvidia-cublas-cu12" version = "12.8.4.1" @@ -1529,6 +1544,25 @@ files = [ [package.dependencies] numpy = {version = ">=2,<2.3.0", markers = "python_version >= \"3.9\""} +[[package]] +name = "optype" +version = "0.14.0" +description = "Building Blocks for Precise & Flexible Type Hints" +optional = false +python-versions = ">=3.11" +groups = ["main"] +files = [ + {file = "optype-0.14.0-py3-none-any.whl", hash = "sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b"}, + {file = "optype-0.14.0.tar.gz", hash = "sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6"}, +] + +[package.dependencies] +numpy = {version = ">=1.25", optional = true, markers = "extra == \"numpy\""} +numpy-typing-compat = {version = ">=20250818.1.25,<20250819", optional = true, markers = "extra == \"numpy\""} + +[package.extras] +numpy = ["numpy (>=1.25)", "numpy-typing-compat (>=20250818.1.25,<20250819)"] + [[package]] name = "packaging" version = "25.0" @@ -1637,6 +1671,22 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pandas-stubs" +version = "2.3.3.251201" +description = "Type annotations for pandas" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "pandas_stubs-2.3.3.251201-py3-none-any.whl", hash = "sha256:eb5c9b6138bd8492fd74a47b09c9497341a278fcfbc8633ea4b35b230ebf4be5"}, + {file = "pandas_stubs-2.3.3.251201.tar.gz", hash = "sha256:7a980f4f08cff2a6d7e4c6d6d26f4c5fcdb82a6f6531489b2f75c81567fe4536"}, +] + +[package.dependencies] +numpy = ">=1.23.5" +types-pytz = ">=2022.1.1" + [[package]] name = "pillow" version = "11.3.0" @@ -2372,6 +2422,24 @@ dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodest doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.19.1)", "jupytext", "linkify-it-py", "matplotlib (>=3.5)", "myst-nb (>=1.2.0)", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.2.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)"] test = ["Cython", "array-api-strict (>=2.3.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja ; sys_platform != \"emscripten\"", "pooch", "pytest (>=8.0.0)", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] +[[package]] +name = "scipy-stubs" +version = "1.16.3.2" +description = "Type annotations for SciPy" +optional = false +python-versions = ">=3.11" +groups = ["main"] +files = [ + {file = "scipy_stubs-1.16.3.2-py3-none-any.whl", hash = "sha256:c8df8975d69d77237a12e7b32f0bce08a7cc27bbb53ff7f4671ef1e9de033490"}, + {file = "scipy_stubs-1.16.3.2.tar.gz", hash = "sha256:04c498da2bfb445cf1dce06cb012bbaafd30fa19d9c7b03852a0574e7e83df3a"}, +] + +[package.dependencies] +optype = {version = ">=0.14.0,<0.15", extras = ["numpy"]} + +[package.extras] +scipy = ["scipy (>=1.16.3,<1.17)"] + [[package]] name = "scooby" version = "0.10.2" @@ -2813,6 +2881,18 @@ files = [ [package.dependencies] typing_extensions = ">=4.14.0" +[[package]] +name = "types-pytz" +version = "2025.2.0.20251108" +description = "Typing stubs for pytz" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "types_pytz-2025.2.0.20251108-py3-none-any.whl", hash = "sha256:0f1c9792cab4eb0e46c52f8845c8f77cf1e313cb3d68bf826aa867fe4717d91c"}, + {file = "types_pytz-2025.2.0.20251108.tar.gz", hash = "sha256:fca87917836ae843f07129567b74c1929f1870610681b4c92cb86a3df5817bdb"}, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -3022,4 +3102,4 @@ dev = ["black (>=19.3b0) ; python_version >= \"3.6\"", "pytest (>=4.6.2)"] [metadata] lock-version = "2.1" python-versions = ">=3.13,<3.14" -content-hash = "6b6d3ec0385494898f80d1d821a1b48f1bfca0f081be0600d583487e664119cb" +content-hash = "faa526e8c9a8209823ec8cabb8cc87dd42446459bb0c237366c2d3a94487d28c" diff --git a/pyproject.toml b/pyproject.toml index fe4a71e..668ba53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,8 +9,8 @@ readme = "README.md" requires-python = ">=3.13,<3.14" dependencies = [ "numpy>=2.0,<3", - "scipy>=1.16.1,<2", - "pandas>=2.3,<3", + "scipy>=1.16.2,<1.17.0", + "pandas>=2.3.3,<2.4.0", "matplotlib>=3.10,<4", "torch==2.9.*", # specific version required by torchvision and torchcodec! "torchvision==0.24.*", # specific to torch 2.9 @@ -41,6 +41,9 @@ dependencies = [ "filelock>=3.20.0,<4", "av<16", # imageio depends on pyav but no pyav 16.0.0 wheel is available for my hardware yet "loguru==0.7.*", + # PEP561 stub packages for better static analysis + "pandas-stubs (>=2.3.3,<2.4.0)", + "scipy-stubs (>=1.16.2,<1.17.0)" # The following should be installed manually (pip install -e .): # flygym & dm_control # parallel-video-io diff --git a/scripts_on_cluster/atomic_batch_extraction/compress_results.run b/scripts_on_cluster/atomic_batch_extraction/compress_results.run new file mode 100644 index 0000000..94281ee --- /dev/null +++ b/scripts_on_cluster/atomic_batch_extraction/compress_results.run @@ -0,0 +1,66 @@ +#!/bin/bash -l + +#SBATCH --job-name=compress_output +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 36 +#SBATCH --mem 96GB +#SBATCH --time 1:00:00 +#SBATCH --partition=standard +#SBATCH --qos=serial +#SBATCH --output=logs/compress_output.log + +# ------------------------------------------------------------------- +# CRASH CHECK: Ensure SLURM_CPUS_PER_TASK is set +# ------------------------------------------------------------------- + +# Get the requested number of CPUs from the Slurm environment +max_jobs=$SLURM_CPUS_PER_TASK + +# Check if the variable is empty or null. +if [ -z "$max_jobs" ]; then + echo "ERROR: SLURM_CPUS_PER_TASK environment variable is not set." >&2 + echo "This script must be submitted using 'sbatch' and requires the --cpus-per-task flag in the header." >&2 + exit 1 +fi + +echo "Running up to $max_jobs concurrent tar jobs ($(date))." +echo "----------------------------------------------" + +input_basedir="$HOME/poseforge/bulk_data/pose_estimation/atomic_batches/4variants" +output_basedir="$HOME/poseforge/bulk_data/pose_estimation/atomic_batches/4variants_tarballs" +mkdir -p $output_basedir + +# Loop through all subdirectories +job_count=0 +for dir_path in "$input_basedir"/*/; do + if [ -d "$dir_path" ]; then + + dir_name=$(basename "$dir_path") + output_file="$output_basedir/${dir_name}.tar.gz" + + # Launch the tar command in the background + # Note the -C flag to change directory to the parent + # *before* archiving. This keeps the archive clean. + + tar -czf "$output_file" -C "$dir_path/.." "$dir_name" & + + echo "Launched job for $dir_name (PID: $!)" + job_count=$((job_count + 1)) + + # When we hit the max number of concurrent jobs, we wait for *one* + # of the currently running background jobs to complete (wait -n). + if [ "$job_count" -ge "$max_jobs" ]; then + wait -n + job_count=$((job_count - 1)) + echo "A background job finished. Resuming launch..." + fi + fi +done + +# Wait for ALL remaining background jobs to finish +echo "----------------------------------------------" +echo "All jobs launched. Waiting for final completion..." +wait + +echo "All parallel tar operations complete ($(date))." diff --git a/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py b/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py new file mode 100644 index 0000000..ada441e --- /dev/null +++ b/scripts_on_cluster/atomic_batch_extraction/gen_batch_scripts.py @@ -0,0 +1,22 @@ +from pathlib import Path + + +# Batch script generation +template_path = Path("template.run") +batch_scripts_dir = Path("batch_scripts/") +batch_scripts_dir.mkdir(exist_ok=True, parents=True) + +# Configs by task +synthetic_videos_basedir = Path("/home/sibwang/poseforge/bulk_data/style_transfer/production/translated_videos/") +trial_names_all = [x.name for x in synthetic_videos_basedir.glob("BO_Gal4_*")] + +# Generate batch scripts +with open("template.run") as f: + template_str = f.read() + +for trial_name in trial_names_all: + batch_script_str = template_str.replace("<<>>", trial_name) + with open(batch_scripts_dir / f"{trial_name}.run", "w") as f: + f.write(batch_script_str) + +print(f"Generated {len(trial_names_all)} batch scripts under {batch_scripts_dir}") diff --git a/scripts_on_cluster/atomic_batch_extraction/process_all.run b/scripts_on_cluster/atomic_batch_extraction/process_all.run index c7f1760..133754b 100644 --- a/scripts_on_cluster/atomic_batch_extraction/process_all.run +++ b/scripts_on_cluster/atomic_batch_extraction/process_all.run @@ -18,9 +18,9 @@ spack load ffmpeg conda activate poseforge # Define paths -project_root="/home/$USER/poseforge" -extraction_cli_path="src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py" -all_trials_basedir="/home/sibwang/poseforge/bulk_data/style_transfer/production/translated_videos/" +project_root="$HOME/poseforge" +extraction_cli_path="src/poseforge/pose/data/scripts/preextract_atomic_batches.py" +all_trials_basedir="bulk_data/style_transfer/production/translated_videos/" # Setup cd $project_root diff --git a/scripts_on_cluster/atomic_batch_extraction/template.run b/scripts_on_cluster/atomic_batch_extraction/template.run new file mode 100644 index 0000000..a6aad3e --- /dev/null +++ b/scripts_on_cluster/atomic_batch_extraction/template.run @@ -0,0 +1,42 @@ +#!/bin/bash -l + +#SBATCH --job-name exabat_<<>> +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 36 +#SBATCH --mem 200GB +#SBATCH --time 8:00:00 +#SBATCH --partition=standard +#SBATCH --qos=serial +#SBATCH --output logs/<<>>.log + +echo "Hello from $(hostname)" + +# Set up environment and package +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge + +# Define paths +project_root="$HOME/poseforge" +extraction_cli_path="src/poseforge/pose/data/scripts/preextract_atomic_batches.py" +synthetic_videos_basedir="bulk_data/style_transfer/production/translated_videos/" +nmf_rendering_basedir="bulk_data/nmf_rendering/" +output_basedir="bulk_data/pose_estimation/atomic_batches/4variants/" + +# Setup +cd $project_root + +# Training +echo "Processing $trial_name at $(date)" +python -u $extraction_cli_path \ + --n_jobs 8 \ + --atomic-batch-nframes 32 \ + --atomic-batch-nvariants-max 4 \ + --minimum-time-diff-frames 60 \ + --original-image-size 464 464 \ + --input-basedir $synthetic_videos_basedir/<<>> \ + --nmf-sim-rendering-basedir $nmf_rendering_basedir \ + --output-dir $output_basedir/<<>> + +echo "Extraction ends at $(date)" diff --git a/scripts_on_cluster/bodyseg_training/job.run b/scripts_on_cluster/bodyseg_training/job.run index 7a1154c..8838990 100644 --- a/scripts_on_cluster/bodyseg_training/job.run +++ b/scripts_on_cluster/bodyseg_training/job.run @@ -18,7 +18,7 @@ spack load ffmpeg conda activate poseforge cd $HOME/poseforge -training_cli_path="src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py" +training_cli_path="src/poseforge/pose/bodyseg/scripts/run_training.py" training_trial_name="trial_20251127a" contrastive_pretraining_trial_name="trial_20251125a_lowlr" contrastive_pretraining_epoch="epoch009" diff --git a/scripts_on_cluster/contrastive_pretraining_inference/submit_all.sh b/scripts_on_cluster/common/submit_all.sh similarity index 100% rename from scripts_on_cluster/contrastive_pretraining_inference/submit_all.sh rename to scripts_on_cluster/common/submit_all.sh diff --git a/scripts_on_cluster/keypoints3d_training/job.run b/scripts_on_cluster/keypoints3d_training/job.run index 5e623d4..7b057f0 100644 --- a/scripts_on_cluster/keypoints3d_training/job.run +++ b/scripts_on_cluster/keypoints3d_training/job.run @@ -18,7 +18,7 @@ spack load ffmpeg conda activate poseforge cd $HOME/poseforge -training_cli_path="src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py" +training_cli_path="src/poseforge/pose/keypoints3d/scripts/run_training.py" training_trial_name="trial_20251127a" contrastive_pretraining_trial_name="trial_20251125a_lowlr" contrastive_pretraining_epoch="epoch009" diff --git a/scripts_on_cluster/nmf_simulation/compress_results.run b/scripts_on_cluster/nmf_simulation/compress_results.run new file mode 100644 index 0000000..8bd810a --- /dev/null +++ b/scripts_on_cluster/nmf_simulation/compress_results.run @@ -0,0 +1,67 @@ +#!/bin/bash -l + +#SBATCH --job-name=compress_output +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 36 +#SBATCH --mem 96GB +#SBATCH --time 1:00:00 +#SBATCH --partition=standard +#SBATCH --qos=serial +#SBATCH --output=logs/compress_output.log + +# ------------------------------------------------------------------- +# CRASH CHECK: Ensure SLURM_CPUS_PER_TASK is set +# ------------------------------------------------------------------- + +# Get the requested number of CPUs from the Slurm environment +max_jobs=$SLURM_CPUS_PER_TASK + +# Check if the variable is empty or null. +if [ -z "$max_jobs" ]; then + echo "ERROR: SLURM_CPUS_PER_TASK environment variable is not set." >&2 + echo "This script must be submitted using 'sbatch' and requires the --cpus-per-task flag in the header." >&2 + exit 1 +fi + +echo "Running up to $max_jobs concurrent tar jobs ($(date))." +echo "----------------------------------------------" + +input_basedir="$HOME/poseforge/bulk_data/nmf_rendering" +output_basedir="$HOME/poseforge/bulk_data/nmf_rendering_tarballs" +mkdir -p $output_basedir + +# Loop through all subdirectories +job_count=0 +for dir_path in "$input_basedir"/*/; do + # Check if the path is actually a directory + if [ -d "$dir_path" ]; then + + dir_name=$(basename "$dir_path") + output_file="$output_basedir/${dir_name}.tar.gz" + + # --- Launch the tar command in the background --- + # Note the -C flag to change directory to the parent + # *before* archiving. This keeps the archive clean. + + tar -czf "$output_file" -C "$dir_path/.." "$dir_name" & + + echo "Launched job for $dir_name (PID: $!)" + job_count=$((job_count + 1)) + + # --- Job Control: Wait for a background job to finish --- + # When we hit the max number of concurrent jobs, we wait for *one* # of the currently running background jobs to complete (wait -n). + if [ "$job_count" -ge "$max_jobs" ]; then + wait -n + job_count=$((job_count - 1)) + echo "A background job finished. Resuming launch..." + fi + fi +done + +# --- Final Wait: Wait for ALL remaining background jobs to finish --- +echo "----------------------------------------------" +echo "All jobs launched. Waiting for final completion..." +wait + +echo "All parallel tar operations complete ($(date))." diff --git a/scripts_on_cluster/nmf_simulation/submit_all.sh b/scripts_on_cluster/nmf_simulation/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/nmf_simulation/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/pose6d_training/train_20251201a.run b/scripts_on_cluster/pose6d_training/train_20251201a.run new file mode 100644 index 0000000..5511dcc --- /dev/null +++ b/scripts_on_cluster/pose6d_training/train_20251201a.run @@ -0,0 +1,81 @@ +#!/bin/bash -l + +#SBATCH --job-name pose6d-training +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 32 +#SBATCH --mem 92GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/pose6d_training/output_20251201a.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge +cd $HOME/poseforge + +training_cli_path="src/poseforge/pose/pose6d/scripts/run_training.py" +training_trial_name="20251201a" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" + +echo "Training starting at $(date)" + +python -u $training_cli_path \ + --n_epochs 15 \ + --seed 42 \ + --model-architecture-config.n-segments 25 \ + --model-architecture-config.n-attention-gated-feature-channels 128 \ + --model-architecture-config.n-global-feature-channels 128 \ + --model-architecture-config.camera-distance 100.0 \ + --model-weights-config.feature-extractor-weights \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ + --loss-config.translation-weight 1.0 \ + --loss-config.rotation-weight 2.0 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ + --training-data-config.input-image-size 256 256 \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 128 \ + --training-data-config.val-batch-size 512 \ + --training-data-config.n-workers 8 \ + --optimizer-config.learning_rate_encoder 0.00003 \ + --optimizer-config.learning_rate_upsample 0.0003 \ + --optimizer-config.learning_rate_attention 0.0003 \ + --optimizer-config.learning_rate_pose6d_heads 0.0003 \ + --optimizer-config.weight-decay 0.00001 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/pose6d/$training_trial_name/" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 1000 \ + --training-artifacts-config.validation-interval 300 \ + --training-artifacts-config.n-batches-per-validation 30 + +echo "Training ends at $(date)" diff --git a/scripts_on_cluster/pose6d_training/train_20251201b.run b/scripts_on_cluster/pose6d_training/train_20251201b.run new file mode 100644 index 0000000..15f058d --- /dev/null +++ b/scripts_on_cluster/pose6d_training/train_20251201b.run @@ -0,0 +1,81 @@ +#!/bin/bash -l + +#SBATCH --job-name pose6d-training +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 32 +#SBATCH --mem 92GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/pose6d_training/output_20251201b.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge +cd $HOME/poseforge + +training_cli_path="src/poseforge/pose/pose6d/scripts/run_training.py" +training_trial_name="20251201b" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" + +echo "Training starting at $(date)" + +python -u $training_cli_path \ + --n_epochs 15 \ + --seed 42 \ + --model-architecture-config.n-segments 25 \ + --model-architecture-config.n-attention-gated-feature-channels 128 \ + --model-architecture-config.n-global-feature-channels 128 \ + --model-architecture-config.camera-distance 100.0 \ + --model-weights-config.feature-extractor-weights \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ + --loss-config.translation-weight 1.0 \ + --loss-config.rotation-weight 10.0 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ + --training-data-config.input-image-size 256 256 \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 128 \ + --training-data-config.val-batch-size 512 \ + --training-data-config.n-workers 8 \ + --optimizer-config.learning_rate_encoder 0.00003 \ + --optimizer-config.learning_rate_upsample 0.0003 \ + --optimizer-config.learning_rate_attention 0.0003 \ + --optimizer-config.learning_rate_pose6d_heads 0.0003 \ + --optimizer-config.weight-decay 0.00001 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/pose6d/$training_trial_name/" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 1000 \ + --training-artifacts-config.validation-interval 300 \ + --training-artifacts-config.n-batches-per-validation 30 + +echo "Training ends at $(date)" diff --git a/scripts_on_cluster/pose6d_training/train_20251201c.run b/scripts_on_cluster/pose6d_training/train_20251201c.run new file mode 100644 index 0000000..f822166 --- /dev/null +++ b/scripts_on_cluster/pose6d_training/train_20251201c.run @@ -0,0 +1,81 @@ +#!/bin/bash -l + +#SBATCH --job-name pose6d-training +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 32 +#SBATCH --mem 92GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/pose6d_training/output_20251201c.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge +cd $HOME/poseforge + +training_cli_path="src/poseforge/pose/pose6d/scripts/run_training.py" +training_trial_name="20251201c" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" + +echo "Training starting at $(date)" + +python -u $training_cli_path \ + --n_epochs 15 \ + --seed 42 \ + --model-architecture-config.n-segments 25 \ + --model-architecture-config.n-attention-gated-feature-channels 192 \ + --model-architecture-config.n-global-feature-channels 64 \ + --model-architecture-config.camera-distance 100.0 \ + --model-weights-config.feature-extractor-weights \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ + --loss-config.translation-weight 1.0 \ + --loss-config.rotation-weight 10.0 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ + --training-data-config.input-image-size 256 256 \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 128 \ + --training-data-config.val-batch-size 512 \ + --training-data-config.n-workers 8 \ + --optimizer-config.learning_rate_encoder 0.0001 \ + --optimizer-config.learning_rate_upsample 0.001 \ + --optimizer-config.learning_rate_attention 0.001 \ + --optimizer-config.learning_rate_pose6d_heads 0.001 \ + --optimizer-config.weight-decay 0.00001 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/pose6d/$training_trial_name/" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 1000 \ + --training-artifacts-config.validation-interval 300 \ + --training-artifacts-config.n-batches-per-validation 30 + +echo "Training ends at $(date)" diff --git a/scripts_on_cluster/pose6d_training/train_20251211a.run b/scripts_on_cluster/pose6d_training/train_20251211a.run new file mode 100644 index 0000000..a7c7258 --- /dev/null +++ b/scripts_on_cluster/pose6d_training/train_20251211a.run @@ -0,0 +1,81 @@ +#!/bin/bash -l + +#SBATCH --job-name pose6d-training +#SBATCH --nodes 1 +#SBATCH --ntasks 1 +#SBATCH --cpus-per-task 32 +#SBATCH --mem 92GB +#SBATCH --time 72:00:00 +#SBATCH --partition=h100 +#SBATCH --qos=normal +#SBATCH --gres=gpu:1 +#SBATCH --output /home/sibwang/poseforge/scripts_on_cluster/pose6d_training/output_20251211a.log + +echo "Hello from $(hostname)" + +. ~/spack/share/spack/setup-env.sh +spack load ffmpeg +conda activate poseforge +cd $HOME/poseforge + +training_cli_path="src/poseforge/pose/pose6d/scripts/run_training.py" +training_trial_name="20251211a" +contrastive_pretraining_trial_name="trial_20251125a_lowlr" +contrastive_pretraining_epoch="epoch009" +contrastive_pretraining_local_step="step003055" + +echo "Training starting at $(date)" + +python -u $training_cli_path \ + --n_epochs 15 \ + --seed 42 \ + --model-architecture-config.n-segments 25 \ + --model-architecture-config.n-attention-gated-feature-channels 192 \ + --model-architecture-config.n-global-feature-channels 64 \ + --model-architecture-config.camera-distance 100.0 \ + --model-weights-config.feature-extractor-weights \ + "bulk_data/pose_estimation/contrastive_pretraining/$contrastive_pretraining_trial_name/checkpoints/checkpoint_${contrastive_pretraining_epoch}_${contrastive_pretraining_local_step}.feature_extractor.pth" \ + --loss-config.translation-weight 1.0 \ + --loss-config.rotation-weight 10.0 \ + --training-data-config.train-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly1_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly2_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly3_trial005" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial001" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial002" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial003" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial004" \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly4_trial005" \ + --training-data-config.val-data-dirs \ + "bulk_data/pose_estimation/atomic_batches/4variants/BO_Gal4_fly5_trial001" \ + --training-data-config.input-image-size 256 256 \ + --training-data-config.atomic-batch-n-samples 32 \ + --training-data-config.atomic-batch-n-variants 4 \ + --training-data-config.train-batch-size 128 \ + --training-data-config.val-batch-size 512 \ + --training-data-config.n-workers 8 \ + --optimizer-config.learning_rate_encoder 0.0001 \ + --optimizer-config.learning_rate_upsample 0.001 \ + --optimizer-config.learning_rate_attention 0.001 \ + --optimizer-config.learning_rate_pose6d_heads 0.001 \ + --optimizer-config.weight-decay 0.00001 \ + --training-artifacts-config.output-basedir \ + "bulk_data/pose_estimation/pose6d/$training_trial_name/" \ + --training-artifacts-config.logging-interval 10 \ + --training-artifacts-config.checkpoint-interval 1000 \ + --training-artifacts-config.validation-interval 300 \ + --training-artifacts-config.n-batches-per-validation 30 + +echo "Training ends at $(date)" diff --git a/scripts_on_cluster/style_transfer_inference/submit_all.sh b/scripts_on_cluster/style_transfer_inference/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/style_transfer_inference/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/style_transfer_training/20250903_parameter_sweep/submit_all.sh b/scripts_on_cluster/style_transfer_training/20250903_parameter_sweep/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/style_transfer_training/20250903_parameter_sweep/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/style_transfer_training/20250905_continued_training/submit_all.sh b/scripts_on_cluster/style_transfer_training/20250905_continued_training/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/style_transfer_training/20250905_continued_training/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/scripts_on_cluster/style_transfer_training/20250905_parameter_sweep/submit_all.sh b/scripts_on_cluster/style_transfer_training/20250905_parameter_sweep/submit_all.sh deleted file mode 100644 index 1dc2024..0000000 --- a/scripts_on_cluster/style_transfer_training/20250905_parameter_sweep/submit_all.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -scripts_dir="./batch_scripts" - -files=($(ls $scripts_dir/*.run | sort)) -file_count=${#files[@]} - -echo -n "Are you sure you want to submit $file_count jobs? (y/Y to confirm) " -read -r confirmation -if [[ "$confirmation" != "y" && "$confirmation" != "Y" ]]; then - echo "Submission canceled." - exit 1 -fi - -for file in "${files[@]}"; do - echo "Submitting $file" - sbatch $file -done - -echo "Submitted $file_count files to the scheduler" diff --git a/src/poseforge/neuromechfly/constants.py b/src/poseforge/neuromechfly/constants.py index dfccbfc..c2b246f 100644 --- a/src/poseforge/neuromechfly/constants.py +++ b/src/poseforge/neuromechfly/constants.py @@ -79,6 +79,26 @@ "RAntenna": np.array([50, 120, 32]) / 255, } +segments_for_6dpose_per_leg = ["Coxa", "Femur", "Tibia", "Tarsus1"] +segments_for_6dpose = [ + f"{leg}{seg}" for leg in legs for seg in segments_for_6dpose_per_leg +] + ["Thorax"] + +# fmt: off +# all_body_segments = [ +# "Thorax", "A1A2", "A3", "A4", "A5", "A6", +# "LHaltere", "LWing", "RHaltere", "RWing", +# "LFCoxa", "LFFemur", "LFTibia", "LFTarsus1", "LFTarsus2", "LFTarsus3", "LFTarsus4", "LFTarsus5", +# "LHCoxa", "LHFemur", "LHTibia", "LHTarsus1", "LHTarsus2", "LHTarsus3", "LHTarsus4", "LHTarsus5", +# "LMCoxa", "LMFemur", "LMTibia", "LMTarsus1", "LMTarsus2", "LMTarsus3", "LMTarsus4", "LMTarsus5", +# "RHCoxa", "RHFemur", "RHTibia", "RHTarsus1", "RHTarsus2", "RHTarsus3", "RHTarsus4", "RHTarsus5", +# "RMCoxa", "RMFemur", "RMTibia", "RMTarsus1", "RMTarsus2", "RMTarsus3", "RMTarsus4", "RMTarsus5", +# "RFCoxa", "RFFemur", "RFTibia", "RFTarsus1", "RFTarsus2", "RFTarsus3", "RFTarsus4", "RFTarsus5", +# "Head", "LEye", "REye", "Rostrum", "Haustellum", +# "LPedicel", "LFuniculus", "LArista", "RPedicel", "RFuniculus", "RArista", +# ] +# fmt: on + ########################################################################### ## COLORS FOR BODY SEGMENT RENDERING BELOW ## diff --git a/src/poseforge/neuromechfly/postprocessing.py b/src/poseforge/neuromechfly/postprocessing.py index fa33870..eac379c 100644 --- a/src/poseforge/neuromechfly/postprocessing.py +++ b/src/poseforge/neuromechfly/postprocessing.py @@ -1,3 +1,5 @@ +from typing import Any + import numpy as np import matplotlib.pyplot as plt import scipy.ndimage as ndimage @@ -8,8 +10,12 @@ import h5py from collections import defaultdict from pathlib import Path + +from numpy import ndarray, dtype from tqdm import tqdm from joblib import Parallel, delayed +from scipy.linalg import rq +from scipy.spatial.transform import Rotation import poseforge.neuromechfly.constants as constants from poseforge.util.plot import ( @@ -78,7 +84,9 @@ def __init__(self): self.label_colors_6d = np.array(self.label_colors_6d) - def __call__(self, images_by_color_coding: list[np.ndarray]): + def __call__( + self, images_by_color_coding: list[np.ndarray] | np.ndarray + ) -> np.ndarray: assert ( len(images_by_color_coding) == 2 ), "Expecting two images each with a different color coding." @@ -87,8 +95,8 @@ def __call__(self, images_by_color_coding: list[np.ndarray]): ), "Color coding images must have the same shape." if not isinstance(images_by_color_coding, list): images_by_color_coding = [ - images_by_color_coding[0, :, :, :], - images_by_color_coding[1, :, :, :], + images_by_color_coding[0, ...], + images_by_color_coding[1, ...], ] image_6d = np.concatenate(images_by_color_coding, axis=-1) @@ -101,7 +109,7 @@ def __call__(self, images_by_color_coding: list[np.ndarray]): return label_indices -def load_video_frames(video_path: Path) -> list[np.ndarray]: +def load_video_frames(video_path: Path) -> tuple[list[np.ndarray], float]: """Load video frames from a video file.""" if not video_path.is_file(): raise FileNotFoundError(f"{video_path} is not a file.") @@ -120,7 +128,7 @@ def load_video_frames(video_path: Path) -> list[np.ndarray]: return frames, fps -def get_rotation_angle_and_matrix(forward_vector: np.ndarray) -> np.ndarray: +def get_rotation_angle_and_matrix(forward_vector: np.ndarray) -> tuple[float, ndarray]: """Get a rotation matrix that rotates the forward vector to the z-axis.""" orientation = np.arctan2(forward_vector[1], forward_vector[0]) rotation_matrix = np.array( @@ -139,7 +147,9 @@ def rotate_image(image: np.ndarray, rotation_angle) -> np.ndarray: ) -def center_square_crop_image(image: np.ndarray, side_length) -> np.ndarray: +def center_square_crop_image( + image: np.ndarray, side_length: int +) -> tuple[np.ndarray, int, int]: """Crop the image to a square of given side length, centered on the image.""" height, width = image.shape[:2] start_col = (width - side_length) // 2 @@ -204,7 +214,7 @@ def rotate_keypoint_positions_camera( image_shape: tuple[int, int], start_col: int, start_row: int, -): +) -> dict[str, np.ndarray]: """ Apply rotation and cropping transformations to camera/image coordinates. Uses a simplified approach based on the working code snippet. @@ -213,7 +223,12 @@ def rotate_keypoint_positions_camera( keypoints_pos_lookup: Dict mapping segment names to (x, y, depth) positions rotation_matrix: 3x3 rotation matrix used for coordinate transformation image_shape: (height, width) of the original image before rotation - start_col, start_row: Cropping offsets after rotation + start_col: Cropping offsets after rotation (column) + start_row: Cropping offsets after rotation (row) + + Returns: + keypoints_pos_lookup_rotated: Dict mapping segment names to transformed + (x, y, depth) positions """ keypoints_pos_lookup_rotated = {} height, width = image_shape @@ -237,7 +252,7 @@ def rotate_keypoint_positions_camera( def rotate_keypoint_positions_world( keypoints_pos_lookup: dict[str, np.ndarray], rotation_matrix: np.ndarray -): +) -> dict[str, np.ndarray]: keypoints_pos_lookup_rotated = {} for body_segment_name, pos_before_rotation in keypoints_pos_lookup.items(): # Apply the rotation matrix to the keypoints @@ -304,15 +319,80 @@ def process_single_frame( # Save object segmentation masks seg_labels = segment_label_parser(rendered_images_transformed) + # Transform mesh states to camera's coordinate system + pos_glob = h5_file["body_segment_states/pos_global"][frame_idx, :, :] + quat_glob = h5_file["body_segment_states/quat_global"][frame_idx, :, :] + cam_projmat = h5_file["camera_matrix"][frame_idx, :, :] # 3x4 mat from mujoco + pos_rel_cam_pre_alignment, quat_rel_cam_pre_alignment = ( + calculate_6dpose_relative_to_camera(pos_glob, quat_glob, cam_projmat) + ) + # As before, we're rotating the camera capture post hoc so that the fly is + # always upright. Therefore, upon computing the position and rotation of the + # mesh relative to the camera, we still need to rotate them around the z axis + alignment_rotation = z_rotation = Rotation.from_rotvec([0, 0, rotation_angle]) + pos_rel_cam_aligned = alignment_rotation.apply(pos_rel_cam_pre_alignment) + quat_rel_cam_aligned = ( + alignment_rotation + * Rotation.from_quat(quat_rel_cam_pre_alignment, scalar_first=True) + ).as_quat(scalar_first=True) + derived_variables["pos_rel_cam"] = pos_rel_cam_aligned + derived_variables["quat_rel_cam"] = quat_rel_cam_aligned + return rendered_images_transformed, derived_variables, seg_labels +def calculate_6dpose_relative_to_camera(pos_global, quat_global, cam_projmat): + """Convert global mesh positions and orientations to camera-relative coordinates. + + Args: + pos_global: (num_segments, 3) array of global positions + quat_global: (num_segments, 4) array of global quaternions (scalar first) + cam_projmat: (3, 4) camera projection matrix + + Returns: + pos_rel_cam: (num_segments, 3) array of positions relative to camera + quat_rel_cam: (num_segments, 4) array of quaternions relative to camera + """ + assert pos_global.shape[0] == quat_global.shape[0], "Number of segments mismatch." + n_segments = pos_global.shape[0] + + # Decompose camera projection matrix to get camera intrinsics, rotation, translation + cam_intrinsics, cam_rotation = rq(cam_projmat[:, :3]) + _sign_multiplier = np.diag(np.sign(np.diag(cam_intrinsics))) + cam_intrinsics = cam_intrinsics @ _sign_multiplier + cam_rotation = _sign_multiplier @ cam_rotation + if np.linalg.det(cam_rotation) < 0: + cam_rotation = -cam_rotation # ensure proper rotation matrix (det = 1) + cam_intrinsics = -cam_intrinsics + cam_translation = np.linalg.inv(cam_intrinsics) @ cam_projmat[:, 3] + + # Compute rotation from world to camera coordinates + rot_world_to_cam = Rotation.from_matrix(cam_rotation) + + # Convert each segment's position and orientation + pos_rel_cam = np.zeros_like(pos_global) + quat_rel_cam = np.zeros_like(quat_global) + for seg_idx in range(n_segments): + this_pos_glob = pos_global[seg_idx, :] + this_quat_glob = quat_global[seg_idx, :] + + this_pos_rel_cam = cam_rotation @ this_pos_glob + cam_translation + mesh_rot = Rotation.from_quat(this_quat_glob, scalar_first=True) + mesh_rot_rel_cam = rot_world_to_cam * mesh_rot + this_quat_rel_cam = mesh_rot_rel_cam.as_quat(scalar_first=True) + + pos_rel_cam[seg_idx, :] = this_pos_rel_cam + quat_rel_cam[seg_idx, :] = this_quat_rel_cam + + return pos_rel_cam, quat_rel_cam + + def process_subsegment( frames_by_color_coding: list[list[np.ndarray]], segment_h5_file_path: Path, frames_range: tuple[int, int], processed_subsegment_dir: Path, - fps: int, + fps: int | float, crop_size: int = 464, num_color_codings: int = 2, n_jobs: int = -1, @@ -488,15 +568,20 @@ def process_subsegment( ) # Add mesh state labels - seg_states_grp = postprocessed_group.create_group("body_segment_states") + seg_states_grp = postprocessed_group.create_group("mesh_pose6d_rel_camera") + seg_states_grp.create_dataset( + "pos_rel_cam", + data=np.array(derived_variables_by_key["pos_rel_cam"]), + dtype="float32", + compression="lzf", + ) + seg_states_grp.create_dataset( + "quat_rel_cam", + data=np.array(derived_variables_by_key["quat_rel_cam"]), + dtype="float32", + compression="lzf", + ) seg_states_grp.attrs.update(source_h5_file["body_segment_states"].attrs) - for sensor_type in source_h5_file["body_segment_states"].keys(): - source_ds = source_h5_file["body_segment_states"][sensor_type] - seg_states_grp.create_dataset( - sensor_type, - data=source_ds[frame_idx_start:frame_idx_end, :, :], - dtype="float32", - ) def _draw_pose_2d_and_3d( @@ -663,16 +748,18 @@ def visualize_single_frame( def visualize_subsegment( processed_subsegment_dir: Path, - fps: int, + render_fps: float | None = None, camera_elevation: float = 30.0, max_abs_azimuth: float = 30.0, azimuth_rotation_period: float = 300.0, n_jobs: int = -1, # Default to all available cores ) -> None: # Load video frames - frames, fps = load_video_frames( + frames, input_fps = load_video_frames( processed_subsegment_dir / "processed_nmf_sim_render_colorcode_0.mp4" ) + if render_fps is None: + render_fps = input_fps # Find processed simulation data processed_data_path = processed_subsegment_dir / "processed_simulation_data.h5" @@ -728,7 +815,7 @@ def visualize_subsegment( # Merge video with imageio.get_writer( str(processed_subsegment_dir / "visualization.mp4"), - fps=fps, + fps=render_fps, codec="libx264", quality=10, # 10 is highest for imageio, lower is lower quality ffmpeg_params=["-crf", "18", "-preset", "slow"], # lower crf = higher quality @@ -742,7 +829,7 @@ def visualize_subsegment( def select_subsegments( - upward_cardinal_vectors: np.array, + upward_cardinal_vectors: np.ndarray, max_tilt_angle_deg: float, mask_morph_closing_size_sec: float, min_subsegment_duration_sec: float, @@ -837,7 +924,7 @@ def postprocess_segment( if visualize: visualize_subsegment( processed_subsegment_dir=output_dir, - fps=fps, + render_fps=fps, camera_elevation=camera_elevation, max_abs_azimuth=max_abs_azimuth, azimuth_rotation_period=azimuth_rotation_period, diff --git a/src/poseforge/neuromechfly/scripts/run_simulation.py b/src/poseforge/neuromechfly/scripts/run_simulation.py index c875731..2969a0b 100644 --- a/src/poseforge/neuromechfly/scripts/run_simulation.py +++ b/src/poseforge/neuromechfly/scripts/run_simulation.py @@ -55,7 +55,7 @@ from poseforge.neuromechfly.data import load_kinematic_recording # from poseforge.neuromechfly.simulate import simulate_one_segment # TODO: revert from poseforge.neuromechfly.postprocessing import postprocess_segment -from poseforge.util import get_hardware_availability +from poseforge.util import get_hardware_availability, bulk_data_dir def simulate_using_kinematic_prior( @@ -155,7 +155,7 @@ def simulate_using_kinematic_prior( def run_sequentially_for_testing(): """Run everything sequentially (for debugging)""" # Configs - output_basedir = Path("bulk_data/nmf_rendering_new/") # TODO: change back to *_test + output_basedir = bulk_data_dir / "nmf_rendering/" # TODO: change back to *_test input_timestep = 0.01 sim_timestep = 0.0001 # trial_paths = [ @@ -163,7 +163,7 @@ def run_sequentially_for_testing(): # Path("bulk_data/kinematic_prior/aymanns2022/trials/BO_Gal4_fly1_trial001.pkl") # ] trial_paths = sorted( # TODO: revert - Path("bulk_data/kinematic_prior/aymanns2022/trials/").glob("*.pkl") + (bulk_data_dir / "kinematic_prior/aymanns2022/trials/").glob("*.pkl") ) # Limit scope of simulation as this is only for testing diff --git a/src/poseforge/neuromechfly/scripts/visualize_meshes.py b/src/poseforge/neuromechfly/scripts/visualize_meshes.py index cb6554e..368c5c1 100644 --- a/src/poseforge/neuromechfly/scripts/visualize_meshes.py +++ b/src/poseforge/neuromechfly/scripts/visualize_meshes.py @@ -1,125 +1,296 @@ import numpy as np -import pandas as pd import pyvista as pv import h5py +from tqdm import trange from xml.etree import ElementTree from pathlib import Path +from dataclasses import dataclass from scipy.spatial.transform import Rotation -from scipy.linalg import rq +from pyvista import PolyData +from loguru import logger from poseforge.neuromechfly.constants import legs, all_segment_names_per_leg +from poseforge.util.data import bulk_data_dir -# Define paths -subsegment_dir = Path( - "bulk_data/nmf_rendering/BO_Gal4_fly1_trial001/segment_000/subsegment_000" -) -sim_data_path = subsegment_dir / "processed_simulation_data.h5" -flygym_data_dir = Path("~/projects/flygym/flygym").expanduser() / "data" -nmf_mesh_dir = flygym_data_dir / "mesh" -mjcf_path = flygym_data_dir / "mjcf/neuromechfly_seqik_kinorder_ypr.xml" - -# Load NeuroMechFly model -mjcf_tree = ElementTree.parse(mjcf_path) -worldbody = mjcf_tree.find("worldbody") -body_attributes = {body.attrib["name"]: body.attrib for body in worldbody.iter("body")} - -# Load simulation data -with h5py.File(sim_data_path, "r") as f: - all_seg_pos_global = f["raw/body_segment_states/pos_global"][:] - all_seg_quat_global = f["raw/body_segment_states/quat_global"][:] - all_cam_matrices = f["raw/camera_matrix"][:] - all_seg_names = list(f["raw/body_segment_states"].attrs["keys"]) - -n_frames = all_seg_pos_global.shape[0] -segments_to_include = [ - f"{leg}{seg}" for leg in legs for seg in all_segment_names_per_leg -] -segments_to_include += ["Thorax"] - -# Load original meshes once (before any transformations) -original_meshes = {} -for seg_name in segments_to_include: - mesh_file = nmf_mesh_dir / f"{seg_name}.stl" - original_meshes[seg_name] = pv.read(mesh_file) - -# Create plotter -plotter = pv.Plotter() -plotter.set_background("black") -plotter.show_axes() - -# Add all meshes to plotter initially -current_meshes = {} -for seg_name in segments_to_include: - current_meshes[seg_name] = original_meshes[seg_name].copy() - plotter.add_mesh( - current_meshes[seg_name], - show_edges=False, - name=seg_name, - smooth_shading=True - ) +def load_neuromechfly_meshes(flygym_path: Path | str) -> dict[str, PolyData]: + """Load NeuroMechFly meshes from the specified FlyGym installation path.""" + # Define paths + mesh_dir = Path(flygym_path) / "data/mesh" + mjcf_path = Path(flygym_path) / "data/mjcf/neuromechfly_seqik_kinorder_ypr.xml" + + # Load NeuroMechFly MJCF file (which is just an XML file) + mjcf_tree = ElementTree.parse(mjcf_path) + + # Load body segment meshes + mesh_lookup = {} + for mesh_element in mjcf_tree.findall("asset/mesh"): # findall iterates in order + mesh_path = mesh_dir / Path(mesh_element.attrib["file"]).name + mesh_name = mesh_element.attrib["name"] + if not mesh_name.startswith("mesh_"): + raise RuntimeError( + "Unexpected mesh specified in MJCF: name doesn't start with 'mesh_'" + ) + mesh_name = mesh_name[len("mesh_") :] + mesh_scale_str = mesh_element.attrib["scale"] + try: + mesh_scale_xyz = [float(s) for s in mesh_scale_str.split()] + except ValueError as e: + raise RuntimeError( + "Unexpected mesh specified in MJCF: scale not following format " + "'xscale yscale zscale' (each a number)" + ) from e -# Current frame tracker -current_frame = [0] - - -def update_frame(): - """Update all meshes to the current frame""" - frame_idx = current_frame[0] - cam_matrix = all_cam_matrices[frame_idx, :, :] - - # Compute camera transformation once per frame - cam_intrinsics, cam_rotation = rq(cam_matrix[:, :3]) - _sign_multiplier = np.diag(np.sign(np.diag(cam_intrinsics))) - cam_intrinsics = cam_intrinsics @ _sign_multiplier - cam_rotation = _sign_multiplier @ cam_rotation - cam_translation = np.linalg.inv(cam_intrinsics) @ cam_matrix[:, 3] - - transform_world2cam = np.eye(4) - transform_world2cam[:3, :3] = cam_rotation - transform_world2cam[:3, 3] = cam_translation - transform_cam2world = np.linalg.inv(transform_world2cam) - - # Update each segment mesh - for seg_name in segments_to_include: - seg_idx = all_seg_names.index(seg_name) - translation = all_seg_pos_global[frame_idx, seg_idx, :] - quaternion = all_seg_quat_global[frame_idx, seg_idx, :] - - # Start with original mesh - mesh = original_meshes[seg_name].copy() - - # Scale + # Load mesh and apply initial scaling transform + # (lengths in NeuroMechFly simulations are in mm instead of m - this is how NMF + # tells MuJoCo to scale everything accordingly) + mesh: PolyData = pv.read(mesh_path) + if mesh is None: + raise RuntimeError(f"Mesh '{mesh_name}' from '{mesh_path}' is empty") scale_transform = np.eye(4) - np.fill_diagonal(scale_transform, [1000, 1000, 1000, 1]) + np.fill_diagonal(scale_transform, [*mesh_scale_xyz, 1]) mesh = mesh.transform(scale_transform, inplace=False) - - # Apply MuJoCo state transformation - placement_transform = np.eye(4) - rotation_object = Rotation.from_quat(quaternion, scalar_first=True) - placement_transform[:3, :3] = rotation_object.as_matrix() - placement_transform[:3, 3] = translation - mesh = mesh.transform(placement_transform, inplace=False) - - # Transform to camera coordinates - mesh = mesh.transform(transform_cam2world, inplace=False) - - # Update the mesh points in place - current_meshes[seg_name].points[:] = mesh.points - - # Update frame counter - current_frame[0] = (current_frame[0] + 1) % n_frames - - # Update title to show current frame - plotter.add_text(f"Frame: {frame_idx}/{n_frames}", name="frame_counter", position="upper_left") - - -# Initialize first frame -update_frame() -plotter.reset_camera() - -# Add timer callback for animation (30 fps) -# The callback receives a step argument, so we need to accept it -plotter.add_timer_event(max_steps=n_frames, duration=int(1000/30), callback=lambda step: update_frame()) - -plotter.show() \ No newline at end of file + mesh_lookup[mesh_name] = mesh + + return mesh_lookup + + +@dataclass +class Pose6DSequence: + # XYZ position relative to camera (n_frames, n_segments, 3) + pos_ts: np.ndarray + # Quaternion rotation relative to camera in scalar-first order (wxyz) + # (n_frames, n_segments, 4) + quat_ts: np.ndarray + # Names of body segments included in pos_ts and quat_ts + segments: list[str] + + @classmethod + def from_processed_simulation_data( + cls, + data_path: Path, + segments: list[str] | None, + relative_to_camera: bool = True, + ) -> "Pose6DSequence": + """ + Load 6D pose sequence from processed NeuroMechFly simulation data file. + + Args: + data_path: Path to processed simulation data HDF5 file + segments: List of segment names to load. If None, load all recorded segments + relative_to_camera: If True, load segment poses relative to camera. + If False, load raw global segment poses (pre-postprocessing). + + Returns: + An instance of Pose6DSequence containing the loaded data + """ + pos_ts, quat_ts, segments = _load_simulation_data( + data_path, segments, load_raw=not relative_to_camera + ) + return cls(pos_ts=pos_ts, quat_ts=quat_ts, segments=segments) + + def __len__(self): + return self.pos_ts.shape[0] + + def render( + self, + mesh_assets: dict[str, PolyData], + render_fps: int, + output_path: Path | str | None = None, + display_live: bool = False, + theme: str = "dark", + disable_pbar: bool = False, + ) -> None: + """Render the 6D pose sequence using NeuroMechFly meshes in PyVista. + + Args: + mesh_assets: Dictionary mapping segment names to PyVista PolyData meshes. + render_fps: Frames per second for rendering. + output_path: If display_live is False, path to save the rendered video. + display_live: If True, display the rendering live instead of saving to file. + theme: Display theme, either "dark" or "light". + disable_pbar: If True, disable the progress bar when saving to file. + """ + # Filter & check mesh files from NeuroMechFly + mesh_assets = {k: v for k, v in mesh_assets.items() if k in self.segments} + missing_keys = set(self.segments) - set(mesh_assets.keys()) + if missing_keys: + logger.critical( + "The following segments from Pose6DSequence are not found in " + f"NeuroMechFly meshes: {missing_keys}" + ) + raise KeyError("Some meshes are missing") + + # Set up rendering + plotter, plotted_meshes = _set_up_renderer(mesh_assets, display_live, theme) + + def _update_scene_by_frameid(frameid): + _transform_all_meshes( + mesh_assets, + plotted_meshes, + self.pos_ts[frameid, :, :], + self.quat_ts[frameid, :, :], + self.segments, + ) + + # Set up first frame + _update_scene_by_frameid(0) + plotter.reset_camera() + + # Render each frame + if display_live: + if output_path is not None: + logger.warning( + "Output path is specified but display_live is set to true. " + "Nothing will be saved and the output path will be ignored." + ) + + plotter.add_timer_event( + max_steps=len(self), + duration=int(1000 / render_fps), # in ms + callback=_update_scene_by_frameid, + ) + plotter.show() + else: + if output_path is None: + logger.critical( + "When Pose6DSequenceRenderer.display_live is set to false, " + "output_path must be specified." + ) + raise ValueError("Video output path not specified.") + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + plotter.open_movie(output_path, framerate=render_fps) + for i in trange(len(self), disable=disable_pbar): + _update_scene_by_frameid(i) + plotter.write_frame() + plotter.close() + + +def _load_simulation_data( + processed_data_path: Path | str, segments: list[str] | None, load_raw: bool = False +) -> tuple[np.ndarray, np.ndarray, list[str]]: + """Load 6D pose data from processed NeuroMechFly simulation data file. + + Args: + processed_data_path: Path to processed simulation data HDF5 file. + segments: List of segment names to load. If None, load all recorded segments. + load_raw: If True, load raw global segment poses (pre-postprocessing). If False, + load postprocessed segment poses relative to camera. + + Returns: + pos_ts: Array of shape (n_frames, n_segments, 3) with XYZ positions. + quat_ts: Array of shape (n_frames, n_segments, 4) with quaternion rotations + in scalar-first (wxyz) order. + segments: List of segment names corresponding to the loaded data. + """ + with h5py.File(processed_data_path, "r") as f: + if load_raw: + grp = f["raw/body_segment_states/"] + all_recorded_segments = list(grp.attrs["keys"]) + pos_ts = grp["pos_global"][:] + quat_ts = grp["quat_global"][:] + else: + grp = f["postprocessed/mesh_pose6d_rel_camera"] # relative to camera + all_recorded_segments = list(grp.attrs["keys"]) + # grp["pos_rel_cam"][:, seg_indices, :] would be better, but h5py only + # supports index-based fancy indexing if the indices are monotonic, so meh + pos_ts = grp["pos_rel_cam"][:] + quat_ts = grp["quat_rel_cam"][:] + if segments is None: + segments = all_recorded_segments + else: + seg2idx = {seg: i for i, seg in enumerate(all_recorded_segments)} + try: + seg_indices = np.array([seg2idx[seg] for seg in segments]) + except KeyError as e: + raise KeyError(f"Some segments not found in simulation data") from e + pos_ts = pos_ts[:, seg_indices, :] + quat_ts = quat_ts[:, seg_indices, :] + + n_frames, n_segments, _ = pos_ts.shape + assert pos_ts.shape == (n_frames, n_segments, 3), "Invalid shape for mesh pos_ts" + assert quat_ts.shape == (n_frames, n_segments, 4), "Invalid shape for mesh quat_ts" + return pos_ts, quat_ts, segments + + +def _set_up_renderer( + nmf_meshes: dict[str, PolyData], display_live: bool = False, theme: str = "dark" +) -> tuple[pv.Plotter, dict[str, PolyData]]: + """Set up PyVista plotter and add NeuroMechFly meshes to it""" + # Set up plotter + plotter = pv.Plotter(off_screen=not display_live) + plotter.show_axes() + + # Apply display theme + if theme.lower() == "dark": + plotter.set_background("black") + mesh_color = "#eeeeee" + elif theme.lower() == "light": + plotter.set_background("white") + mesh_color = "#bbbbbb" + else: + raise ValueError(f"Undefined display theme '{theme}'") + + # Add meshes initially (don't care about positions) + plotted_meshes = {} + for seg_name, mesh in nmf_meshes.items(): + mesh = nmf_meshes[seg_name].copy() + plotter.add_mesh(mesh, show_edges=False, name=seg_name, color=mesh_color) + plotted_meshes[seg_name] = mesh + + # Set up camera + plotter.camera.position = (0, 0, 0) + plotter.camera.focal_point = (0, 0, 1) # look down +Z axis + plotter.camera.up = (0, -1, 0) # -Y is up (OpenCV/MuJoCo convention) + + return plotter, plotted_meshes + + +def _transform_one_mesh(mesh: PolyData, pos: np.ndarray, quat: np.ndarray) -> PolyData: + """Apply 6D pose (translation + quaternion rotation) to a mesh""" + rot = Rotation.from_quat(quat, scalar_first=True) # MuJoCo uses wxyz order + transform = np.eye(4) + transform[:3, :3] = rot.as_matrix() + transform[:3, 3] = pos + return mesh.transform(transform, inplace=False) + + +def _transform_all_meshes( + mesh_assets: dict[str, PolyData], + plotted_meshes: dict[str, PolyData], + pos_all_segments: np.ndarray, + quat_all_segments: np.ndarray, + segment_names: list[str], +) -> None: + """Update all meshes to the specified 6D poses""" + assert pos_all_segments.shape == (len(segment_names), 3) + assert quat_all_segments.shape == (len(segment_names), 4) + for i, seg_name in enumerate(segment_names): + transformed_mesh = _transform_one_mesh( + mesh_assets[seg_name], pos_all_segments[i, :], quat_all_segments[i, :] + ) + plotted_meshes[seg_name].points = transformed_mesh.points + + +if __name__ == "__main__": + flygym_dir = Path("~/projects/flygym/flygym").expanduser() + sample_sim_path = ( + bulk_data_dir + / "nmf_rendering/BO_Gal4_fly1_trial001/segment_000/subsegment_001/processed_simulation_data.h5" + ) + replayed_segments = [ + f"{leg}{seg}" for leg in legs for seg in all_segment_names_per_leg + ] + ["Thorax"] + + nmf_mesh_assets = load_neuromechfly_meshes(flygym_dir) + sim_data = Pose6DSequence.from_processed_simulation_data( + sample_sim_path, segments=replayed_segments, relative_to_camera=True + ) + sim_data.render( + mesh_assets=nmf_mesh_assets, + render_fps=33, + output_path=sample_sim_path.parent / "pose6d_render.mp4", + display_live=True, + theme="light", + ) diff --git a/src/poseforge/pose/bodyseg/model.py b/src/poseforge/pose/bodyseg/model.py index b8cf7af..a45df5c 100644 --- a/src/poseforge/pose/bodyseg/model.py +++ b/src/poseforge/pose/bodyseg/model.py @@ -103,11 +103,12 @@ def load_weights_from_config( and weights_config.model_weights is None ): logging.warning("weights_config contains nothing useful. No action taken.") + return # If full model weights are provided, load them directly if weights_config.model_weights is not None: checkpoint_path = Path(weights_config.model_weights) - if not checkpoint_path.is_file(): + if not checkpoint_path.exists(): raise ValueError(f"Model weights path {checkpoint_path} is not a file") weights = torch.load(checkpoint_path, map_location="cpu") self.load_state_dict(weights) @@ -157,9 +158,7 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: └──(bottleneck/identity)──┘ """ # Run feature extractor - e0, e1, e2, e3, e4 = self.feature_extractor.forward( - x, return_intermediates=True - ) + e0, e1, e2, e3, e4 = self.feature_extractor.forward_with_intermediates(x) d4 = e4 # this is just the bottleneck diff --git a/src/poseforge/pose/bodyseg/pipeline.py b/src/poseforge/pose/bodyseg/pipeline.py index c185595..485d620 100644 --- a/src/poseforge/pose/bodyseg/pipeline.py +++ b/src/poseforge/pose/bodyseg/pipeline.py @@ -98,7 +98,7 @@ def train( # Set up optimizer optimizer = self._create_optimizer(optimizer_config) - # Set up mixed-point training + # Set up mixed-precision training grad_scaler = torch.amp.GradScaler(self.device_type, enabled=self.use_float16) self._check_amp_status_for_model_params( grad_scaler, subtitle="Model parameters before training" @@ -118,10 +118,9 @@ def train( epoch_start_time = time() running_start_time = time() - for step_idx, (atomic_batches_frames, atomic_batches_sim_data) in enumerate( - train_loader - ): + for step_idx, atomic_batches in enumerate(train_loader): # Format data + atomic_batches_frames, atomic_batches_sim_data = atomic_batches frames, sim_data = atomic_batches_to_simple_batch( atomic_batches_frames, atomic_batches_sim_data, device=self.device ) @@ -169,9 +168,8 @@ def train( for k, x in running_loss_dict.items() } time_now = time() - throughput = artifacts_config.logging_interval / ( - time_now - running_start_time - ) + time_elapsed = time_now - running_start_time + throughput = artifacts_config.logging_interval / time_elapsed running_loss_dict = defaultdict(lambda: 0.0) running_start_time = time_now @@ -245,6 +243,7 @@ def validate( raise ValueError("Loss function must be provided for validation") total_loss_dict = defaultdict(lambda: 0.0) + n_steps_iterated = 0 self.model.eval() with torch.no_grad(): for step_idx, (atomic_batches_frames, atomic_batches_sim_data) in enumerate( @@ -274,6 +273,7 @@ def validate( # Accumulate losses for key, loss in loss_dict.items(): total_loss_dict[key] += loss.item() + n_steps_iterated += 1 del ( atomic_batches_frames, @@ -284,7 +284,6 @@ def validate( ) clear_memory_cache() self.model.train() - n_steps_iterated = step_idx + 1 return {k: v / n_steps_iterated for k, v in total_loss_dict.items()} def inference(self, frames: torch.Tensor) -> dict[str, torch.Tensor]: diff --git a/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py b/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py deleted file mode 100644 index 0e5ca59..0000000 --- a/src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch -import h5py -from pathlib import Path -from torchvision.transforms import Resize -from torchsummary import summary -from tqdm import tqdm -from pvio.torch_tools import SimpleVideoCollectionLoader - -import poseforge.pose.bodyseg.config as config -from poseforge.pose.bodyseg import BodySegmentationModel, BodySegmentationPipeline -from poseforge.util.sys import get_hardware_availability -from poseforge.util.data import OutputBuffer - - -def test_bodyseg_model( - input_basedir: Path, - model_dir: Path, - model_checkpoint_path: Path, - output_basedir: Path | None = None, - batch_size: int = 512, - n_workers: int = 16, - inference_image_size: tuple[int, int] = (256, 256), - output_buffer_log_interval: int = 10, -): - # System setup - hardware_avail = get_hardware_availability(check_gpu=True, print_results=True) - if len(hardware_avail["gpus"]) == 0: - raise RuntimeError("No GPU available for testing") - torch.backends.cudnn.benchmark = True - - # Find all trials to process - input_trials = list(input_basedir.glob("*/model_prediction/not_flipped/")) - print(f"Found {len(list(input_trials))} trials to process") - input_trials = [trial for trial in input_trials if len(list(trial.iterdir())) > 0] - print(f"{len(input_trials)} trials have images to process") - - # Create dataset and dataloader - transform = Resize(inference_image_size) - dataloader = SimpleVideoCollectionLoader( - input_trials, transform=transform, batch_size=batch_size, num_workers=n_workers - ) - print(f"Found {len(dataloader.dataset)} frames to process") - print( - f"Using batch size {dataloader.batch_size} with {dataloader.num_workers} " - f"workers. This will generate {len(dataloader)} batches." - ) - - # Create model and learning pipeline - architecture_config_path = model_dir / "configs/model_architecture_config.yaml" - model_weights = config.ModelWeightsConfig(model_weights=model_checkpoint_path) - model = BodySegmentationModel.create_architecture_from_config( - architecture_config_path - ).cuda() - model.load_weights_from_config(model_weights) - summary(model, (3, *inference_image_size)) - pipeline = BodySegmentationPipeline(model, device="cuda", use_float16=True) - - # Make an output buffer - output data for multiple videos will arrive out of sync - def save_predictions(input_video_idx, data_items): - video_obj = dataloader.dataset.videos[input_video_idx] - input_video_path = video_obj.path - exp_trial_name = "_".join(input_video_path.parts[-3:]) - out_dir = output_basedir / exp_trial_name - out_dir.mkdir(parents=True, exist_ok=True) - with h5py.File(out_dir / f"bodyseg_pred.h5", "w") as f: - pred_segmaps = torch.stack([x[0] for x in data_items], dim=0).cpu().numpy() - ds = f.create_dataset( - "pred_segmap", - data=pred_segmaps, - dtype="uint8", - compression="gzip", - shuffle=True, - ) - ds.attrs["class_labels"] = pipeline.class_labels - confs = torch.stack([x[1] for x in data_items], dim=0).cpu().numpy() - ds = f.create_dataset( - "pred_confidence", - data=confs, - dtype="uint8", - compression="gzip", - shuffle=True, - ) - # Confidence is predicted in 0-1, but we store it in 0-100 as uint8 - ds.attrs["scale"] = 100 - ds.attrs["method"] = model.confidence_method - frame_ids = [ - int(p.stem.split("_")[1]) - for p in video_obj.phy_frame_id_to_path.values() - ] - # These are the actual, raw frame IDs from the original video assigned by - # the Spotlight recording software. They may not be contiguous because - # frames where the fly is upside down or too close to the edge, etc. are - # already removed. - f.create_dataset( - "frame_ids", - data=frame_ids, - dtype="int", - compression="gzip", - shuffle=True, - ) - - buckets_and_sizes = { - i: n_frames for i, n_frames in enumerate(dataloader.dataset.n_frames_by_video) - } - output_buffer = OutputBuffer( - buckets_and_expected_sizes=buckets_and_sizes, - closing_func=save_predictions, - ) - - # Run inference - for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): - # No need to move data to and from the GPU, pipeline will do that - pred_dict = pipeline.inference(batch["frames"]) - logits = pred_dict["logits"] - pred_seg = torch.argmax(logits, dim=1).to(torch.uint8).detach().cpu() - confidence = (pred_dict["confidence"] * 100).to(torch.uint8).detach().cpu() - - for i in range(logits.shape[0]): - data_item = (pred_seg[i, :, :], confidence[i, :, :]) - output_buffer.add_data( - bucket=batch["video_indices"][i], - index=batch["frame_indices"][i], - data=data_item, - ) - if (batch_idx + 1) % output_buffer_log_interval == 0: - print( - f"{batch_idx + 1}/{len(dataloader)} batches - " - f"{output_buffer.n_open_buckets} partially processed videos, " - f"{output_buffer.n_data_total} total frames in buffer" - ) - - assert output_buffer.n_data_total == 0 - assert output_buffer.n_open_buckets == 0 - print("Inference complete") - - -if __name__ == "__main__": - input_basedir = Path("bulk_data/behavior_images/spotlight_aligned_and_cropped/") - model_dir = Path("bulk_data/pose_estimation/bodyseg/trial_20251118a") - batch_size = 192 - n_workers = 16 - inference_image_size = (256, 256) - output_buffer_log_interval = 10 - epoch = 14 # chosen by validation performance and visual inspection - step = 12000 # last step of each epoch - - model_checkpoint_path = model_dir / f"checkpoints/epoch{epoch}_step{step}.model.pth" - output_basedir = model_dir / f"production/epoch{epoch}_step{step}/" - output_basedir.mkdir(parents=True, exist_ok=True) - test_bodyseg_model( - input_basedir=input_basedir, - model_dir=model_dir, - model_checkpoint_path=model_checkpoint_path, - output_basedir=output_basedir, - batch_size=batch_size, - n_workers=n_workers, - inference_image_size=inference_image_size, - output_buffer_log_interval=output_buffer_log_interval, - ) diff --git a/src/poseforge/pose/bodyseg/scripts/run_inference_on_spotlight_data.py b/src/poseforge/pose/bodyseg/scripts/run_inference_on_spotlight_data.py new file mode 100644 index 0000000..fa593c5 --- /dev/null +++ b/src/poseforge/pose/bodyseg/scripts/run_inference_on_spotlight_data.py @@ -0,0 +1,151 @@ +import torch +import h5py +from loguru import logger +from pathlib import Path +from torchvision.transforms import Resize +from torchsummary import summary +from tqdm import tqdm +from pvio.torch_tools import SimpleVideoCollectionLoader + +import poseforge.pose.bodyseg.config as config +from poseforge.pose.bodyseg import BodySegmentationModel, BodySegmentationPipeline +from poseforge.util.sys import get_hardware_availability + + +def test_bodyseg_model( + *, + input_dir: Path, + model_dir: Path, + model_checkpoint_path: Path, + save_prob: bool = False, + output_basedir: Path | None = None, + batch_size: int = 512, + n_workers: int = 16, + inference_image_size: tuple[int, int] = (256, 256), +): + # System setup + hardware_avail = get_hardware_availability(check_gpu=True, print_results=True) + if len(hardware_avail["gpus"]) == 0: + raise RuntimeError("No GPU available for testing") + torch.backends.cudnn.benchmark = True + + # Create dataset and dataloader + input_img_dir = input_dir / "model_prediction/not_flipped/" + if len(list(input_img_dir.iterdir())) == 0: + logger.warning(f"Trial {input_img_dir} is empty - skipping") + return + transform = Resize(inference_image_size) + dataloader = SimpleVideoCollectionLoader( + [input_img_dir], + transform=transform, + batch_size=batch_size, + num_workers=n_workers, + ) + n_frames = len(dataloader.dataset) + logger.info(f"Found {n_frames} frames to process") + logger.info( + f"Using batch size {dataloader.batch_size} with {dataloader.num_workers} " + f"workers. This will generate {len(dataloader)} batches." + ) + + # Create model and learning pipeline + architecture_config_path = model_dir / "configs/model_architecture_config.yaml" + model_weights = config.ModelWeightsConfig(model_weights=model_checkpoint_path) + model = BodySegmentationModel.create_architecture_from_config( + architecture_config_path + ).cuda() + model.load_weights_from_config(model_weights) + summary(model, (3, *inference_image_size)) + pipeline = BodySegmentationPipeline(model, device="cuda", use_float16=True) + + # Set up output H5 files + exp_trial_name = input_img_dir.parts[-3] + out_dir = output_basedir / exp_trial_name + out_dir.mkdir(parents=True, exist_ok=True) + f_pred = h5py.File(out_dir / f"bodyseg_pred.h5", "w") + f_prob = h5py.File(out_dir / f"bodyseg_prob.h5", "w") if save_prob else None + ds_segmap = f_pred.create_dataset( + "pred_segmap", + shape=(n_frames, inference_image_size[0], inference_image_size[1]), + dtype="uint8", + compression="gzip", + shuffle=True, + ) + ds_segmap.attrs["class_labels"] = pipeline.class_labels + ds_conf = f_pred.create_dataset( + "pred_confidence", + shape=(n_frames, inference_image_size[0], inference_image_size[1]), + dtype="uint8", + compression="gzip", + shuffle=True, + ) + ds_conf.attrs["scale"] = 100 # transform from 0-1 to 0-100 for uint8 storage + ds_conf.attrs["method"] = model.confidence_method + if save_prob: + ds_probs = f_prob.create_dataset( + "pred_probabilities", + shape=( + n_frames, + model.n_classes, + inference_image_size[0], + inference_image_size[1], + ), + dtype="uint8", + compression="gzip", + shuffle=True, + ) + ds_probs.attrs["scale"] = 100 # transform from 0-1 to 0-100 for uint8 storage + ds_probs.attrs["class_labels"] = pipeline.class_labels + + # Inference loop + log_interval = max(len(dataloader) // 10, 1) + for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): + # Forward pass (no need to move data to and from the GPU; pipeline will do that) + pred_dict = pipeline.inference(batch["frames"]) + logits = pred_dict["logits"] + + # Save outputs + frame_ids = batch["frame_indices"] + pred_seg = torch.argmax(logits, dim=1).to(torch.uint8).detach().cpu() + conf = (pred_dict["confidence"] * 100).to(torch.uint8).detach().cpu() + ds_segmap[frame_ids, :, :] = pred_seg.numpy() + ds_conf[frame_ids, :, :] = conf.numpy() + if save_prob: + prob = (torch.softmax(logits, dim=1) * 100).to(torch.uint8).detach().cpu() + ds_probs[frame_ids, :, :, :] = prob.numpy() + + if (batch_idx + 1) % log_interval == 0 or (batch_idx + 1) == len(dataloader): + logger.info(f"Processed batch {batch_idx + 1} / {len(dataloader)}") + + logger.info("Inference complete") + + +if __name__ == "__main__": + from poseforge.util.sys import set_loguru_level + + set_loguru_level("INFO") + + input_basedir = Path("bulk_data/behavior_images/spotlight_aligned_and_cropped/") + model_dir = Path("bulk_data/pose_estimation/bodyseg/trial_20251127a") + batch_size = 192 + n_workers = 16 + inference_image_size = (256, 256) + output_buffer_log_interval = 10 + epoch = 8 # chosen by validation performance and visual inspection + step = 18335 # last step of each epoch + + model_checkpoint_path = model_dir / f"checkpoints/epoch{epoch}_step{step}.model.pth" + output_basedir = model_dir / f"production/epoch{epoch}_step{step}/" + output_basedir.mkdir(parents=True, exist_ok=True) + for input_dir in sorted(input_basedir.iterdir()): + print(f"Processing {input_dir}") + test_bodyseg_model( + input_dir=input_dir, + model_dir=model_dir, + model_checkpoint_path=model_checkpoint_path, + output_basedir=output_basedir, + batch_size=batch_size, + n_workers=n_workers, + inference_image_size=inference_image_size, + save_prob=True, + ) diff --git a/src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py b/src/poseforge/pose/bodyseg/scripts/run_training.py similarity index 100% rename from src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py rename to src/poseforge/pose/bodyseg/scripts/run_training.py diff --git a/src/poseforge/pose/bodyseg/scripts/visualize_bodyseg_predictions.py b/src/poseforge/pose/bodyseg/scripts/visualize_bodyseg_predictions.py index 02b1cb7..8c14d7a 100644 --- a/src/poseforge/pose/bodyseg/scripts/visualize_bodyseg_predictions.py +++ b/src/poseforge/pose/bodyseg/scripts/visualize_bodyseg_predictions.py @@ -223,13 +223,13 @@ def visualize_bodyseg_prediction( label_alpha = 0.3 n_workers = -1 - epoch = 14 # chosen by validation performance and visual inspection - step = 12000 # last step of each epoch + epoch = 8 # chosen by validation performance and visual inspection + step = 18335 # last step of each epoch pred_basedir = Path( - f"bulk_data/pose_estimation/bodyseg/trial_20251118a/production/epoch{epoch}_step{step}/" + f"bulk_data/pose_estimation/bodyseg/trial_20251127a/production/epoch{epoch}_step{step}/" ) - pred_path = pred_basedir / f"{trial}_model_prediction_not_flipped/bodyseg_pred.h5" - output_path = pred_basedir / f"{trial}_model_prediction_not_flipped/viz.mp4" + pred_path = pred_basedir / f"{trial}/bodyseg_pred.h5" + output_path = pred_basedir / f"{trial}/viz.mp4" visualize_bodyseg_prediction( recording_dir=recording_dir, diff --git a/src/poseforge/pose/common.py b/src/poseforge/pose/common.py index 5663f6b..d3a1a21 100644 --- a/src/poseforge/pose/common.py +++ b/src/poseforge/pose/common.py @@ -88,35 +88,35 @@ def _apply_imagenet_normalization( x_normalized = (x - mean) / std return x_normalized - def forward(self, x, return_intermediates: bool = False): - """ + def forward_with_intermediates( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Run forward pass, returning intermediate feature maps as well as + the final features. + Args: x (torch.Tensor): Input image tensor of shape (batch_size, 3, height, width), with pixel values in [0, 1]. - return_intermediates (bool): Whether to return intermediate - feature maps from various layers. Default False. Returns: - If return_intermediates is False: - features (torch.Tensor): Extracted features. The shape is - (batch_size, out_channels, *output_feature_map_size) - where output_feature_map_size depends on the input - image size. - If return_intermediates is True: - A tuple of 5 torch.Tensors: - - Features after initial Conv-BN-ReLU but before maxpool: - tensor of shape (batch_size, 64, 128, 128) - - Features after layer1: - tensor of shape (batch_size, 64, 64, 64) - - Features after layer2: - tensor of shape (batch_size, 128, 32, 32) - - Features after layer3: - tensor of shape (batch_size, 256, 16, 16) - - Features after layer4: - tensor of shape (batch_size, 512, 8, 8) - This is the same as the single output returned if - return_intermediates is False. + A tuple of 5 torch.Tensors: + - Features after initial Conv-BN-ReLU but before maxpool: + tensor of shape (batch_size, 64, 128, 128) + - Features after layer1: + tensor of shape (batch_size, 64, 64, 64) + - Features after layer2: + tensor of shape (batch_size, 128, 32, 32) + - Features after layer3: + tensor of shape (batch_size, 256, 16, 16) + - Features after layer4: + tensor of shape (batch_size, 512, 8, 8) """ + if x.shape[-2:] != self.input_size: + raise NotImplementedError( + f"Input image size {x.shape[-2:]} not supported; " + f"expected {self.input_size}" + ) + x_norm = self._apply_imagenet_normalization(x) # Remove the final classification head (avgpool + fc) @@ -149,10 +149,28 @@ def forward(self, x, return_intermediates: bool = False): assert x4.shape == (batch_size, 512, 8, 8) self._first_time_forward = False - if return_intermediates: - return x0, x1, x2, x3, x4 - else: - return x4 + return x0, x1, x2, x3, x4 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run forward pass through the ResNet-18 backbone and return the + final extracted features. + + See also `.forward_with_intermediates`, which returns intermediate + feature maps as well (useful for identity connections in U-Net-like + architectures). + + Args: + x (torch.Tensor): Input image tensor of shape (batch_size, 3, + height, width), with pixel values in [0, 1]. + + Returns: + features (torch.Tensor): Extracted features. The shape is + (batch_size, out_channels, *output_feature_map_size) + where output_feature_map_size depends on the input + image size. + """ + x0, x1, x2, x3, x4 = self.forward_with_intermediates(x) + return x4 class DecoderBlock(nn.Module): diff --git a/src/poseforge/pose/data/scripts/_attach_metadata.py b/src/poseforge/pose/data/scripts/_attach_metadata.py new file mode 100644 index 0000000..31db7c7 --- /dev/null +++ b/src/poseforge/pose/data/scripts/_attach_metadata.py @@ -0,0 +1,187 @@ +import h5py +from pathlib import Path +from tqdm import tqdm + +KEYS = { + "body_seg_maps": [ + "Background", + "OtherSegments", + "Thorax", + "LFCoxa", + "LFFemur", + "LFTibia", + "LFTarsus", + "LMCoxa", + "LMFemur", + "LMTibia", + "LMTarsus", + "LHCoxa", + "LHFemur", + "LHTibia", + "LHTarsus", + "RFCoxa", + "RFFemur", + "RFTibia", + "RFTarsus", + "RMCoxa", + "RMFemur", + "RMTibia", + "RMTarsus", + "RHCoxa", + "RHFemur", + "RHTibia", + "RHTarsus", + "LAntenna", + "RAntenna", + ], + "dof_angles": [ + "LFThC_pitch", + "LFThC_roll", + "LFThC_yaw", + "LFCTr_pitch", + "LFCTr_roll", + "LFFTi_pitch", + "LFTiTa_pitch", + "LMThC_pitch", + "LMThC_roll", + "LMThC_yaw", + "LMCTr_pitch", + "LMCTr_roll", + "LMFTi_pitch", + "LMTiTa_pitch", + "LHThC_pitch", + "LHThC_roll", + "LHThC_yaw", + "LHCTr_pitch", + "LHCTr_roll", + "LHFTi_pitch", + "LHTiTa_pitch", + "RFThC_pitch", + "RFThC_roll", + "RFThC_yaw", + "RFCTr_pitch", + "RFCTr_roll", + "RFFTi_pitch", + "RFTiTa_pitch", + "RMThC_pitch", + "RMThC_roll", + "RMThC_yaw", + "RMCTr_pitch", + "RMCTr_roll", + "RMFTi_pitch", + "RMTiTa_pitch", + "RHThC_pitch", + "RHThC_roll", + "RHThC_yaw", + "RHCTr_pitch", + "RHCTr_roll", + "RHFTi_pitch", + "RHTiTa_pitch", + ], + "keypoint_pos": [ + "LFCoxa", + "LFFemur", + "LFTibia", + "LFTarsus1", + "LFTarsus5", + "LMCoxa", + "LMFemur", + "LMTibia", + "LMTarsus1", + "LMTarsus5", + "LHCoxa", + "LHFemur", + "LHTibia", + "LHTarsus1", + "LHTarsus5", + "RFCoxa", + "RFFemur", + "RFTibia", + "RFTarsus1", + "RFTarsus5", + "RMCoxa", + "RMFemur", + "RMTibia", + "RMTarsus1", + "RMTarsus5", + "RHCoxa", + "RHFemur", + "RHTibia", + "RHTarsus1", + "RHTarsus5", + "LPedicel", + "RPedicel", + ], + "mesh_pos": [ + "Thorax", + "LFCoxa", + "LFFemur", + "LFTibia", + "LFTarsus1", + "LHCoxa", + "LHFemur", + "LHTibia", + "LHTarsus1", + "LMCoxa", + "LMFemur", + "LMTibia", + "LMTarsus1", + "RHCoxa", + "RHFemur", + "RHTibia", + "RHTarsus1", + "RMCoxa", + "RMFemur", + "RMTibia", + "RMTarsus1", + "RFCoxa", + "RFFemur", + "RFTibia", + "RFTarsus1", + ], + "mesh_quat": [ + "Thorax", + "LFCoxa", + "LFFemur", + "LFTibia", + "LFTarsus1", + "LHCoxa", + "LHFemur", + "LHTibia", + "LHTarsus1", + "LMCoxa", + "LMFemur", + "LMTibia", + "LMTarsus1", + "RHCoxa", + "RHFemur", + "RHTibia", + "RHTarsus1", + "RMCoxa", + "RMFemur", + "RMTibia", + "RMTarsus1", + "RFCoxa", + "RFFemur", + "RFTibia", + "RFTarsus1", + ], +} + + +def attach_metadata(path: Path): + with h5py.File(path, "a") as f: + for k, keys in KEYS.items(): + f[k].attrs["keys"] = keys + + +if __name__ == "__main__": + basedir = Path(".") + all_files = sorted(basedir.glob("**/*_labels.h5")) + print(f"Found {len(all_files)} files. Continue?") + resp = input("Type 'yes' to continue: ") + if resp.lower() != "yes": + print("Aborting.") + exit(0) + for file_path in tqdm(all_files): + attach_metadata(file_path) diff --git a/src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py b/src/poseforge/pose/data/scripts/preextract_atomic_batches.py similarity index 95% rename from src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py rename to src/poseforge/pose/data/scripts/preextract_atomic_batches.py index 05979fe..91f8c88 100644 --- a/src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py +++ b/src/poseforge/pose/data/scripts/preextract_atomic_batches.py @@ -172,11 +172,11 @@ def process_batch(batch_idx: int): AtomicBatchDataset.save_atomic_batch_frames( atomic_batch_frames, output_dir / f"{filename_stem}_frames.mp4", - fps=sampler.fps, - spacing=video_spacing, + sampler.fps, + video_spacing, ) AtomicBatchDataset.save_atomic_batch_sim_data( - labels, output_dir / f"{filename_stem}_labels.h5" + output_dir / f"{filename_stem}_labels.h5", labels, sampler.label_keys ) if batch_idx % logging_interval == 0: elapsed = time() - start_time @@ -222,13 +222,14 @@ def process_batch(batch_idx: int): # --output-dir bulk_data/pose_estimation/atomic_batches/BO_Gal4_fly1_trial001 # # Example usage not via CLI + # from poseforge.util import bulk_data_dir # extract_atomic_batches( # atomic_batch_nframes=32, # atomic_batch_nvariants_max=4, # minimum_time_diff_frames=60, - # input_basedir="bulk_data/style_transfer/production/translated_videos", - # nmf_sim_rendering_basedir="bulk_data/nmf_rendering/", - # output_dir="bulk_data/pose_estimation/atomic_batches_test", + # input_basedir=bulk_data_dir / "style_transfer/production/translated_videos", + # nmf_sim_rendering_basedir=bulk_data_dir / "nmf_rendering/", + # output_dir=bulk_data_dir / "pose_estimation/atomic_batches_test", # original_image_size=(464, 464), # n_jobs=-1, # logging_interval=100, diff --git a/src/poseforge/pose/data/synthetic/atomic_batch.py b/src/poseforge/pose/data/synthetic/atomic_batch.py index d447a3f..b2bbba5 100644 --- a/src/poseforge/pose/data/synthetic/atomic_batch.py +++ b/src/poseforge/pose/data/synthetic/atomic_batch.py @@ -2,6 +2,7 @@ import numpy as np import h5py import logging +from typing import Callable from torch.utils.data import Dataset, DataLoader from pathlib import Path from pvio.io import read_frames_from_video, write_frames_to_video @@ -20,7 +21,11 @@ def __init__( load_dof_angles: bool = False, load_keypoint_positions: bool = False, load_body_segment_maps: bool = False, + load_mesh_pose6d: bool = False, + transform: Callable | None = None, ): + self.transform = transform + # Find all .h5 and .mp4 files in the provided directories all_h5_files = set() all_mp4_files = set() @@ -70,6 +75,9 @@ def __init__( self.label_keys.append("keypoint_pos") if load_body_segment_maps: self.label_keys.append("body_seg_maps") + if load_mesh_pose6d: + self.label_keys.append("mesh_pos") + self.label_keys.append("mesh_quat") def __len__(self): return len(self.atomic_batches) @@ -91,6 +99,10 @@ def __getitem__(self, idx): # Load labels data sim_data = self.load_atomic_batch_sim_data(h5_path, self.label_keys) + # Apply transform if provided + if self.transform is not None: + frames, sim_data = self.transform(frames, sim_data) + return frames, sim_data @staticmethod @@ -196,9 +208,9 @@ def load_atomic_batch_frames( @staticmethod def save_atomic_batch_sim_data( - sim_data: dict[str, np.ndarray], output_path: Path, - metadata: dict | None = None, + sim_data: dict[str, np.ndarray], + label_keys: dict[str, list[str]], ): """Save simulation data for an atomic batch to an HDF5 file. @@ -214,11 +226,9 @@ def save_atomic_batch_sim_data( # these files will be accessed very frequently during training, so we # use lzf (faster than gzip) and no shuffling to optimize for speed compression = "lzf" if key == "body_seg_maps" else None - f.create_dataset(key, data=value, compression=compression) + ds = f.create_dataset(key, data=value, compression=compression) + ds.attrs["keys"] = label_keys[key] f.attrs["n_frames"] = next(iter(sim_data.values())).shape[0] - if metadata is not None: - for key, value in metadata.items(): - f.attrs[key] = value @staticmethod def load_atomic_batch_sim_data( @@ -358,12 +368,15 @@ def init_atomic_dataset_and_dataloader( load_dof_angles: bool = False, load_keypoint_positions: bool = False, load_body_segment_maps: bool = False, + load_mesh_pose6d: bool = False, shuffle: bool = False, n_workers: int | None = None, n_channels: int = 3, pin_memory: bool = True, drop_last: bool = True, prefetch_factor: int | None = None, + transform: Callable | None = None, + return_index: bool = False, ): """ Initializes an AtomicBatchDataset and a corresponding DataLoader for @@ -385,6 +398,8 @@ def init_atomic_dataset_and_dataloader( positions. Defaults to False. load_body_segment_maps (bool, optional): Whether to load body segment maps. Defaults to False. + load_mesh_pose6d (bool, optional): Whether to load mesh 6D pose. + Defaults to False. shuffle (bool, optional): Whether to shuffle the data. Defaults to False. n_workers (int | None, optional): Number of worker threads for @@ -396,6 +411,8 @@ def init_atomic_dataset_and_dataloader( batch. Defaults to True. prefetch_factor (int | None, optional): Number of samples to load in advance by each worker. If None, uses PyTorch default. + transform (Callable | None, optional): Optional transform to apply + to each atomic batch when loading. Defaults to None. Returns: dataset (AtomicBatchDataset): @@ -412,6 +429,8 @@ def init_atomic_dataset_and_dataloader( load_dof_angles=load_dof_angles, load_keypoint_positions=load_keypoint_positions, load_body_segment_maps=load_body_segment_maps, + load_mesh_pose6d=load_mesh_pose6d, + transform=transform, ) # Check if batch size is valid diff --git a/src/poseforge/pose/data/synthetic/sampler.py b/src/poseforge/pose/data/synthetic/sampler.py index d0e3daf..3beb873 100644 --- a/src/poseforge/pose/data/synthetic/sampler.py +++ b/src/poseforge/pose/data/synthetic/sampler.py @@ -105,6 +105,19 @@ def __init__( assert self.sampling_stride >= 1, "`sampling_stride` must be >= 1" assert len(simulated_data_sequences) > 0, "At least 1 simulation required" + # Check if all simulations have the same metadata + metadata_li = [seq.get_sim_data_metadata() for seq in simulated_data_sequences] + md_ref = metadata_li[0] + for metadata in metadata_li[1:]: + assert metadata == md_ref, "All simulations must have the same metadata" + self.label_keys = { + "dof_angles": md_ref["dof_angles"]["keys"], + "keypoint_pos": md_ref["keypoint_pos"]["keys"], + "mesh_pos": md_ref["mesh_pose6d"]["keys"], + "mesh_quat": md_ref["mesh_pose6d"]["keys"], + "body_seg_maps": md_ref["segmentation_labels"]["keys"], + } + # Check number of variants per frame assert ( len(set([seq.n_variants for seq in simulated_data_sequences])) == 1 @@ -113,14 +126,14 @@ def __init__( # Check image size and FPS _image_sizes = set() - _fpss = set() + _fps = set() for seq in simulated_data_sequences: _image_sizes.add(seq.frame_size) - _fpss.add(seq.fps) + _fps.add(seq.fps) assert len(_image_sizes) == 1, "All simulations must have the same image size" - assert len(_fpss) == 1, "All simulations must have the same FPS" + assert len(_fps) == 1, "All simulations must have the same FPS" self.frame_size = _image_sizes.pop() - self.fps = _fpss.pop() + self.fps = _fps.pop() # Check number of frames per simulation and in total self.n_frames_per_sim = np.array( diff --git a/src/poseforge/pose/data/synthetic/sim_data_seq.py b/src/poseforge/pose/data/synthetic/sim_data_seq.py index 837cc90..9c33d02 100644 --- a/src/poseforge/pose/data/synthetic/sim_data_seq.py +++ b/src/poseforge/pose/data/synthetic/sim_data_seq.py @@ -6,6 +6,8 @@ from typing import Iterator from pvio.io import get_video_metadata, read_frames_from_video +from poseforge.neuromechfly.constants import segments_for_6dpose + class SimulatedDataSequence: def __init__( @@ -54,13 +56,14 @@ def __init__( def get_sim_data_metadata(self) -> dict: metadata = {} - with h5py.File(self.simulated_labels_path, "r") as ds: - postprocessed_ds = ds["postprocessed"] - metadata["dof_angles"] = dict(postprocessed_ds["dof_angles"].attrs) - metadata["keypoint_pos"] = dict(postprocessed_ds["keypoint_pos"].attrs) - metadata["segmentation_labels"] = dict( - postprocessed_ds["segmentation_labels"].attrs - ) + with h5py.File(self.simulated_labels_path, "r") as f: + ds = f["postprocessed"] + metadata["dof_angles"] = dict(ds["dof_angles"].attrs) + metadata["keypoint_pos"] = dict(ds["keypoint_pos"].attrs) + metadata["segmentation_labels"] = dict(ds["segmentation_labels"].attrs) + metadata["mesh_pose6d"] = dict(ds["mesh_pose6d_rel_camera"].attrs) + for inner_dict in metadata.values(): + inner_dict["keys"] = list(inner_dict["keys"]) return metadata def __len__(self) -> int: @@ -119,14 +122,14 @@ def read_simulated_labels( *, load_dof_angles: bool = True, load_keypoint_pos: bool = True, - load_mesh_states: bool = False, + load_mesh_states: bool = True, load_body_seg_maps: bool = True, ) -> dict[str, np.ndarray]: self._check_frame_indices_validity(frame_indices) labels = {} - with h5py.File(self.simulated_labels_path, "r") as ds: - ds = ds["postprocessed"] + with h5py.File(self.simulated_labels_path, "r") as f: + ds = f["postprocessed"] if load_dof_angles: labels["dof_angles"] = ds["dof_angles"][frame_indices, :] @@ -143,10 +146,16 @@ def read_simulated_labels( labels["keypoint_pos"] = keypoint_pos if load_mesh_states: - raise NotImplementedError( - "Mesh states tracking (xyz + quat for 3D rotation) has not been " - "implemented yet" + pose6d_grp = ds["mesh_pose6d_rel_camera"] + all_avail_segments = pose6d_grp.attrs["keys"] + seg_mask = np.array( + [name in segments_for_6dpose for name in all_avail_segments] ) + # h5py only supports fancy indexing along one axis + mesh_pos = pose6d_grp["pos_rel_cam"][frame_indices, :, :] + labels["mesh_pos"] = mesh_pos[:, seg_mask, :] + mesh_quat = pose6d_grp["quat_rel_cam"][frame_indices, :, :] + labels["mesh_quat"] = mesh_quat[:, seg_mask, :] if load_body_seg_maps: seg_labels_ds = ds["segmentation_labels"] diff --git a/src/poseforge/pose/keypoints3d/model.py b/src/poseforge/pose/keypoints3d/model.py index 4a414ff..bddfa83 100644 --- a/src/poseforge/pose/keypoints3d/model.py +++ b/src/poseforge/pose/keypoints3d/model.py @@ -416,9 +416,7 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: └──(bottleneck/identity)──┘ """ # Run feature extractor - e0, e1, e2, e3, e4 = self.feature_extractor.forward( - x, return_intermediates=True - ) + e0, e1, e2, e3, e4 = self.feature_extractor.forward_with_intermediates(x) d4 = e4 # this is just the bottleneck diff --git a/src/poseforge/pose/keypoints3d/pipeline.py b/src/poseforge/pose/keypoints3d/pipeline.py index 12ef8b9..cc16d14 100644 --- a/src/poseforge/pose/keypoints3d/pipeline.py +++ b/src/poseforge/pose/keypoints3d/pipeline.py @@ -76,7 +76,7 @@ def train( # Set up optimizer optimizer = self._create_optimizer(optimizer_config) - # Set up mixed-point training + # Set up mixed-precision training grad_scaler = torch.amp.GradScaler(self.device_type, enabled=self.use_float16) self._check_amp_status_for_model_params( grad_scaler, subtitle="Model parameters before training" @@ -232,6 +232,7 @@ def validate( if max_batches <= 0: raise ValueError("max_batches must be positive or None") total_loss_dict = defaultdict(lambda: 0.0) + n_steps_iterated = 0 if self.loss_func is None: raise ValueError("Loss function must be provided for validation") @@ -263,10 +264,10 @@ def validate( # Accumulate losses for key, loss in loss_dict.items(): total_loss_dict[key] += loss.item() + n_steps_iterated += 1 clear_memory_cache() self.model.train() - n_steps_iterated = step_idx + 1 return {k: v / n_steps_iterated for k, v in total_loss_dict.items()} def inference(self, frames: torch.Tensor) -> dict[str, torch.Tensor]: diff --git a/src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py b/src/poseforge/pose/keypoints3d/scripts/run_inference_on_spotlight_data.py similarity index 99% rename from src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py rename to src/poseforge/pose/keypoints3d/scripts/run_inference_on_spotlight_data.py index 84d68a4..42a260c 100644 --- a/src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py +++ b/src/poseforge/pose/keypoints3d/scripts/run_inference_on_spotlight_data.py @@ -200,8 +200,8 @@ def run_keypoints3d_inference( if __name__ == "__main__": input_basedir = Path("bulk_data/behavior_images/spotlight_aligned_and_cropped/") - model_dir = Path("bulk_data/pose_estimation/keypoints3d/trial_20251118a") - epoch = 19 # chosen based on validation performance and visual inspection + model_dir = Path("bulk_data/pose_estimation/keypoints3d/trial_20251127a") + epoch = 15 # chosen based on validation performance and visual inspection step = 9167 # last step print(f"Running inference for epoch {epoch}") diff --git a/src/poseforge/pose/keypoints3d/scripts/test_keypoints3d_models.py b/src/poseforge/pose/keypoints3d/scripts/run_inference_on_synthetic_data.py similarity index 100% rename from src/poseforge/pose/keypoints3d/scripts/test_keypoints3d_models.py rename to src/poseforge/pose/keypoints3d/scripts/run_inference_on_synthetic_data.py diff --git a/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py b/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py index 418557a..b046680 100644 --- a/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py +++ b/src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py @@ -613,6 +613,7 @@ def process_all( max_n_frames=max_n_frames, n_workers=n_workers_per_dataset, debug_plots_dir=keypoints3d_output_file.parent / "ik_debug_plots/", + parallel_over_time=False, ) _save_seqikpy_output( output_path, joint_angles, forward_kinematics, frame_ids=frame_ids @@ -828,10 +829,10 @@ def _save_seqikpy_output( # --input-images-basedir bulk_data/behavior_images/spotlight_aligned_and_cropped/ # * Processing from this script directly - epoch = 19 # these must be consistent with run_keypoints3d_inference.py + epoch = 15 # these must be consistent with run_inference_on_spotlight_data.py step = 9167 # same as above production_model_basedir = Path( - f"bulk_data/pose_estimation/keypoints3d/trial_20251118a/production/epoch{epoch}_step{step}/" + f"bulk_data/pose_estimation/keypoints3d/trial_20251127a/production/epoch{epoch}_step{step}/" ) input_images_basedir = Path( "bulk_data/behavior_images/spotlight_aligned_and_cropped/" diff --git a/src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py b/src/poseforge/pose/keypoints3d/scripts/run_training.py similarity index 100% rename from src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py rename to src/poseforge/pose/keypoints3d/scripts/run_training.py diff --git a/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py b/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py index 9fdda06..c577609 100644 --- a/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py +++ b/src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py @@ -374,7 +374,8 @@ def visualize_predictions( with h5py.File(inference_output_path, "r") as f: frame_ids = f["frame_ids"][:] # 3D world coordinates (n_frames, n_kp, 3) - # Upstream inference currently writes this as 'keypoints_world_xyz' in run_keypoints3d_inference + # Upstream inference currently writes this as 'keypoints_world_xyz' in + # run_inference_on_spotlight_data.py if "keypoints_pos" in f: keypoints_pos = f["keypoints_pos"][:] keypoints_order = list(f["keypoints_pos"].attrs["keypoints"]) @@ -507,9 +508,9 @@ def visualize_predictions( if __name__ == "__main__": input_basedir = Path("bulk_data/behavior_images/spotlight_aligned_and_cropped/") - model_dir = Path("bulk_data/pose_estimation/keypoints3d/trial_20251118a") + model_dir = Path("bulk_data/pose_estimation/keypoints3d/trial_20251127a") recordings = ["20250613-fly1b-002"] - epoch = 19 # these must be consistent with run_keypoints3d_inference.py + epoch = 15 # these must be consistent with run_inference_on_spotlight_data.py step = 9167 # same as above for recording in recordings: diff --git a/src/poseforge/pose/pose6d/__init__.py b/src/poseforge/pose/pose6d/__init__.py new file mode 100644 index 0000000..8b734d9 --- /dev/null +++ b/src/poseforge/pose/pose6d/__init__.py @@ -0,0 +1,12 @@ +from .model import Pose6DModel, Pose6DLoss +from .pipeline import Pose6DPipeline, ComputeBodysegProbs +from . import config + + +__all__ = [ + "Pose6DModel", + "Pose6DLoss", + "Pose6DPipeline", + "ComputeBodysegProbs", + "config", +] diff --git a/src/poseforge/pose/pose6d/config.py b/src/poseforge/pose/pose6d/config.py new file mode 100644 index 0000000..dfacd2a --- /dev/null +++ b/src/poseforge/pose/pose6d/config.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass + +from poseforge.util import SerializableDataClass + + +@dataclass(frozen=True) +class ModelArchitectureConfig(SerializableDataClass): + # Number of object segments. + # Default: 6 legs * (coxa, femur, tibia, tarsus1) + thorax = 25 + n_segments: int = 25 + # Number of feature channels that are gated by a per-segment attention mechanism + n_attention_gated_feature_channels: int = 128 + # Number of feature channels that are not gated by attention + n_global_feature_channels: int = 128 + # Camera distance + camera_distance: float = 100.0 + + +@dataclass(frozen=True) +class ModelWeightsConfig(SerializableDataClass): + # Feature extractor weights. Can be a path to the (contrastively) pretrained weights + # or "IMAGENET1K_V1" + feature_extractor_weights: str | None = None + # Model weights, optional. If provided, the model will be initialized from these + # weights (in which case feature_extractor_weights is ignored). + model_weights: str | None = None + + +@dataclass(frozen=True) +class LossConfig(SerializableDataClass): + # Weight for the translation loss term + translation_weight: float = 1.0 + # Weight for the rotation loss term + rotation_weight: float = 1.0 + + +@dataclass(frozen=True) +class TrainingDataConfig(SerializableDataClass): + # Paths to training data (recursively containing atomic batches) + train_data_dirs: list[str] + # Paths to validation data (recursively containing atomic batches) + val_data_dirs: list[str] + # Frame size (height, width) + input_image_size: tuple[int, int] + # Numbers of samples (frames) in each pre-extracted atomic batch + atomic_batch_n_samples: int + # Number of variants (synthetic images made by different style transfer models) + atomic_batch_n_variants: int + # Number of different frames to include in each batch. Note that n_variants variants + # of each frame will be included, so effective batch size = + # train_batch_size * n_variants. + # This must be a multiple of `atomic_batch_n_samples` in `AtomicBatchDataset`. + train_batch_size: int + # Validation batch size. Can be much smaller than train_batch_size. Must be + # a multiple of `atomic_batch_n_samples` in `AtomicBatchDataset` + val_batch_size: int + # Number of workers for data loading. Use number of CPU cores if None. + n_workers: int | None = None + + +@dataclass(frozen=True) +class OptimizerConfig(SerializableDataClass): + learning_rate_encoder: float = 3e-5 + learning_rate_upsample: float = 3e-4 + learning_rate_attention: float = 3e-4 + learning_rate_pose6d_heads: float = 3e-4 + weight_decay: float = 1e-5 + + +@dataclass(frozen=True) +class TrainingArtifactsConfig(SerializableDataClass): + # Base directory to save logs and model checkpoints + output_basedir: str + # Log training metrics every N steps + logging_interval: int = 10 + # Save model checkpoint every N steps (NOT EPOCHS!) + checkpoint_interval: int = 500 + # Run validation every N steps (NOT EPOCHS!) + validation_interval: int = 500 + # Number of batches to use for each validation (useful if validation set is large) + n_batches_per_validation: int = 300 diff --git a/src/poseforge/pose/pose6d/model.py b/src/poseforge/pose/pose6d/model.py new file mode 100644 index 0000000..42f3ddc --- /dev/null +++ b/src/poseforge/pose/pose6d/model.py @@ -0,0 +1,273 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from loguru import logger +from pathlib import Path + +import poseforge.pose.pose6d.config as config +from poseforge.pose.common import ResNetFeatureExtractor, DecoderBlock + + +class Pose6DModel(nn.Module): + def __init__( + self, + n_segments: int, + feature_extractor: ResNetFeatureExtractor, + n_attention_gated_feature_channels: int, + n_global_feature_channels: int, + camera_distance: float, + ): + super(Pose6DModel, self).__init__() + self.n_segments = n_segments + self.feature_extractor = feature_extractor + self.n_attention_gated_feature_channels = n_attention_gated_feature_channels + self.n_global_feature_channels = n_global_feature_channels + self.n_feature_channels_total = ( + n_attention_gated_feature_channels + n_global_feature_channels + ) + self.camera_distance = camera_distance + # MuJoCo convention: in front of camera = negative z, so floor is at + # (*, *, -camera_distance) + self.z_center = -camera_distance + + # Create decoder (decoder1/2/3/4 mirror encoder layers 1/2/3/4) + # Note that when upsampling, we actually run decoder4 first, decoder1 last + # Also note that the number of channels along the decoding path is higher than + # in keypoints3d and bodyseg models. This because here we will eventually pool + # the features globally for FC pose6d heads, so we want more channels to retain + # information. + self.dec_layer4 = DecoderBlock(512, 256, 512) + self.dec_layer3 = DecoderBlock(512, 128, 256) + self.dec_layer2 = DecoderBlock(256, 64, self.n_feature_channels_total) + + # Attention layer + if n_attention_gated_feature_channels > 0: + _attention_heads = [ + nn.Sequential( + nn.Conv2d( + in_channels=self.n_feature_channels_total + n_segments, + out_channels=128, + kernel_size=3, + padding=1, + ), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1), + nn.Sigmoid(), + ) + for _ in range(n_segments) + ] + self.attention_heads = nn.ModuleList(_attention_heads) + else: + self.attention_heads = None + + # 6D pose prediction head (one per segment) + _pose6d_heads = [ + nn.Sequential( + nn.Linear(in_features=self.n_feature_channels_total, out_features=512), + nn.ReLU(inplace=True), + nn.Dropout(p=0.3), + nn.Linear(in_features=512, out_features=256), + nn.ReLU(inplace=True), + nn.Dropout(p=0.3), + nn.Linear(in_features=256, out_features=7), + ) + for _ in range(n_segments) + ] + self.pose6d_heads = nn.ModuleList(_pose6d_heads) + + @classmethod + def create_from_config( + cls, arch_config: config.ModelArchitectureConfig | Path | str + ) -> "Pose6DModel": + # Load from file if config is a path + if isinstance(arch_config, (str, Path)): + arch_config = config.ModelArchitectureConfig.load(arch_config) + logger.info(f"Loaded model architecture config from {arch_config}") + + # Initialize feature extractor (WITHOUT WEIGHTS at this step!) + feature_extractor = ResNetFeatureExtractor() + + # Initialize Pose6DModel (self) from config (WITHOUT WEIGHTS at this step!) + obj = cls( + n_segments=arch_config.n_segments, + feature_extractor=feature_extractor, + n_attention_gated_feature_channels=arch_config.n_attention_gated_feature_channels, + n_global_feature_channels=arch_config.n_global_feature_channels, + camera_distance=arch_config.camera_distance, + ) + + logger.info("Initialized Pose6DModel from architecture config") + return obj + + def load_weights_from_config( + self, weights_config: config.ModelWeightsConfig | Path | str + ): + # Load from file if config is given as a path + if isinstance(weights_config, (str, Path)): + weights_config = config.ModelWeightsConfig.load(weights_config) + logger.info(f"Loaded model weights config from {weights_config}") + + # Check if config has either feature extractor weights or model weights + if ( + weights_config.feature_extractor_weights is None + and weights_config.model_weights is None + ): + logger.warning("weights_config contains nothing useful. No action taken") + return + + # If full model weights are provided, load them directly + if weights_config.model_weights is not None: + checkpoint_path = Path(weights_config.model_weights) + if not checkpoint_path.exists(): + logger.critical(f"Model weights path {checkpoint_path} does not exist") + raise FileNotFoundError(f"Model weights file does not exist") + weights = torch.load(checkpoint_path, map_location="cpu") + self.load_state_dict(weights) + logger.info( + f"Loaded Pose6DModel weights (inc. feature extractor) " + f"from {checkpoint_path}" + ) + return + + # Otherwise, init feature extractor weights if provided + self.feature_extractor = ResNetFeatureExtractor( + # Path, str, or "IMAGENET1K_V1" + weights=weights_config.feature_extractor_weights + ) + logger.info( + f"Initialized feature extractor weights from " + f"{weights_config.feature_extractor_weights}" + ) + + def forward(self, image: torch.Tensor, bodyseg_prob: torch.Tensor) -> dict: + bat_size = image.shape[0] + assert image.shape == (bat_size, 3, 256, 256) + + # Run feature extractor and upsample to 64x64 + e0, e1, e2, e3, e4 = self.feature_extractor.forward_with_intermediates(image) + d4 = e4 # this is just the bottleneck + d3 = self.dec_layer4(d4, e3) + d2 = self.dec_layer3(d3, e2) + features = self.dec_layer2(d2, e1) # (B, n_feature_channels_total, 64, 64) + + # Attention using bodyseg prediction happens at 64x64 resolution (stride 4) + assert features.shape == (bat_size, self.n_feature_channels_total, 64, 64) + assert bodyseg_prob.shape == (bat_size, self.n_segments, 64, 64) + # (B, n_channels_feature_map + n_segments, H, W) + feature_map_with_seg_prob = torch.cat([features, bodyseg_prob], dim=1) + + # Process each segment separately + translation_pred_list = [] + quaternion_pred_list = [] + for seg_idx in range(self.n_segments): + # Global features + global_features_pooled = torch.mean( + features[:, : self.n_global_feature_channels, :, :], dim=[2, 3] + ) + + # Attention gated features + if self.n_attention_gated_feature_channels > 0: + attn_head = self.attention_heads[seg_idx] + attn_map = attn_head(feature_map_with_seg_prob) # (B, 1, H, W) + gated_features = ( + features[:, -self.n_attention_gated_feature_channels :, :, :] + * attn_map + ) + attn_weight_total = torch.sum(attn_map, dim=[2, 3]) + 1e-6 # (B, 1) + gated_features_pooled = ( + gated_features.sum(dim=[2, 3]) / attn_weight_total + ) # (B, C) + all_features_pooled = torch.cat( + [global_features_pooled, gated_features_pooled], dim=1 + ) + else: + all_features_pooled = global_features_pooled + + # FC layers for 6D pose prediction + head = self.pose6d_heads[seg_idx] + pose6d_pred = head(all_features_pooled) # (B, 7) + translation_pred = pose6d_pred[:, 0:3] # (B, 3) + quaternion_pred = pose6d_pred[:, 3:7] # (B, 4) + quaternion_pred = F.normalize(quaternion_pred, p=2, dim=1) # normalize quat + translation_pred_list.append(translation_pred) + quaternion_pred_list.append(quaternion_pred) + + # Stack predictions into single tensors of shape (B, n_segments, 3 or 4) + translation_pred_all = torch.stack(translation_pred_list, dim=1) + translation_pred_all[:, :, 2] += self.z_center + quaternion_pred_all = torch.stack(quaternion_pred_list, dim=1) + + return translation_pred_all, quaternion_pred_all + + +class Pose6DLoss(nn.Module): + def __init__(self, translation_weight: float = 1.0, rotation_weight: float = 1.0): + super(Pose6DLoss, self).__init__() + self.translation_weight = translation_weight + self.rotation_weight = rotation_weight + + def forward( + self, + translation_pred: torch.Tensor, + quaternion_pred: torch.Tensor, + translation_label: torch.Tensor, + quaternion_label: torch.Tensor, + ) -> torch.Tensor: + # Compute losses + translation_loss = self.translation_loss( + translation_pred.view(-1, 3), translation_label.view(-1, 3) + ) + rotation_loss = self.rotation_loss( + quaternion_pred.view(-1, 4), quaternion_label.view(-1, 4) + ) + total_loss = ( + self.translation_weight * translation_loss + + self.rotation_weight * rotation_loss + ) + + return { + "translation_loss_unweighted": translation_loss, + "rotation_loss_unweighted": rotation_loss, + "total_loss": total_loss, + } + + @classmethod + def create_from_config( + cls, loss_config: config.LossConfig | Path | str + ) -> "Pose6DLoss": + # Load from file if config is a path + if isinstance(loss_config, (str, Path)): + loss_config = config.LossConfig.load(loss_config) + logger.info(f"Loaded loss config from {loss_config}") + + obj = cls( + translation_weight=loss_config.translation_weight, + rotation_weight=loss_config.rotation_weight, + ) + logger.info("Initialized Pose6DLoss from loss config") + return obj + + @staticmethod + def translation_loss( + translation_pred: torch.Tensor, + translation_label: torch.Tensor, + ) -> torch.Tensor: + """Computes L1 translation loss. Both predictions and labels are should have + shape (N, 3), where N is the number of valid samples.""" + return F.l1_loss(translation_pred, translation_label) + + @staticmethod + def rotation_loss( + quaternion_pred: torch.Tensor, + quaternion_label: torch.Tensor, + ) -> torch.Tensor: + """Computes rotation loss based on quaternion dot product. Both predictions and + labels are should have shape (N, 4), where N is the number of valid samples.""" + # Normalize both quaternions to unit length (note: dim 1 is quat components) + quaternion_pred = F.normalize(quaternion_pred, p=2, dim=1) + quaternion_label = F.normalize(quaternion_label, p=2, dim=1) + + # Compute absolute quaternion dot product (note: dim 1 is quat components) + dot_prod = torch.abs(torch.sum(quaternion_pred * quaternion_label, dim=1)) + + return torch.mean(1.0 - dot_prod) diff --git a/src/poseforge/pose/pose6d/pipeline.py b/src/poseforge/pose/pose6d/pipeline.py new file mode 100644 index 0000000..9fa10e4 --- /dev/null +++ b/src/poseforge/pose/pose6d/pipeline.py @@ -0,0 +1,528 @@ +import torch +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +from collections import defaultdict +from time import time +from datetime import datetime +from pathlib import Path +from itertools import chain +from tqdm import tqdm +from loguru import logger + +import poseforge.util as util +import poseforge.pose.data.synthetic as synth_data +import poseforge.pose.pose6d.config as config +from poseforge.pose.pose6d.model import Pose6DModel, Pose6DLoss +from poseforge.pose.bodyseg import BodySegmentationPipeline +from poseforge.neuromechfly.constants import segments_for_6dpose + + +class ComputeBodysegProbs(torch.nn.Module): + def __init__(self, scale_factor: int): + super(ComputeBodysegProbs, self).__init__() + self.scale_factor = scale_factor + self.pool = torch.nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor) + + self.pose6d_idx_to_bodyseg_idx = [] + for segment in segments_for_6dpose: + if segment.endswith("Tarsus1"): + segment = segment.replace("Tarsus1", "Tarsus") + bodyseg_class_labels = list(BodySegmentationPipeline.class_labels) + if segment not in bodyseg_class_labels: + logger.critical( + f"Segment {segment} in Pose6D model has no matching segment in " + "body segmentation maps." + ) + raise ValueError("Invalid segment name.") + self.pose6d_idx_to_bodyseg_idx.append(bodyseg_class_labels.index(segment)) + + def forward( + self, atomic_batch_frames: torch.Tensor, sim_data: dict[str, torch.Tensor] + ) -> torch.Tensor: + segmaps = sim_data["body_seg_maps"] # (B, H, W), value = seg idx in sim data + bat_size, nrows, ncols = segmaps.size() + + # Convert to one-hot (B, len(pose6d_segments), H, W) + segmaps_onehot = torch.zeros( + (bat_size, len(segments_for_6dpose), nrows, ncols), dtype=torch.float32 + ) + for i, bodyseg_idx in enumerate(self.pose6d_idx_to_bodyseg_idx): + segmaps_onehot[:, i, :, :] = (segmaps == bodyseg_idx).float() + + # Downsample using max pooling + segmaps_onehot_downsampled = self.pool(segmaps_onehot) + sim_data["segmap_probs_label"] = segmaps_onehot_downsampled.to(segmaps.device) + sim_data.pop("body_seg_maps") # remove original segmap + + return atomic_batch_frames, sim_data + + +class Pose6DPipeline: + def __init__( + self, + model: Pose6DModel, + loss_func: Pose6DLoss | None = None, + device: torch.device | str = "cuda", + use_float16: bool = True, + ): + self.model = model.to(device) + self.device = device + self.loss_func = loss_func + self.use_float16 = use_float16 + if torch.cuda.is_available() and "cuda" in str(device): + self.device_type = "cuda" + else: + self.device_type = "cpu" + + # Body segmentation maps are at full working resolution (256x256). The Pose6D + # model performs soft global pooling weighted by segmentation map probabilities + # at 64x64 resolution, so we need to downsample the segmentation maps. Use max + # pooling because it's better to slightly overestimate the presence of segments. + self._compute_bodyseg_probs_transform = ComputeBodysegProbs(scale_factor=4) + + def train( + self, + n_epochs: int, + data_config: config.TrainingDataConfig, + optimizer_config: config.OptimizerConfig, + artifacts_config: config.TrainingArtifactsConfig, + seed: int = 42, + half_batch_size_for_debugging: bool = False, + ): + # Set random seed for reproducibility + util.set_random_seed(seed) + + # If half_batch_size_for_debugging, cut batch sizes in half to save memory + self.half_batch_size_for_debugging = half_batch_size_for_debugging + if self.half_batch_size_for_debugging: + logger.warning( + "Debug mode: using half batch sizes for training and validation in " + "order to fit the model in memory for a GeForce RTX 3080 Ti." + ) + + # Set up training and validation data + train_ds, train_loader = self._init_training_dataset_and_dataloader(data_config) + val_ds, val_loader = self._init_validation_dataset_and_dataloader(data_config) + n_batches_per_epoch = len(train_loader) + + # Set up logging dir and logger + log_dir = Path(artifacts_config.output_basedir) / "logs" + log_dir.mkdir(parents=True, exist_ok=True) + writer = SummaryWriter(log_dir=log_dir) + + # Set up checkpoint dir + checkpoint_dir = Path(artifacts_config.output_basedir) / "checkpoints" + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Set up optimizer + optimizer = self._create_optimizer(optimizer_config) + + # Set up mixed-precision training + grad_scaler = torch.amp.GradScaler(self.device_type, enabled=self.use_float16) + self._check_amp_status_for_model_params( + grad_scaler, subtitle="Model parameters before training" + ) + + # Check if loss function is provided + if self.loss_func is None: + logger.critical("Loss function must be provided for training.") + raise ValueError("Loss function must be provided for training.") + + # Training loop + self.model.train() + for epoch_idx in range(n_epochs): + logger.info( + f"Starting epoch {epoch_idx} out of {n_epochs} at {datetime.now()}..." + ) + running_loss_dict = defaultdict(lambda: 0.0) + epoch_start_time = time() + running_start_time = time() + + for step_idx, atomic_batches in enumerate(train_loader): + # Format data + atomic_batches_frames, atomic_batches_sim_data = atomic_batches + frames, sim_data = synth_data.atomic_batches_to_simple_batch( + atomic_batches_frames, atomic_batches_sim_data, device=self.device + ) + if self.half_batch_size_for_debugging: + frames, sim_data = self._get_half_batch(frames, sim_data) + bodyseg_probs_label = sim_data["segmap_probs_label"] + + # Forward pass with mixed precision + with torch.amp.autocast(self.device_type, enabled=self.use_float16): + pred_pos, pred_quat = self.model(frames, bodyseg_probs_label) + loss_dict = self.loss_func( + pred_pos, pred_quat, sim_data["mesh_pos"], sim_data["mesh_quat"] + ) + + if epoch_idx == 0 and step_idx == 0: + # Check if float16 is used + self._check_amp_status_for_model_params( + grad_scaler, + subtitle="Model parameters at start of training", + ) + self._check_amp_status_during_training( + frames, + pred_pos, + pred_quat, + sim_data["mesh_pos"], + sim_data["mesh_quat"], + grad_scaler, + subtitle="Variables at start of training", + ) + + # Check magnitude of predictions + for i, seg_name in enumerate(segments_for_6dpose): + _pos = pred_pos[0, i, :].detach().cpu().tolist() + _quat = pred_quat[0, i, :].detach().cpu().tolist() + logger.info( + f"Sample {i}, {seg_name}: " + f"pred_pos={[round(x, 3) for x in _pos]}, " + f"pred_quat={[round(x, 3) for x in _quat]}" + ) + + # Backpropagate and optimize + optimizer.zero_grad(set_to_none=True) + grad_scaler.scale(loss_dict["total_loss"]).backward() + grad_scaler.step(optimizer) + grad_scaler.update() + + # Logging + for key, value in loss_dict.items(): + running_loss_dict[key] += value.item() + if step_idx % artifacts_config.logging_interval == 0 and step_idx > 0: + avg_loss_dict = { + k: x / artifacts_config.logging_interval + for k, x in running_loss_dict.items() + } + time_now = time() + time_elapsed = time_now - running_start_time + throughput = artifacts_config.logging_interval / time_elapsed + + running_loss_dict = defaultdict(lambda: 0.0) + running_start_time = time_now + self._update_logs_training( + writer, + epoch_index=epoch_idx, + within_epoch_step_idx=step_idx, + n_batches_per_epoch=n_batches_per_epoch, + avg_loss_dict=avg_loss_dict, + throughput=throughput, + ) + + # Run validation + if ( + step_idx % artifacts_config.validation_interval == 0 + and step_idx > 0 + ): + del ( + atomic_batches_frames, + atomic_batches_sim_data, + frames, + sim_data, + bodyseg_probs_label, + pred_pos, + pred_quat, + ) + util.clear_memory_cache() + val_loss_dict = self.validate( + val_loader, + max_batches=artifacts_config.n_batches_per_validation, + ) + self._update_logs_validation( + writer, + epoch_idx=epoch_idx, + within_epoch_step_idx=step_idx, + n_batches_per_epoch=n_batches_per_epoch, + val_loss_dict=val_loss_dict, + ) + + # Save checkpoint + # (every log_interval steps and last step of each epoch) + if ( + step_idx % artifacts_config.checkpoint_interval == 0 + and step_idx > 0 + ) or (step_idx == n_batches_per_epoch - 1): + checkpoint_path_stem = ( + checkpoint_dir / f"epoch{epoch_idx}_step{step_idx}" + ) + self._save_checkpoint( + checkpoint_path_stem, + model=self.model, + loss=self.loss_func, + optimizer=optimizer, + grad_scaler=grad_scaler, + ) + logger.info(f"Saved checkpoint to {checkpoint_path_stem}.*.pth") + + epoch_wall_time = time() - epoch_start_time + logger.info(f"Finished epoch {epoch_idx} in {epoch_wall_time:.2f} seconds.") + + writer.close() + + def validate( + self, validation_data_loader: DataLoader, max_batches: int | None = None + ): + if max_batches is None: + max_batches = len(validation_data_loader) + if max_batches <= 0: + raise ValueError("max_batches must be positive or None") + if self.loss_func is None: + raise ValueError("Loss function must be provided for validation") + + total_loss_dict = defaultdict(lambda: 0.0) + n_steps_iterated = 0 + self.model.eval() + with torch.no_grad(): + for step_idx, (atomic_batches_frames, atomic_batches_sim_data) in enumerate( + tqdm(validation_data_loader, desc="Validation", disable=None) + ): + if step_idx >= max_batches: + break + + # Format data + frames, sim_data = synth_data.atomic_batches_to_simple_batch( + atomic_batches_frames, atomic_batches_sim_data, device=self.device + ) + if self.half_batch_size_for_debugging: + frames, sim_data = self._get_half_batch(frames, sim_data) + bodyseg_probs_label = sim_data["segmap_probs_label"] + + # Run model + with torch.amp.autocast(self.device_type, enabled=self.use_float16): + pred_pos, pred_quat = self.model(frames, bodyseg_probs_label) + loss_dict = self.loss_func( + pred_pos, pred_quat, sim_data["mesh_pos"], sim_data["mesh_quat"] + ) + + # Accumulate losses + for key, loss in loss_dict.items(): + total_loss_dict[key] += loss.item() + n_steps_iterated += 1 + + del ( + atomic_batches_frames, + atomic_batches_sim_data, + frames, + sim_data, + bodyseg_probs_label, + pred_pos, + pred_quat, + ) + util.clear_memory_cache() + self.model.train() + return {k: v / n_steps_iterated for k, v in total_loss_dict.items()} + + def inference( + self, frames: torch.Tensor, bodyseg_probs: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + input_device = frames.device + self.model.eval() + with torch.no_grad(): + frames = frames.to(self.device) + bodyseg_probs = bodyseg_probs.to(self.device) + with torch.amp.autocast(self.device_type, enabled=self.use_float16): + pred_pos, pred_quat = self.model(frames, bodyseg_probs) + self.model.train() + return pred_pos.to(input_device), pred_quat.to(input_device) + + def _init_training_dataset_and_dataloader( + self, data_config: config.TrainingDataConfig + ): + return synth_data.init_atomic_dataset_and_dataloader( + data_dirs=data_config.train_data_dirs, + atomic_batch_n_samples=data_config.atomic_batch_n_samples, + atomic_batch_n_variants=data_config.atomic_batch_n_variants, + input_image_size=data_config.input_image_size, + batch_size=data_config.train_batch_size, + load_dof_angles=False, + load_keypoint_positions=False, + load_body_segment_maps=True, + load_mesh_pose6d=True, + shuffle=True, + n_workers=data_config.n_workers, + n_channels=3, + pin_memory=True, + drop_last=True, + transform=self._compute_bodyseg_probs_transform, + ) + + def _init_validation_dataset_and_dataloader( + self, data_config: config.TrainingDataConfig + ): + return synth_data.init_atomic_dataset_and_dataloader( + data_dirs=data_config.val_data_dirs, + atomic_batch_n_samples=data_config.atomic_batch_n_samples, + atomic_batch_n_variants=data_config.atomic_batch_n_variants, + input_image_size=data_config.input_image_size, + batch_size=data_config.val_batch_size, + load_dof_angles=False, + load_keypoint_positions=False, + load_body_segment_maps=True, + load_mesh_pose6d=True, + shuffle=False, + n_workers=data_config.n_workers, + n_channels=3, + pin_memory=True, + drop_last=True, + transform=self._compute_bodyseg_probs_transform, + ) + + def _create_optimizer(self, optimizer_config: config.OptimizerConfig): + params = [ + { + "params": self.model.feature_extractor.parameters(), + "lr": optimizer_config.learning_rate_encoder, + }, + { + "params": list( + chain( + # no dec_layer1 because we don't upsample all the way to 128 + self.model.dec_layer2.parameters(), + self.model.dec_layer3.parameters(), + self.model.dec_layer4.parameters(), + ) + ), + "lr": optimizer_config.learning_rate_upsample, + }, + { + "params": self.model.pose6d_heads.parameters(), + "lr": optimizer_config.learning_rate_pose6d_heads, + }, + ] + if self.model.n_attention_gated_feature_channels > 0: + params.append( + { + "params": self.model.attention_heads.parameters(), + "lr": optimizer_config.learning_rate_attention, + } + ) + + optimizer = torch.optim.AdamW( + params, weight_decay=optimizer_config.weight_decay + ) + + # Check if all parameters are covered + n_params_optimizer = util.count_optimizer_parameters(optimizer) + n_params_model = util.count_module_parameters(self.model) + assert n_params_optimizer == n_params_model, ( + f"Number of parameters in optimizer ({n_params_optimizer}) does not match " + f"number of parameters in model ({n_params_model})." + ) + + return optimizer + + def _check_amp_status_for_model_params( + self, grad_scaler: torch.amp.GradScaler, subtitle: str = "Model parameters" + ): + return util.check_mixed_precision_status( + self.use_float16, + self.device, + print_results=True, + tensors={ + "feature_extractor_params": self.model.feature_extractor.parameters(), + "decoder_params": chain( + self.model.dec_layer2.parameters(), + self.model.dec_layer3.parameters(), + self.model.dec_layer4.parameters(), + ), + "pose6d_heads_params": self.model.pose6d_heads.parameters(), + }, + grad_scaler=grad_scaler, + subtitle=subtitle, + ) + + def _check_amp_status_during_training( + self, + input_images: torch.Tensor, + pred_pos: torch.Tensor, + pred_quat: torch.Tensor, + target_pos: torch.Tensor, + target_quat: torch.Tensor, + grad_scaler: torch.amp.GradScaler, + subtitle: str = "Variables during training", + ): + return util.check_mixed_precision_status( + self.use_float16, + self.device, + print_results=True, + tensors={ + "input_images": input_images, + "target_pos": target_pos, + "target_quat": target_quat, + "pred_pos": pred_pos, + "pred_quat": pred_quat, + }, + grad_scaler=grad_scaler, + subtitle=subtitle, + ) + + @staticmethod + def _save_checkpoint( + checkpoint_path_stem: Path, + model: Pose6DModel, + loss: Pose6DLoss | None = None, + optimizer: torch.optim.Optimizer | None = None, + grad_scaler: torch.amp.GradScaler | None = None, + ) -> None: + path = checkpoint_path_stem.with_suffix(".model.pth") + torch.save(model.state_dict(), path) + if loss is not None: + path = checkpoint_path_stem.with_suffix(".loss.pth") + torch.save(loss.state_dict(), path) + if optimizer is not None: + path = checkpoint_path_stem.with_suffix(".optimizer.pth") + torch.save(optimizer.state_dict(), path) + if grad_scaler is not None: + path = checkpoint_path_stem.with_suffix(".grad_scaler.pth") + torch.save(grad_scaler.state_dict(), path) + + def _update_logs_training( + self, + writer: SummaryWriter, + *, + epoch_index: int, + within_epoch_step_idx: int, + n_batches_per_epoch: int, + avg_loss_dict: dict[str, float], + throughput: float, + ) -> None: + global_step_idx = epoch_index * n_batches_per_epoch + within_epoch_step_idx + writer.add_scalar("train/epoch", epoch_index, global_step_idx) + log_str = ( + f"Epoch {epoch_index}, step {within_epoch_step_idx}/{n_batches_per_epoch}, " + ) + for key, value in avg_loss_dict.items(): + log_str += f"{key}: {value:.4f}, " + writer.add_scalar(f"train/loss/{key}", value, global_step_idx) + log_str += f"throughput: {throughput:.2f} batches/sec" + logger.info(log_str) + writer.add_scalar("train/sys/throughput", throughput, global_step_idx) + + def _update_logs_validation( + self, + writer: SummaryWriter, + *, + epoch_idx: int, + within_epoch_step_idx: int, + n_batches_per_epoch: int, + val_loss_dict: dict[str, float], + ) -> None: + global_step_idx = epoch_idx * n_batches_per_epoch + within_epoch_step_idx + log_str = ( + f"Validation at epoch {epoch_idx}, " + f"step {within_epoch_step_idx}/{n_batches_per_epoch}, " + ) + for key, value in val_loss_dict.items(): + log_str += f"{key}: {value:.4f}, " + writer.add_scalar(f"val/loss/{key}", value, global_step_idx) + logger.info(log_str) + + def _get_half_batch(self, frames_batch, sim_data_batch): + """Return half of the batch to save memory (for debugging only).""" + half_batch_size = frames_batch.shape[0] // 2 + frames_batch = frames_batch[:half_batch_size, ...] + sim_data_batch = { + k: v[:half_batch_size, ...] for k, v in sim_data_batch.items() + } + return frames_batch, sim_data_batch diff --git a/src/poseforge/pose/pose6d/scripts/run_training.py b/src/poseforge/pose/pose6d/scripts/run_training.py new file mode 100644 index 0000000..fe7b3e2 --- /dev/null +++ b/src/poseforge/pose/pose6d/scripts/run_training.py @@ -0,0 +1,169 @@ +import torch +from pathlib import Path +from torchsummary import summary +from loguru import logger + +import poseforge.pose.pose6d.config as config +from poseforge.pose.pose6d import Pose6DModel, Pose6DLoss, Pose6DPipeline +from poseforge.util import get_hardware_availability +from poseforge.neuromechfly.constants import segments_for_6dpose + + +def setup_model( + architecture_config: config.ModelArchitectureConfig, + weights_config: config.ModelWeightsConfig | None, +) -> Pose6DModel: + model = Pose6DModel.create_from_config(architecture_config) + logger.info( + f"Model architecture set up based on architecture config {architecture_config}" + ) + if weights_config is not None: + model.load_weights_from_config(weights_config) + logger.info(f"Loaded weights into model based on weights config {weights_config}") + return model + + +def setup_loss_func(loss_config: config.LossConfig) -> Pose6DLoss: + loss_func = Pose6DLoss.create_from_config(loss_config) + logger.info("Set up Pose6D loss function") + return loss_func + + +def save_configs( + configs_dir: Path, + *, + model_architecture_config: config.ModelArchitectureConfig, + loss_config: config.LossConfig, + training_data_config: config.TrainingDataConfig, + optimizer_config: config.OptimizerConfig, + training_artifacts_config: config.TrainingArtifactsConfig, + model_weights_config: config.ModelWeightsConfig | None = None, +) -> None: + configs_dir.mkdir(parents=True, exist_ok=True) + model_architecture_config.save(configs_dir / "model_architecture_config.yaml") + loss_config.save(configs_dir / "loss_config.yaml") + training_data_config.save(configs_dir / "data_config.yaml") + optimizer_config.save(configs_dir / "optimizer_config.yaml") + training_artifacts_config.save(configs_dir / "artifacts_config.yaml") + if model_weights_config is not None: + model_weights_config.save(configs_dir / "model_weights_config.yaml") + + +def train_mesh6d_model( + n_epochs: int, + model_architecture_config: config.ModelArchitectureConfig, + model_weights_config: config.ModelWeightsConfig, + loss_config: config.LossConfig, + training_data_config: config.TrainingDataConfig, + optimizer_config: config.OptimizerConfig, + training_artifacts_config: config.TrainingArtifactsConfig, + seed: int = 42, + half_batch_size_for_debugging: bool = False, +) -> None: + # System setup + hardware_avail = get_hardware_availability(check_gpu=True, print_results=True) + if len(hardware_avail["gpus"]) == 0: + raise RuntimeError("No GPU available for training") + torch.backends.cudnn.benchmark = True + + # Save configs + save_configs( + Path(training_artifacts_config.output_basedir) / "configs", + model_architecture_config=model_architecture_config, + model_weights_config=model_weights_config, + loss_config=loss_config, + training_data_config=training_data_config, + optimizer_config=optimizer_config, + training_artifacts_config=training_artifacts_config, + ) + + # Initialize model and loss function + model = setup_model(model_architecture_config, model_weights_config) + criterion = setup_loss_func(loss_config) + + # Print model summary + print_model_summary(training_data_config, model) + + # Set up loss function + loss_func = setup_loss_func(loss_config) + + # Set up training pipeline + pipeline = Pose6DPipeline( + model=model, loss_func=loss_func, device="cuda", use_float16=True + ) + logger.info("Training pipeline set up") + + # Start training + pipeline.train( + n_epochs=n_epochs, + data_config=training_data_config, + optimizer_config=optimizer_config, + artifacts_config=training_artifacts_config, + seed=seed, + half_batch_size_for_debugging=half_batch_size_for_debugging, + ) + logger.info("Training completed") + + +def print_model_summary(training_data_config, model): + input_image_size = (3, *training_data_config.input_image_size) + attention_feat_map_size = (len(segments_for_6dpose), 64, 64) # hardcoded for now + in_dim = [input_image_size, attention_feat_map_size] + print("============== Full Model Summary ===============") + summary(model, in_dim, device="cpu") + print("=========== Feature Extractor Summary ===========") + summary(model.feature_extractor, input_image_size, device="cpu") + + +if __name__ == "__main__": + import tyro + + tyro.cli( + train_mesh6d_model, + prog=f"python {Path(__file__).name}", + description="Train a 6D mesh pose model using pretrained feature extractor.", + ) + + # # Example using native Python function calls: + # model_architecture_config = config.ModelArchitectureConfig() + # model_weights_config = config.ModelWeightsConfig( + # feature_extractor_weights="bulk_data/pose_estimation/contrastive_pretraining/trial_20251125a_lowlr/checkpoints/checkpoint_epoch009_step003055.feature_extractor.pth", + # model_weights=None, + # ) + # loss_config = config.LossConfig(translation_weight=1.0, quaternion_weight=2.0) + # data_basedir = Path("bulk_data/pose_estimation/atomic_batches/4variants") + # train_data_dirs = [ + # data_basedir / f"BO_Gal4_fly{fly}_trial{trial:03d}" + # for fly in range(1, 5) # flies 1-4 + # for trial in range(1, 6) # trials 1-5 + # ] + # val_data_dirs = [data_basedir / f"BO_Gal4_fly5_trial001"] + # training_data_config = config.TrainingDataConfig( + # train_data_dirs=[str(path) for path in train_data_dirs], + # val_data_dirs=[str(path) for path in val_data_dirs], + # input_image_size=(256, 256), + # atomic_batch_n_samples=32, + # atomic_batch_n_variants=4, + # train_batch_size=32, + # val_batch_size=32, + # n_workers=8, + # ) + # optimizer_config = config.OptimizerConfig() + # training_artifacts_config = config.TrainingArtifactsConfig( + # output_basedir="bulk_data/pose_estimation/mesh6d/trial_20251130z2/", + # logging_interval=10, # 1000 + # checkpoint_interval=100, # 1000 + # validation_interval=100, # 1000 + # n_batches_per_validation=30, # 300 + # ) + # train_mesh6d_model( + # n_epochs=10, + # model_architecture_config=model_architecture_config, + # model_weights_config=model_weights_config, + # loss_config=loss_config, + # training_data_config=training_data_config, + # optimizer_config=optimizer_config, + # training_artifacts_config=training_artifacts_config, + # seed=42, + # half_batch_size_for_debugging=True, + # ) diff --git a/src/poseforge/util/__init__.py b/src/poseforge/util/__init__.py index 83730ca..ceb54a2 100644 --- a/src/poseforge/util/__init__.py +++ b/src/poseforge/util/__init__.py @@ -5,7 +5,7 @@ check_mixed_precision_status, ) from .plot import configure_matplotlib_style -from .data import SerializableDataClass, OutputBuffer +from .data import SerializableDataClass, OutputBuffer, bulk_data_dir from .ml import count_optimizer_parameters, count_module_parameters @@ -17,6 +17,7 @@ "configure_matplotlib_style", "SerializableDataClass", "OutputBuffer", + "bulk_data_dir", "count_optimizer_parameters", "count_module_parameters", ] diff --git a/src/poseforge/util/data.py b/src/poseforge/util/data.py index 464b0dc..d618ff6 100644 --- a/src/poseforge/util/data.py +++ b/src/poseforge/util/data.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Hashable, Callable, Any +import poseforge + @dataclasses.dataclass(frozen=True) class SerializableDataClass: @@ -143,3 +145,7 @@ def n_open_buckets(self) -> int: @property def n_data_total(self) -> int: return sum(len(buf) for buf in self.buffers.values()) + + +assert len(poseforge.__path__) == 1, "poseforge.__path__ contains multiple paths" +bulk_data_dir = Path(poseforge.__path__[0]).parent.parent / "bulk_data"