Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ STAMP is an **end‑to‑end, weakly‑supervised deep‑learning pipeline** tha
* 🎓 **Beginner‑friendly & expert‑ready**: Zero‑code CLI and YAML config for routine use; optional code‑level customization for advanced research.
* 🧩 **Model‑rich**: Out‑of‑the‑box support for **+20 foundation models** at [tile level](getting-started.md#feature-extraction) (e.g., *Virchow‑v2*, *UNI‑v2*) and [slide level](getting-started.md#slide-level-encoding) (e.g., *TITAN*, *COBRA*).
* 🔬 **Weakly‑supervised**: End‑to‑end MIL with Transformer aggregation for training, cross‑validation and external deployment; no pixel‑level labels required.
* 🧮 **Multi-task learning**: Unified framework for **classification**, **regression**, and **cox-based survival analysis**.
* 🧮 **Multi-task learning**: Unified framework for **classification**, **multi-target classification**, **regression**, and **cox-based survival analysis**.
* 📊 **Stats & results**: Built‑in metrics and patient‑level predictions, ready for analysis and reporting.
* 🖼️ **Explainable**: Generates heatmaps and top‑tile exports out‑of‑the‑box for transparent model auditing and publication‑ready figures.
* 🤝 **Collaborative by design**: Clinicians drive hypothesis & interpretation while engineers handle compute; STAMP’s modular CLI mirrors real‑world workflows and tracks every step for full reproducibility.
Expand Down
4 changes: 2 additions & 2 deletions src/stamp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import yaml

from stamp.config import StampConfig
from stamp.modeling.config import (
AdvancedConfig,
MlpModelParams,
ModelParams,
VitModelParams,
)
from stamp.seed import Seed
from stamp.utils.config import StampConfig
from stamp.utils.seed import Seed

STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml")

Expand Down
24 changes: 22 additions & 2 deletions src/stamp/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ preprocessing:
# Extractor to use for feature extractor. Possible options are "ctranspath",
# "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom",
# "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow",
# "virchow-full", "musk", "mstar", "plip"
# "virchow-full", "musk", "mstar", "plip", "ticon"
# Some of them require requesting access to the respective authors beforehand.
extractor: "chief-ctranspath"

Expand Down Expand Up @@ -76,6 +76,8 @@ crossval:

# Name of the column from the clini table to train on.
ground_truth_label: "KRAS"
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# For survival (should be status and follow-up days columns in clini table)
# status_label: "event"
Expand Down Expand Up @@ -133,6 +135,8 @@ training:

# Name of the column from the clini table to train on.
ground_truth_label: "KRAS"
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# For survival (should be status and follow-up days columns in clini table)
# status_label: "event"
Expand Down Expand Up @@ -175,6 +179,8 @@ deployment:

# Name of the column from the clini to compare predictions to.
ground_truth_label: "KRAS"
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# For survival (should be status and follow-up days columns in clini table)
# status_label: "event"
Expand All @@ -200,6 +206,8 @@ statistics:

# Name of the target label.
ground_truth_label: "KRAS"
# For multi-target classification you may specify a list of columns,
# e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"]

# A lot of the statistics are computed "one-vs-all", i.e. there needs to be
# a positive class to calculate the statistics for.
Expand Down Expand Up @@ -319,7 +327,7 @@ advanced_config:
max_lr: 1e-4
div_factor: 25.
# Select a model regardless of task
model_name: "vit" # or mlp, trans_mil
model_name: "vit" # or mlp, trans_mil, barspoon

model_params:
vit: # Vision Transformer
Expand All @@ -338,3 +346,15 @@ advanced_config:
dim_hidden: 512
num_layers: 2
dropout: 0.25

# NOTE: Only the `barspoon` model supports multi-target classification
# (i.e. `ground_truth_label` can be a list of column names). Other
# models expect a single target column.
barspoon: # Encoder-Decoder Transformer for multi-target classification
d_model: 512
num_encoder_heads: 8
num_decoder_heads: 8
num_encoder_layers: 2
num_decoder_layers: 2
dim_feedforward: 2048
positional_encoding: true
2 changes: 1 addition & 1 deletion src/stamp/encoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def init_slide_encoder_(
selected_encoder = encoder

case _ as unreachable:
assert_never(unreachable) # type: ignore
assert_never(unreachable)

selected_encoder.encode_slides_(
output_dir=output_dir,
Expand Down
6 changes: 4 additions & 2 deletions src/stamp/encoding/encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import cast

import h5py
import numpy as np
Expand All @@ -12,11 +13,11 @@
from tqdm import tqdm

import stamp
from stamp.cache import get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.modeling.data import CoordsInfo, get_coords, read_table
from stamp.preprocessing.config import ExtractorName
from stamp.types import DeviceLikeType, PandasLabel
from stamp.utils.cache import get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down Expand Up @@ -183,7 +184,8 @@ def _read_h5(
elif not h5_path.endswith(".h5"):
raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}")
with h5py.File(h5_path, "r") as f:
feats: Tensor = torch.tensor(f["feats"][:], dtype=self.precision) # type: ignore
feats_ds = cast(h5py.Dataset, f["feats"])
feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision)
coords: CoordsInfo = get_coords(f)
extractor: str = f.attrs.get("extractor", "")
if extractor == "":
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/chief.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from numpy import ndarray
from tqdm import tqdm

from stamp.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.preprocessing.config import ExtractorName
from stamp.types import DeviceLikeType, PandasLabel
from stamp.utils.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from torch import Tensor
from tqdm import tqdm

from stamp.cache import get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.encoding.encoder.chief import CHIEF
from stamp.modeling.data import CoordsInfo
from stamp.preprocessing.config import ExtractorName
from stamp.types import DeviceLikeType, PandasLabel
from stamp.utils.cache import get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/gigapath.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from gigapath import slide_encoder
from tqdm import tqdm

from stamp.cache import get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.modeling.data import CoordsInfo
from stamp.preprocessing.config import ExtractorName
from stamp.types import PandasLabel, SlideMPP
from stamp.utils.cache import get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/madeleine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from numpy import ndarray

from stamp.cache import STAMP_CACHE_DIR
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.preprocessing.config import ExtractorName
from stamp.utils.cache import STAMP_CACHE_DIR

try:
from madeleine.models.factory import create_model_from_pretrained
Expand Down
2 changes: 1 addition & 1 deletion src/stamp/encoding/encoder/titan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from tqdm import tqdm
from transformers import AutoModel

from stamp.cache import get_processing_code_hash
from stamp.encoding.config import EncoderName
from stamp.encoding.encoder import Encoder
from stamp.modeling.data import CoordsInfo
from stamp.preprocessing.config import ExtractorName
from stamp.types import DeviceLikeType, Microns, PandasLabel, SlideMPP
from stamp.utils.cache import get_processing_code_hash

__author__ = "Juan Pablo Ricapito"
__copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito"
Expand Down
92 changes: 50 additions & 42 deletions src/stamp/heatmaps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
from collections.abc import Collection, Iterable
from pathlib import Path
from typing import cast, no_type_check
from typing import cast

import h5py
import matplotlib.pyplot as plt
Expand All @@ -19,7 +19,7 @@
from packaging.version import Version
from PIL import Image
from torch import Tensor
from torch.func import jacrev # pyright: ignore[reportPrivateImportUsage]
from torch.func import jacrev

from stamp.modeling.data import get_coords, get_stride
from stamp.modeling.deploy import load_model_from_ckpt
Expand All @@ -29,6 +29,8 @@

_logger = logging.getLogger("stamp")

_SlideLike = openslide.OpenSlide | openslide.ImageSlide


def _gradcam_per_category(
model: torch.nn.Module,
Expand All @@ -37,23 +39,19 @@ def _gradcam_per_category(
) -> Float[Tensor, "tile category"]:
feat_dim = -1

cam = (
(
feats
* jacrev(
lambda bags: model.forward(
bags.unsqueeze(0),
coords=coords.unsqueeze(0),
mask=None,
).squeeze(0)
)(feats)
)
.mean(feat_dim) # type: ignore
.abs()
jac = cast(
Tensor,
jacrev(
lambda bags: model.forward(
bags.unsqueeze(0),
coords=coords.unsqueeze(0),
mask=None,
).squeeze(0)
)(feats),
)

cam = (feats * jac).mean(feat_dim).abs()
cam = torch.softmax(cam, dim=-1)

return cam.permute(-1, -2)


Expand All @@ -79,12 +77,19 @@ def _attention_rollout_single(

# --- 2. Rollout computation ---
attn_rollout: torch.Tensor | None = None
for layer in model.transformer.layers: # type: ignore
attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights
transformer = getattr(model, "transformer", None)
if transformer is None:
raise RuntimeError("Model does not have a transformer attribute")
for layer in transformer.layers:
attn = getattr(layer, "attn_weights", None)
if attn is None:
first_child = next(iter(layer.children()), None)
if first_child is not None:
attn = getattr(first_child, "attn_weights", None)
if attn is None:
raise RuntimeError(
"SelfAttention.attn_weights not found. "
"Make sure SelfAttention stores them."
"Make sure SelfAttention stores them on the layer or its first child."
)

# attn: [heads, seq, seq]
Expand Down Expand Up @@ -117,15 +122,18 @@ def _gradcam_single(
"""
feat_dim = -1

jac = jacrev(
lambda bags: model.forward(
bags.unsqueeze(0),
coords=coords.unsqueeze(0),
mask=None,
).squeeze()
)(feats)
jac = cast(
Tensor,
jacrev(
lambda bags: model.forward(
bags.unsqueeze(0),
coords=coords.unsqueeze(0),
mask=None,
).squeeze()
)(feats),
)

cam = (feats * jac).mean(feat_dim).abs() # type: ignore # [tile]
cam = (feats * jac).mean(feat_dim).abs() # [tile]

return cam

Expand All @@ -148,17 +156,21 @@ def _vals_to_im(


def _show_thumb(
slide, thumb_ax: Axes, attention: Tensor, default_slide_mpp: SlideMPP | None
slide: _SlideLike,
thumb_ax: Axes,
attention: Tensor,
default_slide_mpp: SlideMPP | None,
) -> np.ndarray:
mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp)
dims_um = np.array(slide.dimensions) * mpp
thumb = slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int))
thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist())
thumb = slide.get_thumbnail(thumb_size)
thumb_ax.imshow(np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8])
return np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8]


def _get_thumb_array(
slide,
slide: _SlideLike,
attention: torch.Tensor,
default_slide_mpp: SlideMPP | None,
) -> np.ndarray:
Expand All @@ -168,12 +180,12 @@ def _get_thumb_array(
"""
mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp)
dims_um = np.array(slide.dimensions) * mpp
thumb = np.array(slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int)))
thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist())
thumb = np.array(slide.get_thumbnail(thumb_size))
thumb_crop = thumb[: attention.shape[0] * 8, : attention.shape[1] * 8]
return thumb_crop


@no_type_check # beartype<=0.19.0 breaks here for some reason
def _show_class_map(
class_ax: Axes,
top_score_indices: Integer[Tensor, "width height"],
Expand Down Expand Up @@ -298,13 +310,8 @@ def heatmaps_(
raise ValueError(
f"Feature file {h5_path} is a slide or patient level feature. Heatmaps are currently supported for tile-level features only."
)
feats = (
torch.tensor(
h5["feats"][:] # pyright: ignore[reportIndexIssue]
)
.float()
.to(device)
)
feats_np = np.asarray(h5["feats"])
feats = torch.from_numpy(feats_np).float().to(device)
coords_info = get_coords(h5)
coords_um = torch.from_numpy(coords_info.coords_um).float()
stride_um = Microns(get_stride(coords_um))
Expand All @@ -322,9 +329,10 @@ def heatmaps_(
model = load_model_from_ckpt(checkpoint_path).eval()

# TODO: Update version when a newer model logic breaks heatmaps.
if Version(model.stamp_version) < Version("2.4.0"):
stamp_version = str(getattr(model, "stamp_version", ""))
if Version(stamp_version) < Version("2.4.0"):
raise ValueError(
f"model has been built with stamp version {model.stamp_version} "
f"model has been built with stamp version {stamp_version} "
f"which is incompatible with the current version."
)

Expand Down Expand Up @@ -356,7 +364,7 @@ def heatmaps_(

with torch.no_grad():
scores = torch.softmax(
model.model.forward(
model.model(
feats.unsqueeze(-2),
coords=coords_um.unsqueeze(-2),
mask=torch.zeros(
Expand Down
Loading
Loading