A simplified PyTorch implementation of InfoBatch: Lossless Training Speed Up by Unbiased Dynamic Data Pruning, ICML 2024.
InfoBatch is a data pruning technique that speeds up training without accuracy loss by dynamically selecting informative samples during training. This implementation uses ResNet-18 on CIFAR-10/100 datasets.
Key Features:
- Dynamic data pruning based on per-sample loss scores
- Unbiased gradient estimation using importance sampling
- Multiple pruning strategies (standard, moving average, reverse)
- MLflow integration for experiment tracking
- Docker with GPU support
- NVIDIA GPU with CUDA 11.7
- Base image:
pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
sh run_container.shNote: If the volume mount path is incorrect, modify the -v parameter in run_container.sh:
docker run -it -d \
-v ~/info_batch:/sources \
-v ~/info_batch/dataset/:/sources/dataset \
--ipc=host --gpus=all \
--name info_batch \
pytorch/pytorch:2.0.0-cuda11.7-cudnn8-develAfter running the container:
docker exec -it info_batch bash
pip install mlflowInside the container:
cd /sources/src
python res18_cifar10_ib.pyInfoBatch-Pytorch/
├── src/ # Source code
│ ├── res18_cifar10_whole.py # Baseline training (full dataset)
│ ├── res18_cifar10_ib.py # InfoBatch with standard policy
│ ├── res18_cifar10_ib_ma.py # InfoBatch with moving average threshold
│ ├── res18_cifar10_ib_rev.py # InfoBatch with reverse policy (train pruned samples)
│ ├── res18_cifar100_*.py # CIFAR-100 variants
│ └── utils/
│ ├── dataset.py # Custom dataset classes (InfoCIFAR10/100)
│ ├── policy.py # Standard pruning policy
│ ├── ma_policy.py # Moving average policy
│ └── reverse_policy.py # Reverse pruning policy
├── config/ # Training configurations
│ ├── res18_cifar10_whole.py # Hyperparameters for CIFAR-10
│ └── res18_cifar100_whole.py # Hyperparameters for CIFAR-100
└── run_container.sh # Docker container launch script
| File | Description | Use Case |
|---|---|---|
res18_cifarXX_whole.py |
Baseline training on full dataset | Comparison baseline |
res18_cifarXX_ib.py |
Standard InfoBatch with threshold-based pruning | Main method |
res18_cifarXX_ib_ma.py |
InfoBatch with moving average threshold | Smoother threshold updates |
res18_cifarXX_ib_rev.py |
Trains on pruned samples instead of kept samples | Ablation study |
Default settings in config/res18_cifar10_whole.py:
num_epochs = 100
batch_size = 128
max_lr = 0.1
momentum = 0.9
pct_start = 0.3 # OneCycleLR parameter
weight_decay = 5e-4
label_smooth = 0.0Adjustable in src/utils/policy.py:
PruningPolicy(
data_size, # Total dataset size
total_epoch, # Total training epochs
prob=0.5, # Pruning probability for low-score samples
anneal=0.875 # Fraction of epochs to apply pruning (87.5%)
)How it works:
- Computes per-sample loss scores each iteration
- At epoch end, calculates threshold as mean score of used samples
- Divides data into:
- D2: High-score samples (≥ threshold) → always used
- D1: Low-score samples (< threshold) → probabilistically pruned
- Applies importance sampling weights:
1/(1-prob)for kept D1 samples - After annealing period, uses full dataset
All experiments log metrics to MLflow:
- Training/test loss and accuracy
- Learning rate schedule
- Epoch time and cumulative duration
- Dynamic threshold values
- Iteration count per epoch
mlflow uiThen open http://localhost:5000 in your browser.
Logged Metrics:
train_loss,test_loss,test_accuracyepoch_time,train_durationthreshold(dynamic pruning threshold)iter_count(number of batches per epoch)
InfoCIFAR10/100 extends torchvision's CIFAR datasets to return:
img: Image tensortarget: Class labelindex: Sample index (for score tracking)scaler: Importance sampling weight
PruningPolicy implements PyTorch's Sampler interface:
- Tracks per-sample scores
- Dynamically updates pruning decisions each epoch
- Returns
(index, scaler)tuples for dataloader
for idx, (x, y, sample_idx, scaler) in train_loader:
logits = model(x)
loss = criterion(logits, y) # Per-sample loss
policy.update_scores(sample_idx, loss.detach()) # Track scores
loss = (loss * scaler).mean() # Apply importance weights
loss.backward()
optimizer.step()
policy.update_policy(epoch+1) # Update pruning at epoch end- GPU not detected: Ensure
nvidia-dockeris installed and--gpus=allflag is set - Permission denied: Run
chmod +x run_container.sh
- Out of memory: Reduce
batch_sizein config files - Dataset not found: Datasets auto-download to
/sources/dataset/on first run - MLflow not found: Install inside container:
pip install mlflow
- All experiments expect
/sourcesas the working directory - Modify
sys.path.append('/sources')if using different mount paths
If you use this code, please cite:
@inproceedings{qin2024infobatch,
title={InfoBatch: Lossless Training Speed Up by Unbiased Dynamic Data Pruning},
booktitle={International Conference on Machine Learning (ICML)},
year={2024}
}This is a research implementation. Please refer to the original paper for academic use.