This repository contains a deep learning pipeline for the automated classification of intervertebral disc degeneration using the Pfirrmann grading system (Grades 1 to 5).
The core of this project consists of two main scripts:
-
densenet_training.py: Trains a 3D DenseNet121 model on MRI NIfTI volumes.
-
model_evaluation.py: Evaluates the trained model on the validation set and generates a detailed Confusion Matrix.
This script is highly optimized for medical imaging data and includes several advanced machine learning techniques to ensure robust, real-world performance:
- 3D DenseNet Architecture: Utilizes MONAI's
DenseNet121configured for 3D spatial dimensions to capture volumetric disc features. - Data Leakage Prevention: Uses Scikit-Learn's
GroupShuffleSplitto split the dataset bypatient_id. This ensures that multiple discs from the same patient are kept in the same split (Train or Validation), preventing the model from "memorizing" patient-specific anatomy. - Dynamic Class Weighting: Automatically calculates and applies weights to the
CrossEntropyLossfunction based on the training set distribution. This directly addresses class imbalance (e.g., rare Grade 1 and Grade 5 cases vs. common Grade 3 and 4 cases). - Regularization against Overfitting: Implements a 40% Dropout (
dropout_prob=0.4) inside the DenseNet and Weight Decay (weight_decay=1e-3) in the Adam optimizer to force the model to generalize rather than memorize. - High-Speed Caching: Uses MONAI's
CacheDatasetto load and apply deterministic transforms to the RAM upfront, drastically reducing disk I/O bottlenecks and speeding up training epochs. - Comprehensive Metrics Tracking: Calculates and logs Validation Loss, Accuracy, AUC, Precision, Sensitivity (Recall), Specificity, and Macro F1-Score.
- Automated Visualization: Generates a dual-axis
progress.pngchart every validation step, tracking Train/Val Loss and Train/Val Accuracy over time.
Once a model is trained, this script is used to securely load the best weights and perform a detailed evaluation on the validation cohort.
- Confusion Matrix Generation: Automatically computes and plots a multi-class confusion matrix using Scikit-Learn and Matplotlib, saving it directly to the experiment's run folder for easy clinical analysis.
The model requires a CSV file and a directory of 3D NIfTI (.nii or .nii.gz) images.
By default, the script expects a CSV file containing at least the following columns:
disc_path: The absolute path to the cropped 3D NIfTI file of the intervertebral disc.Pfirrmann: The clinical ground truth label (integer from 1 to 5). Note: The script automatically shifts these to 0-4 for PyTorch compatibility.patient_id: A unique identifier for the patient. Used to safely split the data without leakage.
Example CSV structure:
patient_id,disc_path,Pfirrmann
Pat_001,/data/images/Pat_001_L1_L2.nii.gz,2
Pat_001,/data/images/Pat_001_L2_L3.nii.gz,3
Pat_002,/data/images/Pat_002_L4_L5.nii.gz,5
The images should be 3D MRI crops of individual intervertebral discs. The script automatically applies the following MONAI preprocessing pipeline to standardize them:
- Orientation: Reoriented to
RAScoordinate system. - Spacing: Resampled to an isotropic voxel size of
1.0 x 1.0 x 1.0 mm. - Intensity: Scaled/Normalized using
ScaleIntensityd. - Augmentation (Train only): Random 3D rotations (
RandRotated) with a 20% probability. - Resizing: Standardized to a spatial size of
96 x 64 x 32voxels to fit the DenseNet architecture and optimize GPU memory.
Ensure you have the required libraries installed in your Python environment:
pip install torch monai numpy pandas scikit-learn nibabel matplotlib tensorboard
You can run the script via the command line. It accepts several arguments to customize the run:
python densenet_training.py \
--csv "/path/to/your/dataset.csv" \
--img_col "disc_path" \
--label_col "Pfirrmann" \
--epochs 150
Arguments:
--csv: Path to your input CSV file.--img_col: Name of the column containing image paths (default:disc_path).--label_col: Name of the column containing the target labels (default:Pfirrmann).--epochs: Total number of training epochs (default:150).
During and after training, the script will create a runs/ directory with a timestamped folder for the current experiment (e.g., runs/Jan01_12-00-00/). Inside this folder, you will find:
best_metric_model.pt: The saved PyTorch model weights from the epoch that achieved the highest Validation Accuracy.progress.png: An automatically updated matplotlib chart showing:
- Train Loss (Blue solid line) vs. Val Loss (Cyan dashed line).
- Train Accuracy (Red solid line) vs. Val Accuracy (Orange dashed line).
- A gold star marking the epoch where the best model was saved.
- TensorBoard Logs: Event files tracking loss, accuracy, and AUC. You can visualize them by running:
tensorboard --logdir=runs/