diff --git a/README.md b/README.md index 3069dddf..9a0ed68b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ STAMP is an **end‑to‑end, weakly‑supervised deep‑learning pipeline** tha * 🎓 **Beginner‑friendly & expert‑ready**: Zero‑code CLI and YAML config for routine use; optional code‑level customization for advanced research. * 🧩 **Model‑rich**: Out‑of‑the‑box support for **+20 foundation models** at [tile level](getting-started.md#feature-extraction) (e.g., *Virchow‑v2*, *UNI‑v2*) and [slide level](getting-started.md#slide-level-encoding) (e.g., *TITAN*, *COBRA*). * 🔬 **Weakly‑supervised**: End‑to‑end MIL with Transformer aggregation for training, cross‑validation and external deployment; no pixel‑level labels required. -* 🧮 **Multi-task learning**: Unified framework for **classification**, **regression**, and **cox-based survival analysis**. +* 🧮 **Multi-task learning**: Unified framework for **classification**, **multi-target classification**, **regression**, and **cox-based survival analysis**. * 📊 **Stats & results**: Built‑in metrics and patient‑level predictions, ready for analysis and reporting. * 🖼️ **Explainable**: Generates heatmaps and top‑tile exports out‑of‑the‑box for transparent model auditing and publication‑ready figures. * 🤝 **Collaborative by design**: Clinicians drive hypothesis & interpretation while engineers handle compute; STAMP’s modular CLI mirrors real‑world workflows and tracks every step for full reproducibility. diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 4ab8416f..ffa98bae 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -6,14 +6,14 @@ import yaml -from stamp.config import StampConfig from stamp.modeling.config import ( AdvancedConfig, MlpModelParams, ModelParams, VitModelParams, ) -from stamp.seed import Seed +from stamp.utils.config import StampConfig +from stamp.utils.seed import Seed STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 8440560b..71c54b74 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip" + # "virchow-full", "musk", "mstar", "plip", "ticon" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" @@ -76,6 +76,8 @@ crossval: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -133,6 +135,8 @@ training: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -175,6 +179,8 @@ deployment: # Name of the column from the clini to compare predictions to. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -200,6 +206,8 @@ statistics: # Name of the target label. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # A lot of the statistics are computed "one-vs-all", i.e. there needs to be # a positive class to calculate the statistics for. @@ -319,7 +327,7 @@ advanced_config: max_lr: 1e-4 div_factor: 25. # Select a model regardless of task - model_name: "vit" # or mlp, trans_mil + model_name: "vit" # or mlp, trans_mil, barspoon model_params: vit: # Vision Transformer @@ -338,3 +346,15 @@ advanced_config: dim_hidden: 512 num_layers: 2 dropout: 0.25 + + # NOTE: Only the `barspoon` model supports multi-target classification + # (i.e. `ground_truth_label` can be a list of column names). Other + # models expect a single target column. + barspoon: # Encoder-Decoder Transformer for multi-target classification + d_model: 512 + num_encoder_heads: 8 + num_decoder_heads: 8 + num_encoder_layers: 2 + num_decoder_layers: 2 + dim_feedforward: 2048 + positional_encoding: true \ No newline at end of file diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 9cb873bb..02d46594 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -73,7 +73,7 @@ def init_slide_encoder_( selected_encoder = encoder case _ as unreachable: - assert_never(unreachable) # type: ignore + assert_never(unreachable) selected_encoder.encode_slides_( output_dir=output_dir, diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 5827e884..4720ef9b 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from pathlib import Path from tempfile import NamedTemporaryFile +from typing import cast import h5py import numpy as np @@ -12,11 +13,11 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.modeling.data import CoordsInfo, get_coords, read_table from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" @@ -183,7 +184,8 @@ def _read_h5( elif not h5_path.endswith(".h5"): raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}") with h5py.File(h5_path, "r") as f: - feats: Tensor = torch.tensor(f["feats"][:], dtype=self.precision) # type: ignore + feats_ds = cast(h5py.Dataset, f["feats"]) + feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision) coords: CoordsInfo = get_coords(f) extractor: str = f.attrs.get("extractor", "") if extractor == "": diff --git a/src/stamp/encoding/encoder/chief.py b/src/stamp/encoding/encoder/chief.py index 2ad4b91b..924ceebb 100644 --- a/src/stamp/encoding/encoder/chief.py +++ b/src/stamp/encoding/encoder/chief.py @@ -10,11 +10,11 @@ from numpy import ndarray from tqdm import tqdm -from stamp.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index de49c369..45092f4f 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -9,13 +9,13 @@ from torch import Tensor from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.encoding.encoder.chief import CHIEF from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index 9cb3f6f5..4c0a2f6b 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -9,12 +9,12 @@ from gigapath import slide_encoder from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/madeleine.py b/src/stamp/encoding/encoder/madeleine.py index 5798a592..a0c74dcd 100644 --- a/src/stamp/encoding/encoder/madeleine.py +++ b/src/stamp/encoding/encoder/madeleine.py @@ -3,10 +3,10 @@ import torch from numpy import ndarray -from stamp.cache import STAMP_CACHE_DIR from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import STAMP_CACHE_DIR try: from madeleine.models.factory import create_model_from_pretrained diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 2dba6021..920b3db8 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -10,12 +10,12 @@ from tqdm import tqdm from transformers import AutoModel -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, Microns, PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index 446d85d6..fb704fe6 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -5,7 +5,7 @@ import logging from collections.abc import Collection, Iterable from pathlib import Path -from typing import cast, no_type_check +from typing import cast import h5py import matplotlib.pyplot as plt @@ -19,7 +19,7 @@ from packaging.version import Version from PIL import Image from torch import Tensor -from torch.func import jacrev # pyright: ignore[reportPrivateImportUsage] +from torch.func import jacrev from stamp.modeling.data import get_coords, get_stride from stamp.modeling.deploy import load_model_from_ckpt @@ -29,6 +29,8 @@ _logger = logging.getLogger("stamp") +_SlideLike = openslide.OpenSlide | openslide.ImageSlide + def _gradcam_per_category( model: torch.nn.Module, @@ -37,23 +39,19 @@ def _gradcam_per_category( ) -> Float[Tensor, "tile category"]: feat_dim = -1 - cam = ( - ( - feats - * jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze(0) - )(feats) - ) - .mean(feat_dim) # type: ignore - .abs() + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze(0) + )(feats), ) + cam = (feats * jac).mean(feat_dim).abs() cam = torch.softmax(cam, dim=-1) - return cam.permute(-1, -2) @@ -79,12 +77,19 @@ def _attention_rollout_single( # --- 2. Rollout computation --- attn_rollout: torch.Tensor | None = None - for layer in model.transformer.layers: # type: ignore - attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights + transformer = getattr(model, "transformer", None) + if transformer is None: + raise RuntimeError("Model does not have a transformer attribute") + for layer in transformer.layers: + attn = getattr(layer, "attn_weights", None) + if attn is None: + first_child = next(iter(layer.children()), None) + if first_child is not None: + attn = getattr(first_child, "attn_weights", None) if attn is None: raise RuntimeError( "SelfAttention.attn_weights not found. " - "Make sure SelfAttention stores them." + "Make sure SelfAttention stores them on the layer or its first child." ) # attn: [heads, seq, seq] @@ -117,15 +122,18 @@ def _gradcam_single( """ feat_dim = -1 - jac = jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze() - )(feats) + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze() + )(feats), + ) - cam = (feats * jac).mean(feat_dim).abs() # type: ignore # [tile] + cam = (feats * jac).mean(feat_dim).abs() # [tile] return cam @@ -148,17 +156,21 @@ def _vals_to_im( def _show_thumb( - slide, thumb_ax: Axes, attention: Tensor, default_slide_mpp: SlideMPP | None + slide: _SlideLike, + thumb_ax: Axes, + attention: Tensor, + default_slide_mpp: SlideMPP | None, ) -> np.ndarray: mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int)) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = slide.get_thumbnail(thumb_size) thumb_ax.imshow(np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8]) return np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8] def _get_thumb_array( - slide, + slide: _SlideLike, attention: torch.Tensor, default_slide_mpp: SlideMPP | None, ) -> np.ndarray: @@ -168,12 +180,12 @@ def _get_thumb_array( """ mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = np.array(slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int))) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = np.array(slide.get_thumbnail(thumb_size)) thumb_crop = thumb[: attention.shape[0] * 8, : attention.shape[1] * 8] return thumb_crop -@no_type_check # beartype<=0.19.0 breaks here for some reason def _show_class_map( class_ax: Axes, top_score_indices: Integer[Tensor, "width height"], @@ -298,13 +310,8 @@ def heatmaps_( raise ValueError( f"Feature file {h5_path} is a slide or patient level feature. Heatmaps are currently supported for tile-level features only." ) - feats = ( - torch.tensor( - h5["feats"][:] # pyright: ignore[reportIndexIssue] - ) - .float() - .to(device) - ) + feats_np = np.asarray(h5["feats"]) + feats = torch.from_numpy(feats_np).float().to(device) coords_info = get_coords(h5) coords_um = torch.from_numpy(coords_info.coords_um).float() stride_um = Microns(get_stride(coords_um)) @@ -322,9 +329,10 @@ def heatmaps_( model = load_model_from_ckpt(checkpoint_path).eval() # TODO: Update version when a newer model logic breaks heatmaps. - if Version(model.stamp_version) < Version("2.4.0"): + stamp_version = str(getattr(model, "stamp_version", "")) + if Version(stamp_version) < Version("2.4.0"): raise ValueError( - f"model has been built with stamp version {model.stamp_version} " + f"model has been built with stamp version {stamp_version} " f"which is incompatible with the current version." ) @@ -356,7 +364,7 @@ def heatmaps_( with torch.no_grad(): scores = torch.softmax( - model.model.forward( + model.model( feats.unsqueeze(-2), coords=coords_um.unsqueeze(-2), mask=torch.zeros( diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 21ce69db..5b9a6bcc 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -21,7 +21,7 @@ class TrainConfig(BaseModel): ) feature_dir: Path = Field(description="Directory containing feature files") - ground_truth_label: PandasLabel | None = Field( + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = Field( default=None, description="Name of categorical column in clinical table to train on", ) @@ -64,7 +64,7 @@ class DeploymentConfig(BaseModel): slide_table: Path feature_dir: Path - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None patient_label: PandasLabel = "PATIENT" filename_label: PandasLabel = "FILENAME" @@ -99,8 +99,29 @@ class TransMILModelParams(BaseModel): dim_hidden: int = 512 +class BarspoonParams(BaseModel): + model_config = ConfigDict(extra="forbid") + d_model: int = 512 + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 + + class LinearModelParams(BaseModel): model_config = ConfigDict(extra="forbid") + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 class ModelParams(BaseModel): @@ -109,6 +130,7 @@ class ModelParams(BaseModel): trans_mil: TransMILModelParams = Field(default_factory=TransMILModelParams) mlp: MlpModelParams = Field(default_factory=MlpModelParams) linear: LinearModelParams = Field(default_factory=LinearModelParams) + barspoon: BarspoonParams = Field(default_factory=BarspoonParams) class AdvancedConfig(BaseModel): diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 0ff037cf..8ddfb03d 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -1,8 +1,10 @@ import logging +from collections import Counter from collections.abc import Mapping, Sequence -from typing import Any, Final +from typing import Any, cast import numpy as np +import torch from pydantic import BaseModel from sklearn.model_selection import KFold, StratifiedKFold @@ -10,13 +12,8 @@ from stamp.modeling.data import ( PatientData, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, + load_patient_data_, log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, ) from stamp.modeling.deploy import ( _predict, @@ -28,7 +25,6 @@ from stamp.modeling.train import setup_model_for_training, train_model_ from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import ( - FeaturePath, GroundTruth, PatientId, ) @@ -53,67 +49,31 @@ def categorical_crossval_( config: CrossvalConfig, advanced: AdvancedConfig, ) -> None: - feature_type = detect_feature_type(config.feature_dir) + if config.task is None: + raise ValueError( + "task must be set to 'classification' | 'regression' | 'survival'" + ) + + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) _logger.info(f"Detected feature type: {feature_type}") - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label are is required for survival modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for classification or regression modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - ) - slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( - slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - ) - patient_to_data: Mapping[PatientId, PatientData] = ( - filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - ) - elif feature_type == "patient": - patient_to_data: Mapping[PatientId, PatientData] = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = { - pid: pd.ground_truth for pid, pd in patient_to_data.items() - } - else: - raise RuntimeError(f"Unsupported feature type: {feature_type}") + patient_to_ground_truth = { + pid: pd.ground_truth for pid, pd in patient_to_data.items() + } + + if feature_type not in ("tile", "slide", "patient"): + raise ValueError(f"Unknown feature type: {feature_type}") config.output_dir.mkdir(parents=True, exist_ok=True) splits_file = config.output_dir / "splits.json" @@ -158,18 +118,51 @@ def categorical_crossval_( f"{ground_truths_not_in_split}" ) + categories_for_export: ( + dict[str, list] | list + ) = [] # declare upfront to avoid unbound variable warnings + categories: Sequence[GroundTruth] | list | None = [] + if config.task == "classification": - categories = config.categories or sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - } - ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, - ) + # Determine categories for training (single-target) and for export (supports multi-target) + if isinstance(config.ground_truth_label, str): + categories = config.categories or sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + } + ) + log_patient_class_summary( + patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, + categories=categories, + ) + categories_for_export = cast(list, categories) + else: + # Multi-target: build a mapping from target label -> sorted list of categories + categories_accum: dict[str, set[GroundTruth]] = {} + for patient_data in patient_to_data.values(): + gt = patient_data.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + # Log summary per target + for t, cats in categories_for_export.items(): + ground_truths = [ + pd.ground_truth.get(t) + for pd in patient_to_data.values() + if isinstance(pd.ground_truth, dict) + and pd.ground_truth.get(t) is not None + ] + counter = Counter(ground_truths) + _logger.info( + f"{t} | Total patients: {len(ground_truths)} | " + + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) + ) + # For training, categories can remain None (inferred later) + categories = config.categories or None else: categories = [] @@ -206,12 +199,18 @@ def categorical_crossval_( }, categories=( categories - or sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - } + if categories is not None + else ( + sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + and not isinstance(patient_data.ground_truth, dict) + } + ) + if not isinstance(config.ground_truth_label, Sequence) + else None ) ), train_transform=( @@ -263,30 +262,48 @@ def categorical_crossval_( ) if config.task == "survival": - _to_survival_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - cut_off=getattr(model.hparams, "train_pred_median", None), - ).to_csv(split_dir / "patient-preds.csv", index=False) + if isinstance(config.ground_truth_label, str): + _to_survival_prediction_df( + patient_to_ground_truth=cast( + Mapping[PatientId, str | None], patient_to_ground_truth + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + cut_off=getattr(model.hparams, "train_pred_median", None), + ).to_csv(split_dir / "patient-preds.csv", index=False) + else: + _logger.warning( + "Multi-target survival prediction export not yet supported; skipping CSV save" + ) elif config.task == "regression": if config.ground_truth_label is None: raise RuntimeError("Grounf truth label is required for regression") - _to_regression_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - ).to_csv(split_dir / "patient-preds.csv", index=False) + if isinstance(config.ground_truth_label, str): + _to_regression_prediction_df( + patient_to_ground_truth=cast( + Mapping[PatientId, str | None], patient_to_ground_truth + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, + ).to_csv(split_dir / "patient-preds.csv", index=False) + else: + _logger.warning( + "Multi-target regression prediction export not yet supported; skipping CSV save" + ) else: if config.ground_truth_label is None: raise RuntimeError( "Grounf truth label is required for classification" ) _to_prediction_df( - categories=categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, + predictions=cast( + Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], + predictions, + ), patient_label=config.patient_label, ground_truth_label=config.ground_truth_label, ).to_csv(split_dir / "patient-preds.csv", index=False) @@ -297,24 +314,23 @@ def _get_splits( ) -> _Splits: patients = np.array(list(patient_to_data.keys())) - # Detect survival GT: "time status" - tokens = [str(p.ground_truth).split() for p in patient_to_data.values()] - - if all(len(t) == 2 for t in tokens): - y = np.array([int(t[1]) for t in tokens], dtype=int) - skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=0) - iterator = skf.split(patients, y) - else: - skf = KFold(n_splits=n_splits, shuffle=True, random_state=0) - iterator = skf.split(patients) + # Extract ground truth for stratification. + # For multi-target (dict), use the first target's value + y_strat = np.array( + [ + next(iter(gt.values())) if isinstance(gt, dict) else gt + for gt in [patient.ground_truth for patient in patient_to_data.values()] + ] + ) + skf = spliter(n_splits=n_splits, shuffle=True, random_state=0) splits = _Splits( splits=[ _Split( - train_patients=set(patients[train_idx]), - test_patients=set(patients[test_idx]), + train_patients=set(patients[train_indices]), + test_patients=set(patients[test_indices]), ) - for train_idx, test_idx in iterator + for train_indices, test_indices in skf.split(patients, y_strat) ] ) return splits diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 21b86176..7b2a2d91 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -3,21 +3,34 @@ import logging from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import KW_ONLY, dataclass +from io import BytesIO # accept in _BinaryIOLike at runtime from itertools import groupby from pathlib import Path -from typing import IO, BinaryIO, Counter, Generic, TextIO, TypeAlias, Union, cast +from typing import ( + IO, + Any, + BinaryIO, + Final, + Generic, + List, + TextIO, + TypeAlias, + Union, + cast, +) import h5py import numpy as np import pandas as pd import torch -from jaxtyping import Float + +# Use beartype's typing for PEP-585 deprecation-safe hints +from beartype.typing import Dict from packaging.version import Version from torch import Tensor from torch.utils.data import DataLoader, Dataset import stamp -from stamp.seed import Seed from stamp.types import ( Bags, BagSize, @@ -35,6 +48,7 @@ Task, TilePixels, ) +from stamp.utils.seed import Seed _logger = logging.getLogger("stamp") @@ -43,14 +57,19 @@ __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" -_Bag: TypeAlias = Float[Tensor, "tile feature"] -_EncodedTarget: TypeAlias = Float[Tensor, "category_is_hot"] | Float[Tensor, "1"] # noqa: F821 -_BinaryIOLike: TypeAlias = Union[BinaryIO, IO[bytes]] +_Bag: TypeAlias = Tensor +_EncodedTarget: TypeAlias = ( + Tensor | dict[str, Tensor] +) # Union of encoded targets or multi-target dict +_BinaryIOLike: TypeAlias = Union[ + BinaryIO, IO[bytes], BytesIO +] # includes io.BytesIO for runtime checks """The ground truth, encoded numerically - classification: one-hot float [C] - regression: float [1] +- multi-target: dict[target_name -> one-hot/regression value] """ -_Coordinates: TypeAlias = Float[Tensor, "tile 2"] +_Coordinates: TypeAlias = Tensor @dataclass @@ -59,12 +78,12 @@ class PatientData(Generic[GroundTruthType]): _ = KW_ONLY ground_truth: GroundTruthType - feature_files: Iterable[FeaturePath | BinaryIO] + feature_files: Iterable[FeaturePath | _BinaryIOLike] def tile_bag_dataloader( *, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None, task: Task, categories: Sequence[Category] | None = None, @@ -74,7 +93,7 @@ def tile_bag_dataloader( transform: Callable[[Tensor], Tensor] | None, ) -> tuple[ DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], ]: """Creates a dataloader from patient data for tile-level (bagged) features. @@ -86,103 +105,139 @@ def tile_bag_dataloader( task='regression': returns float targets """ - if task == "classification": - raw_ground_truths = np.array([patient.ground_truth for patient in patient_data]) - categories = ( - categories if categories is not None else list(np.unique(raw_ground_truths)) - ) - # one_hot = torch.tensor(raw_ground_truths.reshape(-1, 1) == categories) - one_hot = torch.tensor( - raw_ground_truths.reshape(-1, 1) == categories, dtype=torch.float32 - ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=one_hot, - transform=transform, - ) - cats_out: Sequence[Category] = list(categories) - elif task == "regression": - raw_targets = np.array( - [ - np.nan if p.ground_truth is None else float(p.ground_truth) - for p in patient_data - ], - dtype=np.float32, - ) - y = torch.from_numpy(raw_targets).reshape(-1, 1) + targets, cats_out = _parse_targets( + patient_data=patient_data, + task=task, + categories=categories, + ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out = [] + is_multitarget = isinstance(targets[0], dict) - elif task == "survival": # Not yet support logistic-harzard - times: list[float] = [] - events: list[float] = [] + collate_fn = _collate_multitarget if is_multitarget else _collate_to_tuple - for p in patient_data: - if p.ground_truth is None: - times.append(np.nan) - events.append(np.nan) - continue + ds = BagDataset( + bags=[patient.feature_files for patient in patient_data], + bag_size=bag_size, + ground_truths=targets, + transform=transform, + ) + dl = DataLoader( + ds, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + ) - try: - time_str, status_str = p.ground_truth.split(" ", 1) + return ( + cast( + DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], + dl, + ), + cats_out, + ) - # Handle missing values encoded as "nan" - if time_str.lower() == "nan": - times.append(np.nan) - else: - times.append(float(time_str)) - - if status_str.lower() == "nan": - events.append(np.nan) - elif status_str.lower() in {"dead", "event", "1", "Yes", "yes"}: - events.append(1.0) - elif status_str.lower() in {"alive", "censored", "0", "No", "no"}: - events.append(0.0) - else: - events.append(np.nan) # unknown status → mark missing - except Exception: +def _parse_targets( + *, + patient_data: Sequence, + task: Task, + categories: Sequence[Category] | None = None, + target_spec: dict[str, Any] | None = None, + target_label: str | None = None, +) -> tuple[ + Union[torch.Tensor, list[dict[str, torch.Tensor]]], + Sequence[Category] | Mapping[str, Sequence[Category]], +]: + """ + Parse raw GroundTruth (str) into model-ready tensors. + This is the ONLY place task semantics live. + """ + + gts = [p.ground_truth for p in patient_data] + + if task == "classification": + if any(isinstance(gt, dict) for gt in gts if gt is not None): + # infer target names from the first non-None dict + first_dict = next(gt for gt in gts if isinstance(gt, dict)) + target_names = list(first_dict.keys()) + + # infer categories per target (ignore None patients, ignore None values) + categories_out: dict[str, list[str]] = {t: [] for t in target_names} + for gt in gts: + if not isinstance(gt, dict): + continue + for t in target_names: + v = gt.get(t) + if v is not None: + categories_out[t].append(v) + + # make unique + sorted + categories_out = { + t: sorted(set(vals)) for t, vals in categories_out.items() + } + + # encode per patient; if gt missing -> all zeros + encoded: list[dict[str, Tensor]] = [] + for gt in gts: + patient_encoded: dict[str, Tensor] = {} + for t in target_names: + cats = categories_out[t] + if not isinstance(gt, dict) or gt.get(t) is None: + one_hot = torch.zeros(len(cats), dtype=torch.float32) + else: + one_hot = torch.tensor( + [gt[t] == c for c in cats], + dtype=torch.float32, + ) + patient_encoded[t] = one_hot + encoded.append(patient_encoded) + + # IMPORTANT: return categories as mapping, not list-of-target-names + return encoded, categories_out + + # single target + unique = {gt for gt in gts if gt is not None} + if len(unique) >= 2 or categories is not None: + raw = np.array([p.ground_truth for p in patient_data]) + categories = categories or list(sorted(unique)) + labels = torch.tensor( + raw.reshape(-1, 1) == categories, + dtype=torch.float32, + ) + return labels, categories + + raise ValueError( + "Only one unique class found in classification task. " + "This is usually a data or configuration error." + ) + + elif task == "regression": + y = torch.tensor( + [np.nan if gt is None else float(gt) for gt in gts], + dtype=torch.float32, + ).reshape(-1, 1) + return y, [] + + elif task == "survival": + times, events = [], [] + for gt in gts: + if gt is None: times.append(np.nan) events.append(np.nan) + continue - # Final tensor shape: (N, 2) - y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) + time_str, status_str = gt.split(" ", 1) + times.append(np.nan if time_str.lower() == "nan" else float(time_str)) + events.append(_parse_survival_status(status_str)) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out: Sequence[Category] = [] # survival has no categories + y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) + return y, [] else: - raise ValueError(f"Unknown task: {task}") - - return ( - cast( - DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - DataLoader( - ds, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - collate_fn=_collate_to_tuple, - worker_init_fn=Seed.get_loader_worker_init() - if Seed._is_set() - else None, - ), - ), - cats_out, - ) + raise ValueError(f"Unsupported task: {task}") def _collate_to_tuple( @@ -210,6 +265,24 @@ def _collate_to_tuple( return (bags, coords, bag_sizes, encoded_targets) +def _collate_multitarget( + items: list[tuple[_Bag, _Coordinates, BagSize, Dict[str, Tensor]]], +) -> tuple[Bags, CoordinatesBatch, BagSizes, Dict[str, Tensor]]: + bags = torch.stack([b for b, _, _, _ in items]) + coords = torch.stack([c for _, c, _, _ in items]) + bag_sizes = torch.tensor([s for _, _, s, _ in items]) + + acc: Dict[str, List[Tensor]] = {} + + for _, _, _, tdict in items: + for k, v in tdict.items(): + acc.setdefault(k, []).append(v) + + targets: Dict[str, Tensor] = {k: torch.stack(v, dim=0) for k, v in acc.items()} + + return bags, coords, bag_sizes, targets + + def patient_feature_dataloader( *, patient_data: Sequence[PatientData[GroundTruth | None]], @@ -237,21 +310,31 @@ def create_dataloader( *, feature_type: str, task: Task, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None = None, batch_size: int, shuffle: bool, num_workers: int, transform: Callable[[Tensor], Tensor] | None, - categories: Sequence[Category] | None = None, -) -> tuple[DataLoader, Sequence[Category]]: + categories: Sequence[Category] | Mapping[str, Sequence[Category]] | None = None, +) -> tuple[DataLoader, Sequence[Category] | Mapping[str, Sequence[Category]]]: """Unified dataloader for all feature types and tasks.""" if feature_type == "tile": + # For multi-target classification, categories may be a mapping from + # target name to per-target categories. _parse_targets (used inside + # tile_bag_dataloader) only consumes explicit categories for the + # single-target case, so we pass a sequence or None here. + cats_arg: Sequence[Category] | None + if isinstance(categories, Mapping): + cats_arg = None + else: + cats_arg = categories + return tile_bag_dataloader( patient_data=patient_data, bag_size=bag_size, task=task, - categories=categories, + categories=cats_arg, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, @@ -263,21 +346,32 @@ def create_dataloader( if task == "classification": raw = np.array([p.ground_truth for p in patient_data]) - categories = categories or list(np.unique(raw)) - labels = torch.tensor(raw.reshape(-1, 1) == categories, dtype=torch.float32) - elif task == "regression": + categories_out = categories or list(np.unique(raw)) labels = torch.tensor( - [ - float(gt) - for gt in (p.ground_truth for p in patient_data) - if gt is not None - ], - dtype=torch.float32, - ).reshape(-1, 1) + raw.reshape(-1, 1) == categories_out, dtype=torch.float32 + ) + elif task == "regression": + values: list[float] = [] + for gt in (p.ground_truth for p in patient_data): + if gt is None: + continue + if isinstance(gt, dict): + # Use first value for multi-target regression + first_val = next(iter(gt.values())) + values.append(float(first_val)) + else: + values.append(float(gt)) + + labels = torch.tensor(values, dtype=torch.float32).reshape(-1, 1) elif task == "survival": times, events = [], [] for p in patient_data: - t, e = (p.ground_truth or "nan nan").split(" ", 1) + if isinstance(p.ground_truth, dict): + # Multi-target survival: use first target + val = list(p.ground_truth.values())[0] + t, e = (val or "nan nan").split(" ", 1) + else: + t, e = (p.ground_truth or "nan nan").split(" ", 1) times.append(float(t) if t.lower() != "nan" else np.nan) events.append(_parse_survival_status(e)) @@ -340,7 +434,7 @@ def load_patient_level_data( clini_table: Path, feature_dir: Path, patient_label: PandasLabel, - ground_truth_label: PandasLabel | None = None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, time_label: PandasLabel | None = None, status_label: PandasLabel | None = None, feature_ext: str = ".h5", @@ -419,7 +513,7 @@ class BagDataset(Dataset[tuple[_Bag, _Coordinates, BagSize, _EncodedTarget]]): If `bag_size` is None, all the samples will be used. """ - ground_truths: Float[Tensor, "index category_is_hot"] | Float[Tensor, "index 1"] + ground_truths: Tensor | list[dict[str, Tensor]] # ground_truths: Bool[Tensor, "index category_is_hot"] # """The ground truth for each bag, one-hot encoded.""" @@ -444,9 +538,19 @@ def __getitem__( for bag_file in self.bags[index]: with h5py.File(bag_file, "r") as h5: if "feats" in h5: - arr = h5["feats"][:] # pyright: ignore[reportIndexIssue] # original STAMP files + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + arr = feats_obj[:] # original STAMP files else: - arr = h5["patch_embeddings"][:] # type: ignore # your Kronos files + embeddings_obj = h5["patch_embeddings"] + if not isinstance(embeddings_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'patch_embeddings' to be an HDF5 dataset but got {type(embeddings_obj)}" + ) + arr = embeddings_obj[:] # your Kronos files feats.append(torch.from_numpy(arr)) coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) @@ -480,7 +584,7 @@ class PatientFeatureDataset(Dataset): def __init__( self, - feature_files: Sequence[FeaturePath | BinaryIO], + feature_files: Sequence[FeaturePath | _BinaryIOLike], ground_truths: Tensor, # shape: [num_samples, num_classes] transform: Callable[[Tensor], Tensor] | None, ): @@ -496,7 +600,12 @@ def __len__(self): def __getitem__(self, idx: int): feature_file = self.feature_files[idx] with h5py.File(feature_file, "r") as h5: - feats = torch.from_numpy(h5["feats"][:]) # pyright: ignore[reportIndexIssue] + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + feats = torch.from_numpy(feats_obj[:]) # Accept [V] or [1, V] if feats.ndim == 2 and feats.shape[0] == 1: feats = feats[0] @@ -529,7 +638,7 @@ def mpp(self) -> SlideMPP: def get_coords(feature_h5: h5py.File) -> CoordsInfo: - # --- NEW: handle missing coords ----multiplex data bypass: no coords found; generated fake coords + # NEW: handle missing coords - multiplex data bypass: no coords found; generated fake coords if "coords" not in feature_h5: feats_obj = feature_h5["patch_embeddings"] @@ -545,10 +654,15 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: tile_size_px = TilePixels(0) return CoordsInfo(coords_um, tile_size_um, tile_size_px) - coords: np.ndarray = feature_h5["coords"][:] # type: ignore - coords_um: np.ndarray | None = None + coords_obj = feature_h5["coords"] + if not isinstance(coords_obj, h5py.Dataset): + raise RuntimeError( + f"{feature_h5.filename}: expected 'coords' to be an HDF5 dataset but got {type(coords_obj)}" + ) + coords: np.ndarray = coords_obj[:] tile_size_um: Microns | None = None tile_size_px: TilePixels | None = None + coords_um: np.ndarray | None = None if (tile_size := feature_h5.attrs.get("tile_size", None)) and feature_h5.attrs.get( "unit", None ) == "um": @@ -583,7 +697,15 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: ) if not tile_size_px and "tile_size_px" in feature_h5.attrs: - tile_size_px = TilePixels(int(feature_h5.attrs["tile_size_px"])) # pyright: ignore[reportArgumentType] + tile_size_px_attr = feature_h5.attrs.get("tile_size_px") + if tile_size_px_attr is not None and isinstance( + tile_size_px_attr, (int, float) + ): + tile_size_px = TilePixels(int(tile_size_px_attr)) + else: + raise RuntimeError( + "Invalid or missing 'tile_size_px' attribute in the feature file." + ) if not tile_size_um or coords_um is None: raise RuntimeError( @@ -627,33 +749,71 @@ def patient_to_ground_truth_from_clini_table_( *, clini_table_path: Path | TextIO, patient_label: PandasLabel, - ground_truth_label: PandasLabel, -) -> dict[PatientId, GroundTruth]: - """Loads the patients and their ground truths from a clini table.""" + ground_truth_label: PandasLabel | Sequence[PandasLabel], +) -> ( + dict[PatientId, GroundTruth | None] | dict[PatientId, dict[str, GroundTruth | None]] +): + """Loads the patients and their ground truths from a clini table. + + `ground_truth_label` may be either a single column name (str) or a sequence + of column names. In the latter case the returned mapping will contain a + dict mapping column -> value for each patient (supporting multi-target + setups). + """ + # Normalize to list for uniform handling + if isinstance(ground_truth_label, str): + cols = [patient_label, ground_truth_label] + multi = False + target_cols_inner: list[PandasLabel] = [] + else: + cols = [patient_label, *list(ground_truth_label)] + multi = True + target_cols_inner = [c for c in cols if c != patient_label] + clini_df = read_table( clini_table_path, - usecols=[patient_label, ground_truth_label], + usecols=cols, dtype=str, - ).dropna() + ) + + # If multi-target, keep rows where at least one target is present; for + # single target behave like before and drop rows missing the value. + if multi: + clini_df = clini_df.dropna(subset=target_cols_inner, how="all") + else: + clini_df = clini_df.dropna(subset=[ground_truth_label]) + try: - patient_to_ground_truth: Mapping[PatientId, GroundTruth] = clini_df.set_index( - patient_label, verify_integrity=True - )[ground_truth_label].to_dict() + if multi: + # Build mapping patient -> {col: value} + result: dict[PatientId, dict[str, GroundTruth | None]] = {} + for _, row in clini_df.iterrows(): + pid = row[patient_label] + # Convert pandas nan to None and keep strings otherwise + result[pid] = { + col: (None if pd.isna(row[col]) else str(row[col])) + for col in target_cols_inner + } + return result + else: + patient_to_ground_truth: Mapping[PatientId, str] = cast( + Mapping[PatientId, str], + clini_df.set_index(patient_label, verify_integrity=True)[ + cast(PandasLabel, ground_truth_label) + ].to_dict(), + ) + return cast(dict[PatientId, GroundTruth | None], patient_to_ground_truth) except KeyError as e: if patient_label not in clini_df: raise ValueError( f"{patient_label} was not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - elif ground_truth_label not in clini_df: + else: raise ValueError( - f"{ground_truth_label} was not found in clini table " + f"One or more ground truth columns were not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - else: - raise e from e - - return patient_to_ground_truth def patient_to_survival_from_clini_table_( @@ -667,7 +827,6 @@ def patient_to_survival_from_clini_table_( Loads patients and their survival ground truths (time + event) from a clini table. Returns - ------- dict[PatientId, GroundTruth] Mapping patient_id -> "time status" (e.g. "302 dead", "476 alive"). """ @@ -780,7 +939,9 @@ def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: def filter_complete_patient_data_( *, - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth: Mapping[ + PatientId, GroundTruth | dict[str, GroundTruth] | None + ], slide_to_patient: Mapping[FeaturePath, PatientId], drop_patients_with_missing_ground_truth: bool, ) -> Mapping[PatientId, PatientData]: @@ -865,7 +1026,7 @@ def _log_patient_slide_feature_inconsistencies( ) -def get_stride(coords: Float[Tensor, "tile 2"]) -> float: +def get_stride(coords: Tensor) -> float: """Gets the minimum step width between any two coordintes.""" xs: Tensor = coords[:, 0].unique(sorted=True) ys: Tensor = coords[:, 1].unique(sorted=True) @@ -919,26 +1080,130 @@ def _parse_survival_status(value) -> int | None: ) +def load_patient_data_( + *, + feature_dir: Path, + clini_table: Path, + slide_table: Path | None, + task: Task, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + time_label: PandasLabel | None, + status_label: PandasLabel | None, + patient_label: PandasLabel, + filename_label: PandasLabel, + drop_patients_with_missing_ground_truth: bool = True, +) -> tuple[Mapping[PatientId, PatientData], str]: + """Load patient data based on feature type (tile, slide, or patient). + + This consolidates the common data loading logic used across train, crossval, and deploy. + + Returns: + (patient_to_data, feature_type) + """ + feature_type = detect_feature_type(feature_dir) + + if feature_type in ("tile", "slide"): + if slide_table is None: + raise ValueError("A slide table is required for tile/slide-level features") + + # Load ground truth based on task + if task == "survival": + if time_label is None or status_label is None: + raise ValueError( + "Both time_label and status_label are required for survival modeling" + ) + patient_to_ground_truth = patient_to_survival_from_clini_table_( + clini_table_path=clini_table, + time_label=time_label, + status_label=status_label, + patient_label=patient_label, + ) + else: + if ground_truth_label is None: + raise ValueError( + "Ground truth label is required for classification or regression modeling" + ) + patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( + clini_table_path=clini_table, + ground_truth_label=ground_truth_label, + patient_label=patient_label, + ) + + # Link slides to patients + slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( + slide_to_patient_from_slide_table_( + slide_table_path=slide_table, + feature_dir=feature_dir, + patient_label=patient_label, + filename_label=filename_label, + ) + ) + + # Filter to complete patient data + patient_to_data = filter_complete_patient_data_( + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | dict[str, GroundTruth] | None], + patient_to_ground_truth, + ), + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=drop_patients_with_missing_ground_truth, + ) + elif feature_type == "patient": + patient_to_data = load_patient_level_data( + task=task, + clini_table=clini_table, + feature_dir=feature_dir, + patient_label=patient_label, + ground_truth_label=ground_truth_label, + time_label=time_label, + status_label=status_label, + ) + else: + raise RuntimeError(f"Unknown feature type: {feature_type}") + + return patient_to_data, feature_type + + def log_patient_class_summary( *, patient_to_data: Mapping[PatientId, PatientData], categories: Sequence[Category] | None, - prefix: str = "", ) -> None: + """ + Logs class distribution. + Supports both single-target and multi-target classification. + """ + ground_truths = [ - pd.ground_truth - for pd in patient_to_data.values() - if pd.ground_truth is not None + p.ground_truth for p in patient_to_data.values() if p.ground_truth is not None ] if not ground_truths: - _logger.warning(f"{prefix}No ground truths available to summarize.") + _logger.warning("No ground truths available for summary.") return - cats = categories or sorted(set(ground_truths)) - counter = Counter(ground_truths) + # Multi-target + if isinstance(ground_truths[0], dict): + # Collect per-target values + per_target: dict[str, list] = {} - _logger.info( - f"{prefix}Total patients: {len(ground_truths)} | " - + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) - ) + for gt in ground_truths: + for key, value in gt.items(): + per_target.setdefault(key, []).append(value) + + for target_name, values in per_target.items(): + counts = {} + for v in values: + counts[v] = counts.get(v, 0) + 1 + + _logger.info( + f"[Multi-target] Target '{target_name}' distribution: {counts}" + ) + + # Single-target + else: + counts = {} + for gt in ground_truths: + counts[gt] = counts.get(gt, 0) + 1 + + _logger.info(f"Class distribution: {counts}") diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 905c6005..d3b29ebd 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd import torch -from jaxtyping import Float +import torch.nn.functional as F from lightning.pytorch.accelerators.accelerator import Accelerator from stamp.modeling.data import ( @@ -20,7 +20,7 @@ slide_to_patient_from_slide_table_, ) from stamp.modeling.registry import ModelName, load_model_class -from stamp.types import GroundTruth, PandasLabel, PatientId +from stamp.types import Category, GroundTruth, PandasLabel, PatientId __all__ = ["deploy_categorical_model_"] @@ -32,6 +32,13 @@ Logit: TypeAlias = float +# Prediction type aliases +PredictionSingle: TypeAlias = torch.Tensor +PredictionMulti: TypeAlias = dict[str, torch.Tensor] +PredictionsType: TypeAlias = Mapping[ + PatientId, Union[PredictionSingle, PredictionMulti] +] + def load_model_from_ckpt(path: Union[str, Path]): ckpt = torch.load(path, map_location="cpu", weights_only=False) @@ -50,7 +57,7 @@ def deploy_categorical_model_( clini_table: Path | None, slide_table: Path | None, feature_dir: Path, - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, patient_label: PandasLabel, @@ -126,7 +133,12 @@ def deploy_categorical_model_( # classification/regression: still use ground_truth_label if ( len( - ground_truth_labels := set(model.ground_truth_label for model in models) + ground_truth_labels := { + tuple(model.ground_truth_label) + if isinstance(model.ground_truth_label, list) + else (model.ground_truth_label,) + for model in models + } ) != 1 ): @@ -145,17 +157,21 @@ def deploy_categorical_model_( f"{ground_truth_label} vs {model_ground_truth_label}" ) - ground_truth_label = ground_truth_label or model_ground_truth_label + ground_truth_label = ground_truth_label or cast( + PandasLabel, model_ground_truth_label + ) output_dir.mkdir(exist_ok=True, parents=True) model_categories = None if task == "classification": # Ensure the categories were the same between all models - category_sets = {tuple(m.categories) for m in models} + category_sets = { + tuple(cast(Sequence[GroundTruth], m.categories)) for m in models + } if len(category_sets) != 1: raise RuntimeError(f"Categories differ between models: {category_sets}") - model_categories = list(models[0].categories) + model_categories = list(cast(Sequence[GroundTruth], models[0].categories)) # Data loading logic if feature_type in ("tile", "slide"): @@ -171,6 +187,15 @@ def deploy_categorical_model_( ) if clini_table is not None: if task == "survival": + if not hasattr(models[0], "time_label") or not isinstance( + models[0].time_label, str + ): + raise AttributeError("Model is missing valid 'time_label' (str).") + if not hasattr(models[0], "status_label") or not isinstance( + models[0].status_label, str + ): + raise AttributeError("Model is missing valid 'status_label' (str).") + patient_to_ground_truth = patient_to_survival_from_clini_table_( clini_table_path=clini_table, patient_label=patient_label, @@ -192,7 +217,10 @@ def deploy_categorical_model_( patient_id: None for patient_id in set(slide_to_patient.values()) } patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth, + ), slide_to_patient=slide_to_patient, drop_patients_with_missing_ground_truth=False, ) @@ -241,16 +269,39 @@ def deploy_categorical_model_( "regression": _to_regression_prediction_df, "survival": _to_survival_prediction_df, }[task] - all_predictions: list[Mapping[PatientId, Float[torch.Tensor, "category"]]] = [] # noqa: F821 + all_predictions: list[PredictionsType] = [] + categories_for_export: ( + Sequence[Category] | Mapping[str, Sequence[Category]] | None + ) = cast(Sequence[Category] | Mapping[str, Sequence[Category]] | None, None) for model_i, model in enumerate(models): predictions = _predict( model=model, - test_dl=test_dl, # pyright: ignore[reportPossiblyUnboundVariable] + test_dl=test_dl, patient_ids=patient_ids, accelerator=accelerator, ) all_predictions.append(predictions) + if isinstance(next(iter(predictions.values())), dict): + # Multi-target case: gather categories across all targets for export (use model categories if available, else infer from GT) + categories_accum: dict[str, set[GroundTruth]] = {} + + for pd_item in patient_to_data.values(): + gt = pd_item.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + + else: + # Single-target case: use categories from model if available, else infer from GT + if task == "classification": + categories_for_export = models[0].categories + else: + categories_for_export = [] + # cut-off values from survival ckpt cut_off = ( getattr(model.hparams, "train_pred_median", None) @@ -261,7 +312,7 @@ def deploy_categorical_model_( # Only save individual model files when deploying multiple models (ensemble) if len(models) > 1: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -270,7 +321,7 @@ def deploy_categorical_model_( ).to_csv(output_dir / f"patient-preds-{model_i}.csv", index=False) else: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -279,17 +330,29 @@ def deploy_categorical_model_( ).to_csv(output_dir / "patient-preds.csv", index=False) if task == "classification": - # TODO we probably also want to save the 95% confidence interval in addition to the mean + # compute mean prediction across models (supports single- and multi-target) + mean_preds: dict[PatientId, object] = {} + for pid in patient_ids: + model_preds = cast( + list[torch.Tensor], [preds[pid] for preds in all_predictions] + ) + firstp = model_preds[0] + if isinstance(firstp, dict): + # per-target averaging + mean_preds[pid] = { + t: torch.stack([p[t] for p in model_preds]).mean(dim=0) + for t in firstp.keys() + } + else: + mean_preds[pid] = torch.stack(model_preds).mean(dim=0) + + assert categories_for_export is not None, ( + "categories_for_export must be set before use" + ) df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions={ - # Mean prediction - patient_id: torch.stack( - [predictions[patient_id] for predictions in all_predictions] - ).mean(dim=0) - for patient_id in patient_ids - }, + predictions=mean_preds, patient_label=patient_label, ground_truth_label=ground_truth_label, ).to_csv(output_dir / "patient-preds_95_confidence_interval.csv", index=False) @@ -301,7 +364,7 @@ def _predict( test_dl: torch.utils.data.DataLoader, patient_ids: Sequence[PatientId], accelerator: str | Accelerator, -) -> Mapping[PatientId, Float[torch.Tensor, "..."]]: +) -> PredictionsType: model = model.eval() torch.set_float32_matmul_precision("medium") @@ -310,8 +373,11 @@ def _predict( getattr(model, "train_patients", []) ) | set(getattr(model, "valid_patients", [])) if overlap := patients_used_for_training & set(patient_ids): - raise ValueError( - f"some of the patients in the validation set were used during training: {overlap}" + _logger.critical( + "DATA LEAKAGE DETECTED: %d patient(s) in deployment set were used " + "during training/validation. Overlapping IDs: %s", + len(overlap), + sorted(overlap), ) trainer = lightning.Trainer( @@ -320,51 +386,190 @@ def _predict( logger=False, ) - raw_preds = torch.concat(cast(list[torch.Tensor], trainer.predict(model, test_dl))) + outs = trainer.predict(model, test_dl) + + if not outs: + return {} + + first = outs[0] + + # Multi-target case: each element of outs is a dict[target_label -> tensor] + if isinstance(first, dict): + per_target_lists: dict[str, list[torch.Tensor]] = {} + for out in outs: + if not isinstance(out, dict): + raise RuntimeError("Mixed prediction output types from model") + for k, v in out.items(): + per_target_lists.setdefault(k, []).append(v) + + per_target_tensors: dict[str, torch.Tensor] = { + k: torch.cat(vlist, dim=0) for k, vlist in per_target_lists.items() + } + + if getattr(model.hparams, "task", None) == "classification": + for k in list(per_target_tensors.keys()): + per_target_tensors[k] = torch.softmax(per_target_tensors[k], dim=1) + + # build per-patient dicts + num_preds = next(iter(per_target_tensors.values())).shape[0] + predictions: dict[PatientId, dict[str, torch.Tensor]] = {} + for i, pid in enumerate(patient_ids[:num_preds]): + predictions[pid] = { + k: per_target_tensors[k][i] for k in per_target_tensors.keys() + } + + return predictions + + # Single-target case: each element of outs is a tensor + outs_single = cast(list[torch.Tensor], outs) + + raw_preds = torch.cat(outs_single, dim=0) if getattr(model.hparams, "task", None) == "classification": - predictions = torch.softmax(raw_preds, dim=1) + raw_preds = torch.softmax(raw_preds, dim=1) elif getattr(model.hparams, "task", None) == "survival": - predictions = raw_preds.squeeze(-1) # (N,) risk scores - else: # regression - predictions = raw_preds + raw_preds = raw_preds.squeeze(-1) + + result: dict[PatientId, torch.Tensor] = { + pid: raw_preds[i] for i, pid in enumerate(patient_ids) + } - return dict(zip(patient_ids, predictions, strict=True)) + return result def _to_prediction_df( *, - categories: Sequence[GroundTruth], - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], - predictions: Mapping[PatientId, torch.Tensor], + categories: Sequence[GroundTruth] | Mapping[str, Sequence[GroundTruth]], + patient_to_ground_truth: Mapping[PatientId, GroundTruth | None] + | Mapping[PatientId, dict[str, GroundTruth | None]], + predictions: Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], patient_label: PandasLabel, - ground_truth_label: PandasLabel, + ground_truth_label: PandasLabel | Sequence[PandasLabel], **kwargs, ) -> pd.DataFrame: - """Compiles deployment results into a DataFrame.""" - return pd.DataFrame( - [ - { - patient_label: patient_id, - ground_truth_label: patient_to_ground_truth.get(patient_id), - "pred": categories[int(prediction.argmax())], - **{ - f"{ground_truth_label}_{category}": prediction[i_cat].item() - for i_cat, category in enumerate(categories) - }, - "loss": ( - torch.nn.functional.cross_entropy( - prediction.reshape(1, -1), - torch.tensor(np.where(np.array(categories) == ground_truth)[0]), - ).item() - if (ground_truth := patient_to_ground_truth.get(patient_id)) - is not None - else None - ), - } - for patient_id, prediction in predictions.items() - ] - ).sort_values(by="loss") + """Compiles deployment results into a DataFrame. + + Supports single-target and multi-target classification. + - Single-target: `predictions` maps patient -> tensor and `categories` is a sequence. + - Multi-target: `predictions` maps patient -> dict[target_label -> tensor] and + `categories` is a mapping from target_label -> sequence of category names. + """ + first_pred = next(iter(predictions.values())) + + # Multi-target predictions: dict per patient + if isinstance(first_pred, dict): + # determine target labels + target_labels = list(cast(dict, first_pred).keys()) + + # prepare categories mapping + if isinstance(categories, dict): + cats_map = categories + else: + # try infer categories list ordering: assume categories is a sequence-of-sequences + cats_map = {} + if isinstance(categories, Sequence): + try: + for i, t in enumerate(target_labels): + cats_map[t] = list( + cast(Sequence[Sequence[GroundTruth]], categories)[i] + ) + except Exception: + cats_map = {} + + # infer missing category lists from ground truth + if any(t not in cats_map for t in target_labels): + inferred: dict[str, set] = {t: set() for t in target_labels} + for pid, gt in patient_to_ground_truth.items(): + if isinstance(gt, dict): + for t in target_labels: + val = gt.get(t) + if val is not None: + inferred[t].add(val) + for t in target_labels: + if t not in cats_map: + cats_map[t] = sorted(inferred.get(t, [])) + + rows = [] + for pid, pred_dict in predictions.items(): + row: dict = {patient_label: pid} + gt_entry = patient_to_ground_truth.get(pid) + # ground truths per target + for t in target_labels: + if isinstance(gt_entry, dict): + row[t] = gt_entry.get(t) + else: + row[t] = gt_entry + + total_loss = 0.0 + has_loss = False + for t in target_labels: + tensor = cast(dict[str, torch.Tensor], pred_dict)[t] + probs = tensor.detach().cpu() + cats: Sequence[GroundTruth] = cast( + Sequence[GroundTruth], + cats_map.get(t, []), + ) + if probs.numel() == 1: + row[f"pred_{t}"] = float(probs.item()) + else: + pred_idx = int(probs.argmax().item()) + row[f"pred_{t}"] = ( + cats[pred_idx] if pred_idx < len(cats) else pred_idx + ) + for i_cat, cat in enumerate(cats): + if i_cat < probs.shape[0]: + row[f"{t}_{cat}"] = float(probs[i_cat].item()) + else: + row[f"{t}_{cat}"] = None + + if isinstance(gt_entry, dict) and (gt := gt_entry.get(t)) is not None: + try: + target_index = int(np.where(np.array(cats) == gt)[0][0]) + loss = torch.nn.functional.cross_entropy( + probs.reshape(1, -1), torch.tensor([target_index]) + ).item() + total_loss += loss + has_loss = True + except Exception: + pass + + row["loss"] = total_loss if has_loss else None + rows.append(row) + + return pd.DataFrame(rows) + + # Single-target (original behaviour) + if not all(isinstance(p, torch.Tensor) for p in predictions.values()): + raise TypeError("Single-target block received multi-target dict predictions.") + + predictions = cast(Mapping[PatientId, torch.Tensor], predictions) + + rows = [] + for pid, prediction in predictions.items(): + gt = patient_to_ground_truth.get(pid) + cats = cast(Sequence[GroundTruth], categories) + pred_idx = int(prediction.argmax()) + row = { + patient_label: pid, + ground_truth_label: gt, + "pred": cats[pred_idx], + **{ + f"{ground_truth_label}_{category}": float(prediction[i_cat].item()) + for i_cat, category in enumerate(cats) + }, + "loss": ( + torch.nn.functional.cross_entropy( + prediction.reshape(1, -1), + torch.tensor(np.where(np.array(cats) == gt)[0]), + ).item() + if gt is not None + else None + ), + } + rows.append(row) + + return pd.DataFrame(rows).sort_values(by="loss") def _to_regression_prediction_df( @@ -383,8 +588,6 @@ def _to_regression_prediction_df( - pred (float) - loss (per-sample L1 loss if GT available, else None) """ - import torch.nn.functional as F - return pd.DataFrame( [ { diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 59a0a3aa..0b6a3885 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -6,14 +6,21 @@ from typing import Any, TypeAlias import lightning -import numpy as np import torch + +# Use beartype.typing.Mapping to avoid PEP-585 deprecation warnings in beartype +from beartype.typing import Mapping from jaxtyping import Bool, Float from packaging.version import Version from torch import Tensor, nn, optim from torchmetrics.classification import MulticlassAUROC import stamp +from stamp.modeling.models.barspoon import ( + EncDecTransformer, + LitMilClassificationMixin, + TargetLabel, +) from stamp.modeling.models.cox import neg_partial_log_likelihood from stamp.types import ( Bags, @@ -143,6 +150,19 @@ def on_train_batch_end(self, outputs, batch, batch_idx): ) +class _TileLevelMixin: + """Mixin for tile-level models providing shared MIL masking logic.""" + + @staticmethod + def _mask_from_bags(bags: Bags, bag_sizes: BagSizes) -> Bool[Tensor, "batch tile"]: + """Create attention mask for padded tiles in variable-length bags.""" + max_possible_bag_size = bags.size(1) + mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( + 0 + ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) + return mask + + class LitBaseClassifier(Base): """ PyTorch Lightning wrapper for tile level and patient level clasification. @@ -194,15 +214,16 @@ def __init__( self.class_weights = category_weights self.valid_auroc = MulticlassAUROC(len(categories)) # Number classes - self.categories = np.array(categories) + self.categories = list(categories) self.hparams.update({"task": "classification"}) -class LitTileClassifier(LitBaseClassifier): +class LitTileClassifier(_TileLevelMixin, LitBaseClassifier): """ - PyTorch Lightning wrapper for the model used in weakly supervised - learning settings, such as Multiple Instance Learning (MIL) for whole-slide images or patch-based data. + PyTorch Lightning wrapper for tile-level MIL classification. + + Used in weakly supervised settings for whole-slide images or patch-based data. """ supported_features = ["tile"] @@ -244,7 +265,6 @@ def _step( ) if step_name == "validation": - # TODO this is a bit ugly, we'd like to have `_step` without special cases self.valid_auroc.update(logits, targets.long().argmax(dim=-1)) self.log( f"{step_name}_auroc", @@ -286,31 +306,19 @@ def predict_step( # adding a mask here will *drastically* and *unbearably* increase memory usage return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideClassifier(LitBaseClassifier): - """ - PyTorch Lightning wrapper for MLPClassifier. - """ + """PyTorch Lightning wrapper for slide/patient-level classification.""" supported_features = ["slide"] def forward(self, x: Tensor) -> Tensor: return self.model(x) - def _step(self, batch, step_name: str): - feats, targets = batch + def _step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str + ) -> Loss: + feats, targets = list(batch) # Works for both tuple and list logits = self.model(feats.float()) loss = nn.functional.cross_entropy( logits, @@ -336,17 +344,25 @@ def _step(self, batch, step_name: str): ) return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch return self.model(feats) @@ -397,9 +413,10 @@ def _compute_loss(y_true: Tensor, y_pred: Tensor) -> Loss: return nn.functional.l1_loss(y_true, y_pred) -class LitTileRegressor(LitBaseRegressor): +class LitTileRegressor(_TileLevelMixin, LitBaseRegressor): """ - PyTorch Lightning wrapper for weakly supervised / MIL regression at tile/patient level. + PyTorch Lightning wrapper for tile-level MIL regression. + Produces a single continuous output per bag (dim_output = 1). """ @@ -486,38 +503,22 @@ def predict_step( # keep memory usage low as in classifier return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideRegressor(LitBaseRegressor): - """ - PyTorch Lightning wrapper for slide-level or patient-level regression. - Produces a single continuous output per slide (dim_output = 1). - """ + """PyTorch Lightning wrapper for slide/patient-level regression.""" supported_features = ["slide"] def forward(self, feats: Tensor) -> Tensor: - """Forward pass for slide-level features.""" return self.model(feats.float()) def _step( self, *, - batch: tuple[Tensor, Tensor], + batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str, ) -> Loss: - feats, targets = batch + feats, targets = list(batch) # Works for both tuple and list preds = self.model(feats.float(), mask=None) # (B, 1) y = targets.to(preds).float() @@ -534,7 +535,6 @@ def _step( ) if step_name == "validation": - # same metrics as LitTileRegressor p = preds.squeeze(-1) t = y.squeeze(-1) self.log( @@ -547,17 +547,25 @@ def _step( return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch return self.model(feats.float()) @@ -702,12 +710,12 @@ def on_train_epoch_end(self): self.hparams.update({"train_pred_median": self.train_pred_median}) -class LitTileSurvival(LitSurvivalBase): +class LitTileSurvival(_TileLevelMixin, LitSurvivalBase): """ - Tile-level or patch-level survival analysis. - Expects dataloader batches like: - (bags, coords, bag_sizes, targets) - where targets is shape (B,2): [:,0]=time, [:,1]=event (1=event, 0=censored). + Tile-level survival analysis with Cox proportional hazards loss. + + Expects batches: (bags, coords, bag_sizes, targets) + where targets.shape = (B, 2): [:,0]=time, [:,1]=event (0=censored, 1=event). """ supported_features = ["tile"] @@ -818,3 +826,75 @@ class LitPatientSurvival(LitSlideSurvival): """ supported_features = ["patient"] + + +class LitEncDecTransformer(LitMilClassificationMixin): + def __init__( + self, + *, + dim_input: int, + category_weights: Mapping[TargetLabel, torch.Tensor], + model_class: type[nn.Module] | None = None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + categories: Mapping[str, Sequence[Category]], + # Model parameters + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + # Other hparams + learning_rate: float = 1e-4, + **hparams: Any, + ) -> None: + weights_dict: dict[TargetLabel, torch.Tensor] = dict(category_weights) + super().__init__( + weights=weights_dict, + learning_rate=learning_rate, + ) + _ = hparams # so we don't get unused parameter warnings + + self.model = EncDecTransformer( + d_features=dim_input, + target_n_outs={t: len(w) for t, w in category_weights.items()}, + d_model=d_model, + num_encoder_heads=num_encoder_heads, + num_decoder_heads=num_decoder_heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + positional_encoding=positional_encoding, + ) + + self.hparams["supported_features"] = "tile" + self.hparams.update({"task": "classification"}) + # ---- Normalize categories into strict mapping[str, list[str]] ---- + if not isinstance(categories, Mapping): + raise ValueError( + "Multi-target classification requires categories as Mapping[str, Sequence[str]]." + ) + + normalized_categories: dict[str, list[str]] = { + str(k): list(v) for k, v in categories.items() + } + + # Sanity check: head size must match category size + for t, w in category_weights.items(): + if t not in normalized_categories: + raise ValueError(f"Missing categories for target '{t}'") + if len(normalized_categories[t]) != len(w): + raise ValueError( + f"Category mismatch for target '{t}': " + f"{len(normalized_categories[t])} categories " + f"but head has {len(w)} outputs." + ) + + self.ground_truth_label = ground_truth_label + self.categories = normalized_categories + + self.save_hyperparameters() + + def forward(self, *args): + return self.model(*args) diff --git a/src/stamp/modeling/models/barspoon.py b/src/stamp/modeling/models/barspoon.py new file mode 100644 index 00000000..f841bb3d --- /dev/null +++ b/src/stamp/modeling/models/barspoon.py @@ -0,0 +1,367 @@ +""" +Port from https://github.com/KatherLab/barspoon-transformer +""" + +import re +from typing import Any, TypeAlias + +import lightning +import torch +import torch.nn.functional as F +import torchmetrics +from packaging.version import Version +from torch import nn +from torchmetrics.classification import MulticlassAUROC +from torchmetrics.utilities.data import dim_zero_cat + +import stamp +from stamp.types import Bags, BagSizes, CoordinatesBatch + +__all__ = [ + "EncDecTransformer", + "LitMilClassificationMixin", + "SafeMulticlassAUROC", +] + + +TargetLabel: TypeAlias = str + + +class EncDecTransformer(nn.Module): + """An encoder decoder architecture for multilabel classification tasks + + This architecture is a modified version of the one found in [Attention Is + All You Need][1]: First, we project the features into a lower-dimensional + feature space, to prevent the transformer architecture's complexity from + exploding for high-dimensional features. We add sinusodial [positional + encodings][1]. We then encode these projected input tokens using a + transformer encoder stack. Next, we decode these tokens using a set of + class tokens, one per output label. Finally, we forward each of the decoded + tokens through a fully connected layer to get a label-wise prediction. + + PE1 + | + +--+ v +---+ + t1 ->|FC|--+-->| |--+ + . +--+ | E | | + . | x | | + . +--+ | m | | + tn ->|FC|--+-->| |--+ + +--+ ^ +---+ | + | | + PEn v + +---+ +---+ + c1 ---------------->| |-->|FC1|--> s1 + . | D | +---+ . + . | x | . + . | l | +---+ . + ck ---------------->| |-->|FCk|--> sk + +---+ +---+ + + We opted for this architecture instead of a more traditional [Vision + Transformer][2] to improve performance for multi-label predictions with many + labels. Our experiments have shown that adding too many class tokens to a + vision transformer decreases its performance, as the same weights have to + both process the tiles' information and the class token's processing. Using + an encoder-decoder architecture alleviates these issues, as the data-flow of + the class tokens is completely independent of the encoding of the tiles. + Furthermore, analysis has shown that there is almost no interaction between + the different classes in the decoder. While this points to the decoder + being more powerful than needed in practice, this also means that each + label's prediction is mostly independent of the others. As a consequence, + noisy labels will not negatively impact the accuracy of non-noisy ones. + + In our experiments so far we did not see any improvement by adding + positional encodings. We tried + + 1. [Sinusodal encodings][1] + 2. Adding absolute positions to the feature vector, scaled down so the + maximum value in the training dataset is 1. + + Since neither reduced performance and the author percieves the first one to + be more elegant (as the magnitude of the positional encodings is bounded), + we opted to keep the positional encoding regardless in the hopes of it + improving performance on future tasks. + + The architecture _differs_ from the one descibed in [Attention Is All You + Need][1] as follows: + + 1. There is an initial projection stage to reduce the dimension of the + feature vectors and allow us to use the transformer with arbitrary + features. + 2. Instead of the language translation task described in [Attention Is All + You Need][1], where the tokens of the words translated so far are used + to predict the next word in the sequence, we use a set of fixed, learned + class tokens in conjunction with equally as many independent fully + connected layers to predict multiple labels at once. + + [1]: https://arxiv.org/abs/1706.03762 "Attention Is All You Need" + [2]: https://arxiv.org/abs/2010.11929 + "An Image is Worth 16x16 Words: + Transformers for Image Recognition at Scale" + """ + + def __init__( + self, + d_features: int, + target_n_outs: dict[str, int], + *, + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + ) -> None: + super().__init__() + + self.projector = nn.Sequential(nn.Linear(d_features, d_model), nn.ReLU()) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=num_encoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=num_encoder_layers, enable_nested_tensor=False + ) + + self.target_labels = target_n_outs.keys() + + # One class token per output label + self.class_tokens = nn.ParameterDict( + { + sanitize(target_label): torch.rand(d_model) + for target_label in target_n_outs + } + ) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_model, + nhead=num_decoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_decoder = nn.TransformerDecoder( + decoder_layer, num_layers=num_decoder_layers + ) + + self.heads = nn.ModuleDict( + { + sanitize(target_label): nn.Linear( + in_features=d_model, out_features=n_out + ) + for target_label, n_out in target_n_outs.items() + } + ) + + self.positional_encoding = positional_encoding + + def forward( + self, + tile_tokens: torch.Tensor, + tile_positions: torch.Tensor, + ) -> dict[str, torch.Tensor]: + batch_size, _, _ = tile_tokens.shape + + tile_tokens = self.projector(tile_tokens) # shape: [bs, seq_len, d_model] + + if self.positional_encoding: + # Add positional encodings + d_model = tile_tokens.size(-1) + x = tile_positions.unsqueeze(-1) / 100_000 ** ( + torch.arange(d_model // 4).type_as(tile_positions) / d_model + ) + positional_encodings = torch.cat( + [ + torch.sin(x).flatten(start_dim=-2), + torch.cos(x).flatten(start_dim=-2), + ], + dim=-1, + ) + tile_tokens = tile_tokens + positional_encodings + + tile_tokens = self.transformer_encoder(tile_tokens) + + class_tokens = torch.stack( + [self.class_tokens[sanitize(t)] for t in self.target_labels] + ).expand(batch_size, -1, -1) + class_tokens = self.transformer_decoder(tgt=class_tokens, memory=tile_tokens) + + # Apply the corresponding head to each class token + logits = { + target_label: self.heads[sanitize(target_label)](class_token) + for target_label, class_token in zip( + self.target_labels, + class_tokens.permute(1, 0, 2), # Permute to [target, batch, d_model] + strict=True, + ) + } + + return logits + + +class LitMilClassificationMixin(lightning.LightningModule): + """Makes a module into a multilabel, multiclass Lightning one""" + + supported_features = ["tile"] + + def __init__( + self, + *, + weights: dict[TargetLabel, torch.Tensor], + # Other hparams + learning_rate: float = 1e-4, + stamp_version: Version = Version(stamp.__version__), + **hparams: Any, + ) -> None: + super().__init__() + _ = hparams # So we don't get unused parameter warnings + + # Check if version is compatible. + if stamp_version < Version("2.4.0"): + # Update this as we change our model in incompatible ways! + raise ValueError( + f"model has been built with stamp version {stamp_version} " + f"which is incompatible with the current version." + ) + elif stamp_version > Version(stamp.__version__): + # Let's be strict with models "from the future", + # better fail deadly than have broken results. + raise ValueError( + "model has been built with a stamp version newer than the installed one " + f"({stamp_version} > {stamp.__version__}). " + "Please upgrade stamp to a compatible version." + ) + + self.hparams.update({"task": "classification"}) + + self.learning_rate = learning_rate + + target_aurocs = torchmetrics.MetricCollection( + { + sanitize(target_label): SafeMulticlassAUROC(num_classes=len(weight)) + for target_label, weight in weights.items() + } + ) + for step_name in ["train", "validation", "test"]: + setattr( + self, + f"{step_name}_target_aurocs", + target_aurocs.clone(prefix=f"{step_name}_"), + ) + + self.weights = weights + + self.save_hyperparameters() + + def step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, dict[str, torch.Tensor]], + step_name=None, + ): + """Process a batch with structure (feats, coords, bag_sizes, targets). + + Args: + batch: Tuple of (feats, coords, bag_sizes, targets) where: + - feats: bag features [batch, bag_size, feature_dim] + - coords: tile coordinates [batch, bag_size, 2] + - bag_sizes: number of tiles per bag [batch] + - targets: dict mapping target names to one-hot encoded tensors [batch, num_classes] + step_name: Optional step name for logging ('train', 'validation', 'test'). + """ + feats: Bags + coords: CoordinatesBatch + bag_sizes: BagSizes + targets: dict[str, torch.Tensor] + feats, coords, bag_sizes, targets = batch + logits = self(feats, coords) + + # Calculate the cross entropy loss for each target, then sum them + loss = sum( + F.cross_entropy( + (logit := logits[target_label]), + targets[target_label].type_as(logit), + weight=weight.type_as(logit), + ) + for target_label, weight in self.weights.items() + ) + + if step_name: + self.log( + f"{step_name}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + # Update target-wise metrics + for target_label in self.weights: + target_auroc = getattr(self, f"{step_name}_target_aurocs")[ + sanitize(target_label) + ] + is_na = (targets[target_label] == 0).all(dim=1) + target_auroc.update( + logits[target_label][~is_na], + targets[target_label][~is_na].argmax(dim=1), + ) + self.log( + f"{step_name}_{target_label}_auroc", + target_auroc, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + + return loss + + def training_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="train") + + def validation_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="validation") + + def test_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="test") + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + if len(batch) == 2: + feats, positions = batch + else: + feats, positions, _, _ = batch + + logits = self(feats, positions) + + softmaxed = { + target_label: torch.softmax(x, 1) for target_label, x in logits.items() + } + return softmaxed + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + +def sanitize(x: str) -> str: + return re.sub(r"[^A-Za-z0-9_]", "_", x) + + +class SafeMulticlassAUROC(MulticlassAUROC): + """A Multiclass AUROC that doesn't blow up when no targets are given""" + + def compute(self) -> torch.Tensor: + # Add faux entry if there are none so far + if len(self.preds) == 0: + self.update(torch.zeros(1, self.num_classes), torch.zeros(1).long()) + elif len(dim_zero_cat(self.preds)) == 0: + self.update( + torch.zeros(1, self.num_classes).type_as(self.preds[0]), + torch.zeros(1).long().type_as(self.target[0]), + ) + return super().compute() diff --git a/src/stamp/modeling/models/mlp.py b/src/stamp/modeling/models/mlp.py index e4f8881f..e88a77ca 100644 --- a/src/stamp/modeling/models/mlp.py +++ b/src/stamp/modeling/models/mlp.py @@ -29,7 +29,7 @@ def __init__( layers.append(nn.Dropout(dropout)) in_dim = dim_hidden layers.append(nn.Linear(in_dim, dim_output)) - self.mlp = nn.Sequential(*layers) # type: ignore + self.mlp = nn.Sequential(*layers) @jaxtyped(typechecker=beartype) def forward( diff --git a/src/stamp/modeling/registry.py b/src/stamp/modeling/registry.py index 2205af22..14011084 100644 --- a/src/stamp/modeling/registry.py +++ b/src/stamp/modeling/registry.py @@ -1,6 +1,7 @@ from enum import StrEnum from stamp.modeling.models import ( + LitEncDecTransformer, LitPatientClassifier, LitPatientRegressor, LitPatientSurvival, @@ -21,6 +22,7 @@ class ModelName(StrEnum): MLP = "mlp" TRANS_MIL = "trans_mil" LINEAR = "linear" + BARSPOON = "barspoon" # Map (feature_type, task) → correct Lightning wrapper class @@ -34,6 +36,7 @@ class ModelName(StrEnum): ("patient", "classification"): LitPatientClassifier, ("patient", "regression"): LitPatientRegressor, ("patient", "survival"): LitPatientSurvival, + # ("tile", "multiclass"): LitEncDecTransformer, } @@ -54,6 +57,13 @@ def load_model_class(task: Task, feature_type: str, model_name: ModelName): case ModelName.MLP: from stamp.modeling.models.mlp import MLP as ModelClass + case ModelName.BARSPOON: + from stamp.modeling.models.barspoon import ( + EncDecTransformer as ModelClass, + ) + + LitModelClass = LitEncDecTransformer + case ModelName.LINEAR: from stamp.modeling.models.mlp import ( Linear as ModelClass, diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index b855e2a0..fb2bdb3b 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -2,7 +2,7 @@ import shutil from collections.abc import Callable, Mapping, Sequence from pathlib import Path -from typing import cast +from typing import Any, cast import lightning import torch @@ -18,13 +18,7 @@ PatientData, PatientFeatureDataset, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, - log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, + load_patient_data_, ) from stamp.modeling.registry import ModelName, load_model_class from stamp.modeling.transforms import VaryPrecisionTransform @@ -53,66 +47,25 @@ def train_categorical_model_( advanced: AdvancedConfig, ) -> None: """Trains a model based on the feature type.""" - feature_type = detect_feature_type(config.feature_dir) - _logger.info(f"Detected feature type: {feature_type}") - - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label is required for survival modeling" - ) - patient_to_ground_truth = patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for tile-level modeling" - ) - patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - elif feature_type == "patient": - # Patient-level: ignore slide_table - if config.slide_table is not None: - _logger.warning("slide_table is ignored for patient-level features.") - - patient_to_data = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - else: - raise RuntimeError(f"Unknown feature type: {feature_type}") - if config.task is None: raise ValueError( "task must be set to 'classification' | 'regression' | 'survival'" ) + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) + _logger.info(f"Detected feature type: {feature_type}") + # Train the model (the rest of the logic is unchanged) model, train_dl, valid_dl = setup_model_for_training( patient_to_data=patient_to_data, @@ -145,14 +98,14 @@ def train_categorical_model_( def setup_model_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, train_transform: Callable[[torch.Tensor], torch.Tensor] | None, feature_type: str, advanced: AdvancedConfig, # Metadata, has no effect on model training - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, clini_table: Path, @@ -193,10 +146,6 @@ def setup_model_for_training( feature_type=feature_type, train_categories=train_categories, ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, - ) # 1. Default to a model if none is specified if advanced.model_name is None: @@ -209,6 +158,7 @@ def setup_model_for_training( LitModelClass, ModelClass = load_model_class( task, feature_type, advanced.model_name ) + print(f"Using Lightning wrapper class: {LitModelClass}") # 3. Validate that the chosen model supports the feature type if feature_type not in LitModelClass.supported_features: @@ -272,7 +222,7 @@ def setup_model_for_training( def setup_dataloaders_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, bag_size: int, @@ -283,7 +233,7 @@ def setup_dataloaders_for_training( ) -> tuple[ DataLoader, DataLoader, - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], int, Sequence[PatientId], Sequence[PatientId], @@ -310,10 +260,25 @@ def setup_dataloaders_for_training( ) if task == "classification": - stratify = ground_truths + # Handle both single and multi-target cases + if ground_truths and isinstance(ground_truths[0], dict): + # Multi-target: use first target for stratification + first_key = list(ground_truths[0].keys())[0] + stratify = [cast(dict, gt)[first_key] for gt in ground_truths] + else: + stratify = ground_truths elif task == "survival": - # Extract event indicator (status) - statuses = [int(gt.split()[1]) for gt in ground_truths] + # Extract event indicator (status) - handle both single and multi-target + statuses = [] + for gt in ground_truths: + if isinstance(gt, dict): + # Multi-target survival: extract from first target + first_key = list(gt.keys())[0] + val = cast(dict, gt)[first_key] + if val: + statuses.append(int(val.split()[1])) + else: + statuses.append(int(gt.split()[1])) stratify = statuses elif task == "regression": stratify = None @@ -321,7 +286,10 @@ def setup_dataloaders_for_training( train_patients, valid_patients = cast( tuple[Sequence[PatientId], Sequence[PatientId]], train_test_split( - list(patient_to_data), stratify=stratify, shuffle=True, random_state=0 + list(patient_to_data), + stratify=cast(Any, stratify), + shuffle=True, + random_state=0, ), ) @@ -441,28 +409,50 @@ def _compute_class_weights_and_check_categories( *, train_dl: DataLoader, feature_type: str, - train_categories: Sequence[str], -) -> torch.Tensor: + train_categories: Sequence[str] | Mapping[str, Sequence[str]], +) -> torch.Tensor | dict[str, torch.Tensor]: """ Computes class weights and checks for category issues. Logs warnings if there are too few or underpopulated categories. Returns normalized category weights as a torch.Tensor. """ if feature_type == "tile": - category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) + dataset = cast(BagDataset, train_dl.dataset) + + if isinstance(dataset.ground_truths, list): + # Multi-target case: compute weights per target head + weights_per_target: dict[str, torch.Tensor] = {} + + target_keys = dataset.ground_truths[0].keys() + + for key in target_keys: + stacked = torch.stack([gt[key] for gt in dataset.ground_truths], dim=0) + counts = stacked.sum(dim=0) + w = counts.sum() / counts + weights_per_target[key] = w / w.sum() + + return weights_per_target + else: + category_counts = dataset.ground_truths.sum(dim=0) else: - category_counts = cast( - PatientFeatureDataset, train_dl.dataset - ).ground_truths.sum(dim=0) + dataset = cast(PatientFeatureDataset, train_dl.dataset) + category_counts = dataset.ground_truths.sum(dim=0) cat_ratio_reciprocal = category_counts.sum() / category_counts category_weights = cat_ratio_reciprocal / cat_ratio_reciprocal.sum() if len(train_categories) <= 1: raise ValueError(f"not enough categories to train on: {train_categories}") - elif any(category_counts < 16): + elif (category_counts < 16).any(): + category_counts_list = ( + category_counts.tolist() + if category_counts.dim() > 0 + else [category_counts.item()] + ) underpopulated_categories = { category: int(count) - for category, count in zip(train_categories, category_counts, strict=True) + for category, count in zip( + train_categories, category_counts_list, strict=True + ) if count < 16 } _logger.warning( diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index ab3ff0d2..22f1d90f 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -16,7 +16,6 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor from stamp.preprocessing.tiling import ( @@ -32,6 +31,7 @@ SlidePixels, TilePixels, ) +from stamp.utils.cache import get_processing_code_hash __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck" diff --git a/src/stamp/preprocessing/extractor/chief_ctranspath.py b/src/stamp/preprocessing/extractor/chief_ctranspath.py index 2d2e6b9b..03f5bba6 100644 --- a/src/stamp/preprocessing/extractor/chief_ctranspath.py +++ b/src/stamp/preprocessing/extractor/chief_ctranspath.py @@ -1,6 +1,6 @@ from pathlib import Path -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown diff --git a/src/stamp/preprocessing/extractor/ctranspath.py b/src/stamp/preprocessing/extractor/ctranspath.py index ba9a277c..d189e279 100644 --- a/src/stamp/preprocessing/extractor/ctranspath.py +++ b/src/stamp/preprocessing/extractor/ctranspath.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional, TypeVar, cast -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown @@ -518,7 +518,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) + self.relative_position_index.view(-1) # pyright: ignore[reportCallIssue] ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], diff --git a/src/stamp/preprocessing/extractor/dinobloom.py b/src/stamp/preprocessing/extractor/dinobloom.py index fb7713ce..54b79c53 100644 --- a/src/stamp/preprocessing/extractor/dinobloom.py +++ b/src/stamp/preprocessing/extractor/dinobloom.py @@ -8,9 +8,9 @@ from torch import nn from torchvision import transforms -from stamp.cache import STAMP_CACHE_DIR from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor +from stamp.utils.cache import STAMP_CACHE_DIR __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck" diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index 82a3efba..ce684ba4 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -461,6 +461,9 @@ def _extract_mpp_from_metadata(slide: openslide.AbstractSlide) -> SlideMPP | Non return None doc = minidom.parseString(xml_path) collection = doc.documentElement + if collection is None: + _logger.error("Document element is None, unable to extract MPP.") + return None images = collection.getElementsByTagName("Image") pixels = images[0].getElementsByTagName("Pixels") mpp = float(pixels[0].getAttribute("PhysicalSizeX")) diff --git a/src/stamp/statistics/__init__.py b/src/stamp/statistics/__init__.py index ec09e1e0..b3243ecc 100644 --- a/src/stamp/statistics/__init__.py +++ b/src/stamp/statistics/__init__.py @@ -1,3 +1,11 @@ +"""Statistics utilities (wrappers) for classification, regression and survival. + +This module provides a small, stable wrapper `compute_stats_` that dispatches +to the task-specific statistic implementations found in the submodules. +""" + +from __future__ import annotations + from collections.abc import Sequence from pathlib import Path from typing import NewType @@ -7,7 +15,10 @@ from matplotlib import pyplot as plt from pydantic import BaseModel, ConfigDict, Field -from stamp.statistics.categorical import categorical_aggregated_ +from stamp.statistics.categorical import ( + categorical_aggregated_, + categorical_aggregated_multitarget_, +) from stamp.statistics.prc import ( plot_multiple_decorated_precision_recall_curves, plot_single_decorated_precision_recall_curve, @@ -17,23 +28,25 @@ plot_multiple_decorated_roc_curves, plot_single_decorated_roc_curve, ) -from stamp.statistics.survival import ( - _plot_km, - _survival_stats_for_csv, -) +from stamp.statistics.survival import _plot_km, _survival_stats_for_csv from stamp.types import PandasLabel, Task +__all__ = ["StatsConfig", "compute_stats_"] + + __author__ = "Marko van Treeck, Minh Duc Nguyen" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" def _read_table(file: Path, **kwargs) -> pd.DataFrame: - """Loads a dataframe from a file.""" + """Load a dataframe from CSV or XLSX file path. + + This small helper centralizes file IO formatting and keeps callers simple. + """ if isinstance(file, Path) and file.suffix == ".xlsx": return pd.read_excel(file, **kwargs) - else: - return pd.read_csv(file, **kwargs) + return pd.read_csv(file, **kwargs) class StatsConfig(BaseModel): @@ -41,7 +54,7 @@ class StatsConfig(BaseModel): task: Task = Field(default="classification") output_dir: Path pred_csvs: list[Path] - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None true_class: str | None = None time_label: str | None = None status_label: str | None = None @@ -50,47 +63,65 @@ class StatsConfig(BaseModel): _Inches = NewType("_Inches", float) -def compute_stats_( +def _compute_multitarget_classification_stats( *, - task: Task, output_dir: Path, pred_csvs: Sequence[Path], - ground_truth_label: PandasLabel | None = None, - true_class: str | None = None, - time_label: str | None = None, - status_label: str | None = None, + target_labels: Sequence[str], ) -> None: - match task: - case "classification": - if true_class is None or ground_truth_label is None: - raise ValueError( - "both true_class and ground_truth_label are required in statistic configuration" - ) - - preds_dfs = [ - _read_table( - p, - usecols=[ground_truth_label, f"{ground_truth_label}_{true_class}"], - dtype={ - ground_truth_label: str, - f"{ground_truth_label}_{true_class}": float, - }, - ) - for p in pred_csvs - ] - - y_trues = [ - np.array(df[ground_truth_label] == true_class) for df in preds_dfs - ] - y_preds = [ - np.array(df[f"{ground_truth_label}_{true_class}"].values) - for df in preds_dfs - ] - n_bootstrap_samples = 1000 - figure_width = _Inches(3.8) - threshold_cmap = None - - roc_curve_figure_aspect_ratio = 1.08 + """Compute statistics and plots for multi-target classification. + + For each target, creates ROC and PRC curves for each class, + similar to single-target classification. + """ + output_dir.mkdir(parents=True, exist_ok=True) + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + roc_curve_figure_aspect_ratio = 1.08 + + # Validate all target labels exist in CSV + first_df = _read_table(pred_csvs[0], nrows=0) + missing_targets = [t for t in target_labels if t not in first_df.columns] + if missing_targets: + raise ValueError( + f"Target labels not found in CSV: {missing_targets}. Available columns: {list(first_df.columns)}" + ) + + # Process each target + for target_label in target_labels: + # Load data for this target + preds_dfs = [] + for p in pred_csvs: + df = _read_table(p, dtype=str) + # Only keep rows where this target has ground truth + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs.append(df_clean) + + if not preds_dfs: + continue + + # Get unique classes for this target + classes = sorted(preds_dfs[0][target_label].unique()) + + # Create plots for each class in this target + for true_class in classes: + # Extract ground truth and predictions for this class + y_trues = [] + y_preds = [] + + for df in preds_dfs: + prob_col = f"{target_label}_{true_class}" + if prob_col not in df.columns: + continue + + y_trues.append(np.array(df[target_label] == true_class)) + y_preds.append(np.array(df[prob_col].astype(float).values)) + + if not y_trues: + continue + + # Plot ROC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, @@ -101,63 +132,206 @@ def compute_stats_( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, - threshold_cmap=threshold_cmap, + threshold_cmap=None, ) - else: plot_multiple_decorated_roc_curves( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=None, ) fig.tight_layout() - if not output_dir.exists(): - output_dir.mkdir(parents=True, exist_ok=True) - - fig.savefig(output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"roc-curve_{target_label}={true_class}.svg") plt.close(fig) + # Plot PRC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, ) + if len(preds_dfs) == 1: plot_single_decorated_precision_recall_curve( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, ) - else: plot_multiple_decorated_precision_recall_curves( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", ) fig.tight_layout() - fig.savefig(output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"pr-curve_{target_label}={true_class}.svg") plt.close(fig) - categorical_aggregated_( - preds_csvs=pred_csvs, - ground_truth_label=ground_truth_label, - outpath=output_dir, + # Compute aggregated statistics for all targets + categorical_aggregated_multitarget_( + preds_csvs=pred_csvs, + outpath=output_dir, + target_labels=target_labels, + ) + + +def compute_stats_( + *, + task: Task, + output_dir: Path, + pred_csvs: Sequence[Path], + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, + true_class: str | None = None, + time_label: str | None = None, + status_label: str | None = None, +) -> None: + """Compute and save statistics for the provided task and prediction CSVs. + + This wrapper keeps the external API stable while delegating the detailed + computations and plotting to the submodules under `stamp.statistics.*`. + """ + match task: + case "classification": + # Check if multi-target based on ground_truth_label type + is_multitarget = ( + isinstance(ground_truth_label, (list, tuple)) + and len(ground_truth_label) > 1 ) + if is_multitarget: + # Multi-target classification + if not isinstance(ground_truth_label, (list, tuple)): + raise ValueError( + "ground_truth_label must be a list or tuple for multi-target classification" + ) + _compute_multitarget_classification_stats( + output_dir=output_dir, + pred_csvs=pred_csvs, + target_labels=list(ground_truth_label), + ) + else: + # Single-target classification (original behavior) + if true_class is None or ground_truth_label is None: + raise ValueError( + "both true_class and ground_truth_label are required in statistic configuration" + ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for single-target classification" + ) + + preds_dfs = [ + _read_table( + p, + usecols=[ + ground_truth_label, + f"{ground_truth_label}_{true_class}", + ], + dtype={ + ground_truth_label: str, + f"{ground_truth_label}_{true_class}": float, + }, + ) + for p in pred_csvs + ] + + y_trues = [ + np.array(df[ground_truth_label] == true_class) for df in preds_dfs + ] + y_preds = [ + np.array(df[f"{ground_truth_label}_{true_class}"].values) + for df in preds_dfs + ] + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + threshold_cmap = None + + roc_curve_figure_aspect_ratio = 1.08 + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + + if len(preds_dfs) == 1: + plot_single_decorated_roc_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + threshold_cmap=threshold_cmap, + ) + else: + plot_multiple_decorated_roc_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=None, + ) + + fig.tight_layout() + output_dir.mkdir(parents=True, exist_ok=True) + fig.savefig( + output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + if len(preds_dfs) == 1: + plot_single_decorated_precision_recall_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + ) + else: + plot_multiple_decorated_precision_recall_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + ) + + fig.tight_layout() + fig.savefig( + output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + categorical_aggregated_( + preds_csvs=pred_csvs, + ground_truth_label=ground_truth_label, + outpath=output_dir, + ) + case "regression": if ground_truth_label is None: raise ValueError( "no ground_truth_label configuration supplied in statistic" ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for regression (multi-target regression not yet supported)" + ) regression_aggregated_( preds_csvs=pred_csvs, ground_truth_label=ground_truth_label, @@ -203,12 +377,7 @@ def compute_stats_( cut_off=cut_off, ) - # ------------------------------------------------------------------ # # Save individual and aggregated CSVs - # ------------------------------------------------------------------ # stats_df = pd.DataFrame(per_fold).transpose() stats_df.index.name = "fold_name" # label the index column stats_df.to_csv(output_dir / "survival-stats_individual.csv", index=True) - - # agg_df = _aggregate_with_ci(stats_df) - # agg_df.to_csv(output_dir / "survival-stats_aggregated.csv", index=True) diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 2b5c859e..9d6c4c12 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -21,6 +21,30 @@ ] +def _detect_targets_from_columns(columns: Sequence[str]) -> list[str]: + """Detect target columns from CSV column names. + + Assumes multi-target format where each target has: + - A ground truth column (target name) + - A prediction column (pred_{target}) + - Probability columns ({target}_{class1}, {target}_{class2}, ...) + + Returns: + List of target names detected. + """ + # Convert to list to handle pandas Index + columns = list(columns) + targets = [] + for col in columns: + # Look for columns that start with "pred_" + if col.startswith("pred_"): + target_name = col[5:] # Remove "pred_" prefix + # Verify the target column exists + if target_name in columns: + targets.append(target_name) + return sorted(targets) + + def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: """Calculates some stats for categorical prediction tables. @@ -38,29 +62,29 @@ def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: # roc_auc stats_df["roc_auc_score"] = [ - metrics.roc_auc_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.roc_auc_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # average_precision stats_df["average_precision_score"] = [ - metrics.average_precision_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.average_precision_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # f1 score y_pred_labels = categories[y_pred.argmax(axis=1)] stats_df["f1_score"] = [ - metrics.f1_score(y_true == cat, y_pred_labels == cat) # pyright: ignore[reportCallIssue,reportArgumentType] - for cat in categories + metrics.f1_score(y_true == cat, y_pred_labels == cat) for cat in categories ] # p values p_values = [] for i, cat in enumerate(categories): - pos_scores = y_pred[:, i][y_true == cat] # pyright: ignore[reportCallIssue,reportArgumentType] - neg_scores = y_pred[:, i][y_true != cat] # pyright: ignore[reportCallIssue,reportArgumentType] - p_values.append(st.ttest_ind(pos_scores, neg_scores).pvalue) # pyright: ignore[reportGeneralTypeIssues, reportAttributeAccessIssue] + pos_scores = y_pred[:, i][y_true == cat] + neg_scores = y_pred[:, i][y_true != cat] + _, p_value = st.ttest_ind(pos_scores, neg_scores) + p_values.append(p_value) stats_df["p_value"] = p_values assert set(_score_labels) & set(stats_df.columns) == set(_score_labels) @@ -110,3 +134,61 @@ def categorical_aggregated_( preds_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_individual.csv") stats_df = _aggregate_categorical_stats(preds_df.reset_index()) stats_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_aggregated.csv") + + +def categorical_aggregated_multitarget_( + *, + preds_csvs: Sequence[Path], + outpath: Path, + target_labels: Sequence[str], +) -> None: + """Calculate statistics for multi-target categorical deployments. + + Args: + preds_csvs: CSV files containing predictions. + outpath: Path to save the results to. + target_labels: List of target labels to compute statistics for. + + This will apply `_categorical` to each target in the multi-target setup, + calculate statistics per target, and save both individual and aggregated results. + """ + outpath.mkdir(parents=True, exist_ok=True) + + all_target_stats = {} + + for target_label in target_labels: + # Process each target separately + preds_dfs = {} + for p in preds_csvs: + df = pd.read_csv(p, dtype=str) + # Drop rows where this target's ground truth is missing + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs[Path(p).parent.name] = _categorical(df_clean, target_label) + + if not preds_dfs: + continue + + # Concatenate and save individual stats for this target + preds_df = pd.concat(preds_dfs).sort_index() + preds_df.to_csv(outpath / f"{target_label}_categorical-stats_individual.csv") + + # Aggregate stats for this target + stats_df = _aggregate_categorical_stats(preds_df.reset_index()) + stats_df.to_csv(outpath / f"{target_label}_categorical-stats_aggregated.csv") + + # Store for summary + all_target_stats[target_label] = stats_df + + # Create a combined summary across all targets + if all_target_stats: + summary_dfs = [] + for target_name, stats_df in all_target_stats.items(): + stats_copy = stats_df.copy() + stats_copy.index = pd.MultiIndex.from_product( + [[target_name], stats_copy.index], names=["target", "class"] + ) + summary_dfs.append(stats_copy) + + combined_summary = pd.concat(summary_dfs) + combined_summary.to_csv(outpath / "multitarget_categorical-stats_summary.csv") diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 78fb51cd..87415a7a 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd from lifelines import KaplanMeierFitter -from lifelines.plotting import add_at_risk_counts from lifelines.statistics import logrank_test from lifelines.utils import concordance_index @@ -101,7 +100,7 @@ def _plot_km( if risk_label is None: risk_label = "pred_score" - # --- Clean NaNs and invalid entries --- + # Clean NaNs and invalid entries df = df.replace(["NaN", "nan", "None", "Inf", "inf"], np.nan) df = df.dropna(subset=[time_label, status_label, risk_label]).copy() df = df[df[status_label].isin([0, 1])] @@ -143,20 +142,14 @@ def _plot_km( if len(high_df) > 0: fitters.append(kmf_high) - if len(fitters) > 0: - add_at_risk_counts(*fitters, ax=ax) - - # log-rank only if both groups exist - if len(low_df) > 0 and len(high_df) > 0: - res = logrank_test( - low_df[time_label], - high_df[time_label], - event_observed_A=low_df[status_label], - event_observed_B=high_df[status_label], - ) - logrank_p = float(res.p_value) - else: - logrank_p = float("nan") + # log-rank and c-index + res = logrank_test( + low_df[time_label], + high_df[time_label], + event_observed_A=low_df[status_label], + event_observed_B=high_df[status_label], + ) + logrank_p = float(res.p_value) c_used, used, *_ = _cindex(time, event, risk) ax.text( diff --git a/src/stamp/types.py b/src/stamp/types.py index f1f571cc..c1ff6873 100644 --- a/src/stamp/types.py +++ b/src/stamp/types.py @@ -37,6 +37,7 @@ PatientId: TypeAlias = str GroundTruth: TypeAlias = str +MultiClassGroundTruth: TypeAlias = tuple[str, ...] FeaturePath = NewType("FeaturePath", Path) Category: TypeAlias = str diff --git a/src/stamp/cache.py b/src/stamp/utils/cache.py similarity index 100% rename from src/stamp/cache.py rename to src/stamp/utils/cache.py diff --git a/src/stamp/config.py b/src/stamp/utils/config.py similarity index 100% rename from src/stamp/config.py rename to src/stamp/utils/config.py diff --git a/src/stamp/seed.py b/src/stamp/utils/seed.py similarity index 100% rename from src/stamp/seed.py rename to src/stamp/utils/seed.py diff --git a/tests/random_data.py b/tests/random_data.py index bd95d1bc..b79c9f42 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -74,13 +74,13 @@ def create_random_dataset( clini_df = pd.DataFrame( patient_to_ground_truth.items(), - columns=["patient", "ground-truth"], # pyright: ignore[reportArgumentType] + columns=["patient", "ground-truth"], ) clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( slide_path_to_patient.items(), - columns=["slide_path", "patient"], # pyright: ignore[reportArgumentType] + columns=["slide_path", "patient"], ) slide_df.to_csv(slide_path, index=False) @@ -130,7 +130,7 @@ def create_random_regression_dataset( # --- Write clini + slide tables --- clini_df = pd.DataFrame(patient_to_target, columns=["patient", "target"]) - clini_df["target"] = clini_df["target"].astype(float) # ✅ ensure numeric dtype + clini_df["target"] = clini_df["target"].astype(float) # ensure numeric dtype clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( @@ -254,6 +254,92 @@ def create_random_patient_level_dataset( return clini_path, slide_path, feat_dir, categories +def create_random_multi_target_dataset( + *, + dir: Path, + n_patients: int, + max_slides_per_patient: int, + min_tiles_per_slide: int, + max_tiles_per_slide: int, + feat_dim: int, + target_labels: Sequence[str], + categories_per_target: Sequence[Sequence[str]], + extractor_name: str = "random-test-generator", + min_slides_per_patient: int = 1, +) -> tuple[Path, Path, Path, Sequence[Sequence[str]]]: + """ + Create a random multi-target tile-level dataset. + + Args: + dir: Directory to create dataset in + n_patients: Number of patients + max_slides_per_patient: Maximum slides per patient + min_tiles_per_slide: Minimum tiles per slide + max_tiles_per_slide: Maximum tiles per slide + feat_dim: Feature dimension + target_labels: Names of the target columns (e.g., ["subtype", "grade"]) + categories_per_target: Categories for each target (e.g., [["A", "B"], ["1", "2", "3"]]) + extractor_name: Name of the extractor + min_slides_per_patient: Minimum slides per patient + + Returns: + Tuple of (clini_path, slide_path, feat_dir, categories_per_target) + """ + if len(target_labels) != len(categories_per_target): + raise ValueError( + "target_labels and categories_per_target must have same length" + ) + + slide_path_to_patient: Mapping[Path, PatientId] = {} + patient_to_ground_truths: Mapping[PatientId, dict[str, str]] = {} + clini_path = dir / "clini.csv" + slide_path = dir / "slide.csv" + + feat_dir = dir / "feats" + feat_dir.mkdir() + + for _ in range(n_patients): + # Random patient ID + patient_id = random_string(16) + + # Generate ground truths for each target + ground_truths = {} + for target_label, categories in zip(target_labels, categories_per_target): + ground_truths[target_label] = random.choice(categories) + + patient_to_ground_truths[patient_id] = ground_truths + + # Generate some slides + for _ in range(random.randint(min_slides_per_patient, max_slides_per_patient)): + slide_path_to_patient[ + create_random_feature_file( + tmp_path=feat_dir, + min_tiles=min_tiles_per_slide, + max_tiles=max_tiles_per_slide, + feat_dim=feat_dim, + extractor_name=extractor_name, + ).relative_to(feat_dir) + ] = patient_id + + # Create clinical table with multiple target columns + clini_data = [] + for patient_id, ground_truths in patient_to_ground_truths.items(): + row = {"patient": patient_id} + row.update(ground_truths) + clini_data.append(row) + + clini_df = pd.DataFrame(clini_data) + clini_df.to_csv(clini_path, index=False) + + slide_df = pd.DataFrame( + slide_path_to_patient.items(), + columns=["slide_path", "patient"], + ) + slide_df.to_csv(slide_path, index=False) + + return clini_path, slide_path, feat_dir, categories_per_target + + def create_random_feature_file( *, tmp_path: Path, diff --git a/tests/test_cache_tiles.py b/tests/test_cache_tiles.py index a665e92c..d9f1411d 100644 --- a/tests/test_cache_tiles.py +++ b/tests/test_cache_tiles.py @@ -6,10 +6,10 @@ import numpy as np import pytest -from stamp.cache import download_file from stamp.preprocessing import Microns, TilePixels from stamp.preprocessing.tiling import _Tile, tiles_with_cache from stamp.types import ImageExtension, SlidePixels +from stamp.utils.cache import download_file def _get_tiles_and_images( diff --git a/tests/test_config.py b/tests/test_config.py index 15b5dd80..fbb53a6c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,6 @@ # %% from pathlib import Path -from stamp.config import StampConfig from stamp.heatmaps.config import HeatmapConfig from stamp.modeling.config import ( AdvancedConfig, @@ -21,6 +20,7 @@ TilePixels, ) from stamp.statistics import StatsConfig +from stamp.utils.config import StampConfig def test_config_parsing() -> None: diff --git a/tests/test_crossval.py b/tests/test_crossval.py index 184a5c23..5e53f50c 100644 --- a/tests/test_crossval.py +++ b/tests/test_crossval.py @@ -13,7 +13,7 @@ VitModelParams, ) from stamp.modeling.crossval import categorical_crossval_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow diff --git a/tests/test_data.py b/tests/test_data.py index 564d6426..a60d830e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,6 +3,7 @@ from pathlib import Path import h5py +import pandas as pd import pytest import torch from random_data import ( @@ -21,9 +22,9 @@ PatientFeatureDataset, filter_complete_patient_data_, get_coords, + patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, ) -from stamp.seed import Seed from stamp.types import ( BagSize, FeaturePath, @@ -33,6 +34,7 @@ SlideMPP, TilePixels, ) +from stamp.utils.seed import Seed @pytest.mark.filterwarnings("ignore:some patients have no associated slides") @@ -85,6 +87,31 @@ def test_get_cohort_df(tmp_path: Path) -> None: } +def test_patient_to_ground_truth_multi_target(tmp_path: Path) -> None: + """Verify multi-target clini parsing returns dicts and drops rows missing all targets.""" + df = pd.DataFrame( + { + "patient": ["p1", "p2", "p3", "p4"], + "subtype": ["A", None, "B", None], + "grade": ["1", "2", None, None], + } + ) + df.to_csv(tmp_path / "clini.csv", index=False) + + result = patient_to_ground_truth_from_clini_table_( + clini_table_path=tmp_path / "clini.csv", + patient_label="patient", + ground_truth_label=["subtype", "grade"], + ) + + # p4 has both targets missing → dropped + assert "p4" not in result + + assert result["p1"] == {"subtype": "A", "grade": "1"} + assert result["p2"] == {"subtype": None, "grade": "2"} + assert result["p3"] == {"subtype": "B", "grade": None} + + @pytest.mark.parametrize( "feature_file_creator", [make_feature_file, make_old_feature_file], diff --git a/tests/test_deployment.py b/tests/test_deployment.py index de20ea12..7d1d6589 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import numpy as np import pytest @@ -24,8 +25,8 @@ ) from stamp.modeling.models.mlp import MLP from stamp.modeling.models.vision_tranformer import VisionTransformer -from stamp.seed import Seed from stamp.types import GroundTruth, PatientId, Task +from stamp.utils.seed import Seed def test_predict_patient_level( @@ -83,7 +84,8 @@ def test_predict_patient_level( assert len(predictions) == len(patient_to_data) for pid in patient_ids: - assert predictions[pid].shape == torch.Size([3]), "expected one score per class" + pred = cast(torch.Tensor, predictions[pid]) + assert pred.shape == torch.Size([3]), "expected one score per class" # Check if scores are consistent between runs and different for different patients more_patient_ids = [PatientId(f"pat{i}") for i in range(8, 11)] @@ -124,11 +126,13 @@ def test_predict_patient_level( assert len(more_predictions) == len(all_patient_ids) # Different patients should give different results assert not torch.allclose( - more_predictions[more_patient_ids[0]], more_predictions[more_patient_ids[1]] + cast(torch.Tensor, more_predictions[more_patient_ids[0]]), + cast(torch.Tensor, more_predictions[more_patient_ids[1]]), ), "different inputs should give different results" # The same patient should yield the same result assert torch.allclose( - predictions[patient_ids[0]], more_predictions[patient_ids[0]] + cast(torch.Tensor, predictions[patient_ids[0]]), + cast(torch.Tensor, more_predictions[patient_ids[0]]), ), "the same inputs should repeatedly yield the same results" @@ -163,7 +167,7 @@ def test_to_prediction_df(task: str) -> None: ) if task == "classification": preds_df = _to_prediction_df( - categories=list(model.categories), # type: ignore + categories=list(cast(list, model.categories)), patient_to_ground_truth={ PatientId("pat5"): GroundTruth("foo"), PatientId("pat6"): None, @@ -192,13 +196,13 @@ def test_to_prediction_df(task: str) -> None: # Check if no loss / target is given for targets with missing ground truths no_ground_truth = preds_df[preds_df["patient"].isin(["pat6"])] - assert no_ground_truth["target"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert no_ground_truth["loss"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert no_ground_truth["target"].isna().all() + assert no_ground_truth["loss"].isna().all() # Check if loss / target is given for targets with ground truths with_ground_truth = preds_df[preds_df["patient"].isin(["pat5", "pat7"])] - assert (~with_ground_truth["target"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert (~with_ground_truth["loss"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert (~with_ground_truth["target"].isna()).all() + assert (~with_ground_truth["loss"].isna()).all() elif task == "regression": patient_to_ground_truth = {} @@ -295,7 +299,7 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: div_factor=25.0, ) - # ---- Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) + # Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) if task == "classification": feature_file = make_old_feature_file( feats=torch.rand(23, dim_feats), coords=torch.rand(23, 2) @@ -319,7 +323,7 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: ) } - # ---- Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) + # Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) test_dl, _ = tile_bag_dataloader( task=task, # "classification" | "regression" | "survival" patient_data=list(patient_to_data.values()), @@ -341,12 +345,17 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: assert len(predictions) == 1 pred = list(predictions.values())[0] if task == "classification": - assert pred.shape == torch.Size([len(categories)]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([len(categories)]) elif task == "regression": - assert pred.shape == torch.Size([1]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([1]) else: # survival # Cox model → scalar log-risk, KM → vector or matrix - assert pred.ndim in (0, 1, 2), f"unexpected survival output shape: {pred.shape}" + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.ndim in (0, 1, 2), ( + f"unexpected survival output shape: {pred_tensor.shape}" + ) # Repeatability predictions2 = _predict( @@ -356,4 +365,6 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: accelerator="cpu", ) for pid in predictions: - assert torch.allclose(predictions[pid], predictions2[pid]) + assert torch.allclose( + cast(torch.Tensor, predictions[pid]), cast(torch.Tensor, predictions2[pid]) + ) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 3edef575..28d7c2c1 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -10,13 +10,13 @@ from huggingface_hub.errors import GatedRepoError from random_data import create_random_dataset, create_random_feature_file, random_string -from stamp.cache import download_file from stamp.encoding import ( EncoderName, init_patient_encoder_, init_slide_encoder_, ) from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import download_file # Contains an accepted input patch-level feature encoder # TODO: Make a class for each extractor instead of a function. This class diff --git a/tests/test_feature_extractors.py b/tests/test_feature_extractors.py index 699c10f6..6323ee8d 100644 --- a/tests/test_feature_extractors.py +++ b/tests/test_feature_extractors.py @@ -7,8 +7,8 @@ import torch from huggingface_hub.errors import GatedRepoError -from stamp.cache import download_file from stamp.preprocessing import ExtractorName, Microns, TilePixels, extract_ +from stamp.utils.cache import download_file @pytest.mark.slow diff --git a/tests/test_model.py b/tests/test_model.py index 1aa6d80a..f74e0a99 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ import torch +from stamp.modeling.models.barspoon import EncDecTransformer from stamp.modeling.models.mlp import MLP from stamp.modeling.models.trans_mil import TransMIL from stamp.modeling.models.vision_tranformer import VisionTransformer @@ -162,3 +163,114 @@ def test_trans_mil_inference_reproducibility( ) assert logits1.allclose(logits2) + + +def test_enc_dec_transformer_dims( + batch_size: int = 6, + n_tiles: int = 75, + input_dim: int = 456, + d_model: int = 128, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=256, + positional_encoding=True, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert set(logits.keys()) == set(target_n_outs.keys()) + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_single_target( + batch_size: int = 4, + n_tiles: int = 50, + input_dim: int = 256, + d_model: int = 64, +) -> None: + target_n_outs = {"label": 5} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert list(logits.keys()) == ["label"] + assert logits["label"].shape == (batch_size, 5) + + +def test_enc_dec_transformer_no_positional_encoding( + batch_size: int = 4, + n_tiles: int = 30, + input_dim: int = 128, + d_model: int = 64, +) -> None: + target_n_outs = {"a": 2, "b": 3} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + positional_encoding=False, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_inference_reproducibility( + batch_size: int = 5, + n_tiles: int = 40, + input_dim: int = 200, + d_model: int = 64, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=128, + ) + model = model.eval() + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + + with torch.inference_mode(): + logits1 = model.forward(bags, coords) + logits2 = model.forward(bags, coords) + + for target_label in target_n_outs: + assert logits1[target_label].allclose(logits2[target_label]) diff --git a/tests/test_statistics.py b/tests/test_statistics.py index 790b98ab..b600d7e7 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np +import pandas as pd import torch from random_data import random_patient_preds, random_string @@ -47,3 +48,153 @@ def test_statistics_integration( def test_statistics_integration_for_multiple_patient_preds(tmp_path: Path) -> None: return test_statistics_integration(tmp_path=tmp_path, n_patient_preds=5) + + +def test_statistics_survival_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that survival statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + times = np.random.uniform(30, 2000, size=n_patients) + statuses = np.random.choice([0, 1], size=n_patients, p=[0.3, 0.7]) + risks = np.random.randn(n_patients) + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "day": times, + "status": statuses, + "pred_score": risks, + } + ) + df.to_csv(tmp_path / f"survival-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="survival", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"survival-preds-{i}.csv" for i in range(n_folds)], + time_label="day", + status_label="status", + ) + + assert (tmp_path / "output" / "survival-stats_individual.csv").is_file() + + +def test_statistics_survival_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_survival_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_regression_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that regression statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + y_true = np.random.uniform(0, 100, size=n_patients) + y_pred = y_true + np.random.randn(n_patients) * 10 # noisy predictions + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "target": y_true, + "pred": y_pred, + } + ) + df.to_csv(tmp_path / f"regression-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="regression", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"regression-preds-{i}.csv" for i in range(n_folds)], + ground_truth_label="target", + ) + + assert (tmp_path / "output" / "target_regression-stats_individual.csv").is_file() + assert (tmp_path / "output" / "target_regression-stats_aggregated.csv").is_file() + + +def test_statistics_regression_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_regression_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_multi_target_classification_integration( + *, + tmp_path: Path, + n_patient_preds: int = 1, +) -> None: + """Check that multi-target classification statistics run without crashing. + + Multi-target predictions produce separate ground-truth columns per target. + We run compute_stats_ once per target, as the statistics pipeline handles + one target at a time. + """ + random.seed(0) + np.random.seed(0) + torch.random.manual_seed(0) + + categories_per_target = {"subtype": ["A", "B"], "grade": ["1", "2", "3"]} + + for pred_i in range(n_patient_preds): + n_patients = random.randint(100, 500) + data: dict[str, list] = { + "patient": [random_string(8) for _ in range(n_patients)], + } + + for target_label, cats in categories_per_target.items(): + data[target_label] = [random.choice(cats) for _ in range(n_patients)] + probs = torch.softmax(torch.rand(len(cats), n_patients), dim=0) + for j, cat in enumerate(cats): + data[f"{target_label}_{cat}"] = probs[j].tolist() + + pd.DataFrame(data).to_csv( + tmp_path / f"multi-target-preds-{pred_i}.csv", index=False + ) + + # Run statistics per target (as the pipeline would do) + for target_label, cats in categories_per_target.items(): + true_class = cats[0] + compute_stats_( + task="classification", + output_dir=tmp_path / "output" / target_label, + pred_csvs=[ + tmp_path / f"multi-target-preds-{i}.csv" for i in range(n_patient_preds) + ], + ground_truth_label=target_label, + true_class=true_class, + ) + + assert ( + tmp_path + / "output" + / target_label + / f"{target_label}_categorical-stats_aggregated.csv" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"roc-curve_{target_label}={true_class}.svg" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"pr-curve_{target_label}={true_class}.svg" + ).is_file() + + +def test_statistics_multi_target_classification_multiple_preds( + tmp_path: Path, +) -> None: + return test_statistics_multi_target_classification_integration( + tmp_path=tmp_path, n_patient_preds=3 + ) diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index 0180d171..ea5547f1 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -8,6 +8,7 @@ import torch from random_data import ( create_random_dataset, + create_random_multi_target_dataset, create_random_patient_level_dataset, create_random_patient_level_survival_dataset, create_random_regression_dataset, @@ -22,8 +23,9 @@ VitModelParams, ) from stamp.modeling.deploy import deploy_categorical_model_ +from stamp.modeling.registry import ModelName from stamp.modeling.train import train_categorical_model_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow @@ -123,6 +125,9 @@ def test_train_deploy_integration( pytest.param(False, True, id="use vary_precision_transform"), ], ) +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_integration( *, tmp_path: Path, @@ -356,6 +361,9 @@ def test_train_deploy_survival_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_regression_integration( *, tmp_path: Path, @@ -465,6 +473,9 @@ def test_train_deploy_patient_level_regression_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_survival_integration( *, tmp_path: Path, @@ -531,3 +542,89 @@ def test_train_deploy_patient_level_survival_integration( accelerator="gpu" if torch.cuda.is_available() else "cpu", num_workers=min(os.cpu_count() or 1, 16), ) + + +@pytest.mark.slow +@pytest.mark.filterwarnings("ignore:No positive samples in targets") +def test_train_deploy_multi_target_integration( + *, + tmp_path: Path, + feat_dim: int = 25, +) -> None: + """Integration test: train + deploy a multi-target tile-level classification model.""" + Seed.set(42) + + (tmp_path / "train").mkdir() + (tmp_path / "deploy").mkdir() + + # Define multi-target setup: subtype (2 categories) and grade (3 categories) + target_labels = ["subtype", "grade"] + categories_per_target = [["A", "B"], ["1", "2", "3"]] + + # Create random multi-target tile-level dataset + train_clini_path, train_slide_path, train_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "train", + n_patients=400, + max_slides_per_patient=3, + min_tiles_per_slide=20, + max_tiles_per_slide=600, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "deploy", + n_patients=50, + max_slides_per_patient=3, + min_tiles_per_slide=20, + max_tiles_per_slide=600, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + + # Build config objects + config = TrainConfig( + task="classification", + clini_table=train_clini_path, + slide_table=train_slide_path, + feature_dir=train_feature_dir, + output_dir=tmp_path / "train_output", + patient_label="patient", + ground_truth_label=target_labels, + filename_label="slide_path", + categories=[cat for cats in categories_per_target for cat in cats], + ) + + advanced = AdvancedConfig( + bag_size=500, + num_workers=min(os.cpu_count() or 1, 16), + batch_size=8, + max_epochs=2, + patience=1, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + model_params=ModelParams(), + model_name=ModelName.BARSPOON, + ) + + # Train + deploy multi-target model + train_categorical_model_(config=config, advanced=advanced) + + deploy_categorical_model_( + output_dir=tmp_path / "deploy_output", + checkpoint_paths=[tmp_path / "train_output" / "model.ckpt"], + clini_table=deploy_clini_path, + slide_table=deploy_slide_path, + feature_dir=deploy_feature_dir, + patient_label="patient", + ground_truth_label=target_labels, + time_label=None, + status_label=None, + filename_label="slide_path", + accelerator="gpu" if torch.cuda.is_available() else "cpu", + num_workers=min(os.cpu_count() or 1, 16), + )