From 98d6a67d776c5a88b86eba17b9129a25f93f2928 Mon Sep 17 00:00:00 2001 From: mducducd Date: Fri, 13 Feb 2026 14:01:16 +0000 Subject: [PATCH 1/5] add multi-target support; tests, fixes, and docs --- src/stamp/__main__.py | 4 +- src/stamp/config.yaml | 24 +- src/stamp/encoding/encoder/__init__.py | 2 +- src/stamp/encoding/encoder/chief.py | 2 +- src/stamp/encoding/encoder/eagle.py | 2 +- src/stamp/encoding/encoder/gigapath.py | 2 +- src/stamp/encoding/encoder/madeleine.py | 2 +- src/stamp/encoding/encoder/titan.py | 2 +- src/stamp/heatmaps/__init__.py | 8 +- src/stamp/modeling/config.py | 26 +- src/stamp/modeling/crossval.py | 230 ++++---- src/stamp/modeling/data.py | 512 +++++++++++++----- src/stamp/modeling/deploy.py | 321 +++++++++-- src/stamp/modeling/models/__init__.py | 79 ++- src/stamp/modeling/models/barspoon.py | 367 +++++++++++++ src/stamp/modeling/registry.py | 10 + src/stamp/modeling/train.py | 156 +++--- src/stamp/preprocessing/__init__.py | 2 +- .../extractor/chief_ctranspath.py | 2 +- .../preprocessing/extractor/ctranspath.py | 2 +- .../preprocessing/extractor/dinobloom.py | 2 +- src/stamp/preprocessing/tiling.py | 2 +- src/stamp/statistics/__init__.py | 40 +- src/stamp/statistics/survival.py | 12 +- src/stamp/types.py | 1 + src/stamp/{ => utils}/cache.py | 0 src/stamp/{ => utils}/config.py | 0 src/stamp/{ => utils}/seed.py | 0 src/stamp/utils/target_file.py | 351 ++++++++++++ tests/random_data.py | 86 +++ tests/test_cache_tiles.py | 2 +- tests/test_config.py | 2 +- tests/test_crossval.py | 2 +- tests/test_data.py | 2 +- tests/test_deployment.py | 31 +- tests/test_encoders.py | 2 +- tests/test_feature_extractors.py | 2 +- tests/test_model.py | 112 ++++ tests/test_statistics.py | 151 ++++++ tests/test_train_deploy.py | 99 +++- uv.lock | 9 +- 41 files changed, 2210 insertions(+), 453 deletions(-) create mode 100644 src/stamp/modeling/models/barspoon.py rename src/stamp/{ => utils}/cache.py (100%) rename src/stamp/{ => utils}/config.py (100%) rename src/stamp/{ => utils}/seed.py (100%) create mode 100644 src/stamp/utils/target_file.py diff --git a/src/stamp/__main__.py b/src/stamp/__main__.py index 4ab8416f..ffa98bae 100755 --- a/src/stamp/__main__.py +++ b/src/stamp/__main__.py @@ -6,14 +6,14 @@ import yaml -from stamp.config import StampConfig from stamp.modeling.config import ( AdvancedConfig, MlpModelParams, ModelParams, VitModelParams, ) -from stamp.seed import Seed +from stamp.utils.config import StampConfig +from stamp.utils.seed import Seed STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") diff --git a/src/stamp/config.yaml b/src/stamp/config.yaml index 796140a5..4f16dcb2 100644 --- a/src/stamp/config.yaml +++ b/src/stamp/config.yaml @@ -4,7 +4,7 @@ preprocessing: # Extractor to use for feature extractor. Possible options are "ctranspath", # "uni", "conch", "chief-ctranspath", "conch1_5", "uni2", "dino-bloom", # "gigapath", "h-optimus-0", "h-optimus-1", "virchow2", "virchow", - # "virchow-full", "musk", "mstar", "plip" + # "virchow-full", "musk", "mstar", "plip", "ticon" # Some of them require requesting access to the respective authors beforehand. extractor: "chief-ctranspath" @@ -73,6 +73,8 @@ crossval: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -130,6 +132,8 @@ training: # Name of the column from the clini table to train on. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -172,6 +176,8 @@ deployment: # Name of the column from the clini to compare predictions to. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # For survival (should be status and follow-up days columns in clini table) # status_label: "event" @@ -197,6 +203,8 @@ statistics: # Name of the target label. ground_truth_label: "KRAS" + # For multi-target classification you may specify a list of columns, + # e.g. ground_truth_label: ["KRAS", "BRAF", "NRAS"] # A lot of the statistics are computed "one-vs-all", i.e. there needs to be # a positive class to calculate the statistics for. @@ -316,7 +324,7 @@ advanced_config: max_lr: 1e-4 div_factor: 25. # Select a model regardless of task - model_name: "vit" # or mlp, trans_mil + model_name: "vit" # or mlp, trans_mil, barspoon model_params: vit: # Vision Transformer @@ -335,3 +343,15 @@ advanced_config: dim_hidden: 512 num_layers: 2 dropout: 0.25 + + # NOTE: Only the `barspoon` model supports multi-target classification + # (i.e. `ground_truth_label` can be a list of column names). Other + # models expect a single target column. + barspoon: # Encoder-Decoder Transformer for multi-target classification + d_model: 512 + num_encoder_heads: 8 + num_decoder_heads: 8 + num_encoder_layers: 2 + num_decoder_layers: 2 + dim_feedforward: 2048 + positional_encoding: true \ No newline at end of file diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 5827e884..86daa54a 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -12,11 +12,11 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.modeling.data import CoordsInfo, get_coords, read_table from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/chief.py b/src/stamp/encoding/encoder/chief.py index 2ad4b91b..924ceebb 100644 --- a/src/stamp/encoding/encoder/chief.py +++ b/src/stamp/encoding/encoder/chief.py @@ -10,11 +10,11 @@ from numpy import ndarray from tqdm import tqdm -from stamp.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest, get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/eagle.py b/src/stamp/encoding/encoder/eagle.py index 03bc833e..9266f315 100644 --- a/src/stamp/encoding/encoder/eagle.py +++ b/src/stamp/encoding/encoder/eagle.py @@ -9,13 +9,13 @@ from torch import Tensor from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.encoding.encoder.chief import CHIEF from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, PandasLabel +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/gigapath.py b/src/stamp/encoding/encoder/gigapath.py index 9cb3f6f5..4c0a2f6b 100644 --- a/src/stamp/encoding/encoder/gigapath.py +++ b/src/stamp/encoding/encoder/gigapath.py @@ -9,12 +9,12 @@ from gigapath import slide_encoder from tqdm import tqdm -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/encoding/encoder/madeleine.py b/src/stamp/encoding/encoder/madeleine.py index 5798a592..a0c74dcd 100644 --- a/src/stamp/encoding/encoder/madeleine.py +++ b/src/stamp/encoding/encoder/madeleine.py @@ -3,10 +3,10 @@ import torch from numpy import ndarray -from stamp.cache import STAMP_CACHE_DIR from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import STAMP_CACHE_DIR try: from madeleine.models.factory import create_model_from_pretrained diff --git a/src/stamp/encoding/encoder/titan.py b/src/stamp/encoding/encoder/titan.py index 1012d98f..568254ca 100644 --- a/src/stamp/encoding/encoder/titan.py +++ b/src/stamp/encoding/encoder/titan.py @@ -10,12 +10,12 @@ from tqdm import tqdm from transformers import AutoModel -from stamp.cache import get_processing_code_hash from stamp.encoding.config import EncoderName from stamp.encoding.encoder import Encoder from stamp.modeling.data import CoordsInfo from stamp.preprocessing.config import ExtractorName from stamp.types import DeviceLikeType, Microns, PandasLabel, SlideMPP +from stamp.utils.cache import get_processing_code_hash __author__ = "Juan Pablo Ricapito" __copyright__ = "Copyright (C) 2025 Juan Pablo Ricapito" diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index 446d85d6..22fb5250 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -70,14 +70,14 @@ def _attention_rollout_single( device = feats.device - # --- 1. Forward pass to fill attn_weights in each SelfAttention layer --- + # 1. Forward pass to fill attn_weights in each SelfAttention layer _ = model( bags=feats.unsqueeze(0), coords=coords.unsqueeze(0), mask=torch.zeros(1, len(feats), dtype=torch.bool, device=device), ) - # --- 2. Rollout computation --- + # 2. Rollout computation attn_rollout: torch.Tensor | None = None for layer in model.transformer.layers: # type: ignore attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights @@ -96,10 +96,10 @@ def _attention_rollout_single( if attn_rollout is None: raise RuntimeError("No attention maps collected from transformer layers.") - # --- 3. Extract CLS → tiles attention --- + # 3. Extract CLS → tiles attention cls_attn = attn_rollout[0, 1:] # [tile] - # --- 4. Normalize for visualization consistency --- + # 4. Normalize for visualization consistency cls_attn = cls_attn - cls_attn.min() cls_attn = cls_attn / (cls_attn.max().clamp(min=1e-8)) diff --git a/src/stamp/modeling/config.py b/src/stamp/modeling/config.py index 21ce69db..5b9a6bcc 100644 --- a/src/stamp/modeling/config.py +++ b/src/stamp/modeling/config.py @@ -21,7 +21,7 @@ class TrainConfig(BaseModel): ) feature_dir: Path = Field(description="Directory containing feature files") - ground_truth_label: PandasLabel | None = Field( + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = Field( default=None, description="Name of categorical column in clinical table to train on", ) @@ -64,7 +64,7 @@ class DeploymentConfig(BaseModel): slide_table: Path feature_dir: Path - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None patient_label: PandasLabel = "PATIENT" filename_label: PandasLabel = "FILENAME" @@ -99,8 +99,29 @@ class TransMILModelParams(BaseModel): dim_hidden: int = 512 +class BarspoonParams(BaseModel): + model_config = ConfigDict(extra="forbid") + d_model: int = 512 + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 + + class LinearModelParams(BaseModel): model_config = ConfigDict(extra="forbid") + num_encoder_heads: int = 8 + num_decoder_heads: int = 8 + num_encoder_layers: int = 2 + num_decoder_layers: int = 2 + dim_feedforward: int = 2048 + positional_encoding: bool = True + # Other hparams + learning_rate: float = 1e-4 class ModelParams(BaseModel): @@ -109,6 +130,7 @@ class ModelParams(BaseModel): trans_mil: TransMILModelParams = Field(default_factory=TransMILModelParams) mlp: MlpModelParams = Field(default_factory=MlpModelParams) linear: LinearModelParams = Field(default_factory=LinearModelParams) + barspoon: BarspoonParams = Field(default_factory=BarspoonParams) class AdvancedConfig(BaseModel): diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 43e76f01..2caccecb 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -1,8 +1,10 @@ import logging +from collections import Counter from collections.abc import Mapping, Sequence -from typing import Any, Final +from typing import Any, cast import numpy as np +import torch from pydantic import BaseModel from sklearn.model_selection import KFold, StratifiedKFold @@ -10,13 +12,8 @@ from stamp.modeling.data import ( PatientData, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, + load_patient_data_, log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, ) from stamp.modeling.deploy import ( _predict, @@ -28,7 +25,6 @@ from stamp.modeling.train import setup_model_for_training, train_model_ from stamp.modeling.transforms import VaryPrecisionTransform from stamp.types import ( - FeaturePath, GroundTruth, PatientId, ) @@ -53,67 +49,31 @@ def categorical_crossval_( config: CrossvalConfig, advanced: AdvancedConfig, ) -> None: - feature_type = detect_feature_type(config.feature_dir) + if config.task is None: + raise ValueError( + "task must be set to 'classification' | 'regression' | 'survival'" + ) + + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) _logger.info(f"Detected feature type: {feature_type}") - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label are is required for survival modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for classification or regression modeling" - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = ( - patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - ) - slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( - slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - ) - patient_to_data: Mapping[PatientId, PatientData] = ( - filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - ) - elif feature_type == "patient": - patient_to_data: Mapping[PatientId, PatientData] = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - patient_to_ground_truth: dict[PatientId, GroundTruth] = { - pid: pd.ground_truth for pid, pd in patient_to_data.items() - } - else: - raise RuntimeError(f"Unsupported feature type: {feature_type}") + patient_to_ground_truth = { + pid: pd.ground_truth for pid, pd in patient_to_data.items() + } + + if feature_type not in ("tile", "slide", "patient"): + raise ValueError(f"Unknown feature type: {feature_type}") config.output_dir.mkdir(parents=True, exist_ok=True) splits_file = config.output_dir / "splits.json" @@ -158,18 +118,51 @@ def categorical_crossval_( f"{ground_truths_not_in_split}" ) + categories_for_export: ( + dict[str, list] | list + ) = [] # declare upfront to avoid unbound variable warnings + categories: Sequence[GroundTruth] | list | None = [] # type: ignore # declare upfront to avoid unbound variable warnings + if config.task == "classification": - categories = config.categories or sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - } - ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, - ) + # Determine categories for training (single-target) and for export (supports multi-target) + if isinstance(config.ground_truth_label, str): + categories = config.categories or sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + } + ) + log_patient_class_summary( + patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, + categories=categories, + ) + categories_for_export = cast(list, categories) + else: + # Multi-target: build a mapping from target label -> sorted list of categories + categories_accum: dict[str, set[GroundTruth]] = {} + for patient_data in patient_to_data.values(): + gt = patient_data.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + # Log summary per target + for t, cats in categories_for_export.items(): + ground_truths = [ + pd.ground_truth.get(t) + for pd in patient_to_data.values() + if isinstance(pd.ground_truth, dict) + and pd.ground_truth.get(t) is not None + ] + counter = Counter(ground_truths) + _logger.info( + f"{t} | Total patients: {len(ground_truths)} | " + + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) + ) + # For training, categories can remain None (inferred later) + categories = config.categories or None else: categories = [] @@ -206,12 +199,18 @@ def categorical_crossval_( }, categories=( categories - or sorted( - { - patient_data.ground_truth - for patient_data in patient_to_data.values() - if patient_data.ground_truth is not None - } + if categories is not None + else ( + sorted( + { + patient_data.ground_truth + for patient_data in patient_to_data.values() + if patient_data.ground_truth is not None + and not isinstance(patient_data.ground_truth, dict) + } + ) + if not isinstance(config.ground_truth_label, Sequence) + else None ) ), train_transform=( @@ -263,30 +262,48 @@ def categorical_crossval_( ) if config.task == "survival": - _to_survival_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - cut_off=getattr(model.hparams, "train_pred_median", None), - ).to_csv(split_dir / "patient-preds.csv", index=False) + if isinstance(config.ground_truth_label, str): + _to_survival_prediction_df( + patient_to_ground_truth=cast( + Mapping[PatientId, str | None], patient_to_ground_truth + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + cut_off=getattr(model.hparams, "train_pred_median", None), + ).to_csv(split_dir / "patient-preds.csv", index=False) + else: + _logger.warning( + "Multi-target survival prediction export not yet supported; skipping CSV save" + ) elif config.task == "regression": if config.ground_truth_label is None: raise RuntimeError("Grounf truth label is required for regression") - _to_regression_prediction_df( - patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - ).to_csv(split_dir / "patient-preds.csv", index=False) + if isinstance(config.ground_truth_label, str): + _to_regression_prediction_df( + patient_to_ground_truth=cast( + Mapping[PatientId, str | None], patient_to_ground_truth + ), + predictions=cast(Mapping[PatientId, torch.Tensor], predictions), + patient_label=config.patient_label, + ground_truth_label=config.ground_truth_label, + ).to_csv(split_dir / "patient-preds.csv", index=False) + else: + _logger.warning( + "Multi-target regression prediction export not yet supported; skipping CSV save" + ) else: if config.ground_truth_label is None: raise RuntimeError( "Grounf truth label is required for classification" ) _to_prediction_df( - categories=categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions=predictions, + predictions=cast( + Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], + predictions, + ), patient_label=config.patient_label, ground_truth_label=config.ground_truth_label, ).to_csv(split_dir / "patient-preds.csv", index=False) @@ -296,6 +313,16 @@ def _get_splits( *, patient_to_data: Mapping[PatientId, PatientData[Any]], n_splits: int, spliter ) -> _Splits: patients = np.array(list(patient_to_data.keys())) + + # Extract ground truth for stratification. + # For multi-target (dict), use the first target's value + y_strat = np.array( + [ + next(iter(gt.values())) if isinstance(gt, dict) else gt + for gt in [patient.ground_truth for patient in patient_to_data.values()] + ] + ) + skf = spliter(n_splits=n_splits, shuffle=True, random_state=0) splits = _Splits( splits=[ @@ -303,12 +330,7 @@ def _get_splits( train_patients=set(patients[train_indices]), test_patients=set(patients[test_indices]), ) - for train_indices, test_indices in skf.split( - patients, - np.array( - [patient.ground_truth for patient in patient_to_data.values()] - ), - ) + for train_indices, test_indices in skf.split(patients, y_strat) ] ) return splits diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 6cabec64..8b30ab11 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -5,19 +5,29 @@ from dataclasses import KW_ONLY, dataclass from itertools import groupby from pathlib import Path -from typing import IO, BinaryIO, Counter, Generic, TextIO, TypeAlias, Union, cast +from typing import ( + IO, + Any, + BinaryIO, + Dict, + Final, + Generic, + List, + TextIO, + TypeAlias, + Union, + cast, +) import h5py import numpy as np import pandas as pd import torch -from jaxtyping import Float from packaging.version import Version from torch import Tensor from torch.utils.data import DataLoader, Dataset import stamp -from stamp.seed import Seed from stamp.types import ( Bags, BagSize, @@ -35,6 +45,7 @@ Task, TilePixels, ) +from stamp.utils.seed import Seed _logger = logging.getLogger("stamp") @@ -43,14 +54,17 @@ __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" -_Bag: TypeAlias = Float[Tensor, "tile feature"] -_EncodedTarget: TypeAlias = Float[Tensor, "category_is_hot"] | Float[Tensor, "1"] # noqa: F821 +_Bag: TypeAlias = Tensor +_EncodedTarget: TypeAlias = ( + Tensor | dict[str, Tensor] +) # Union of encoded targets or multi-target dict _BinaryIOLike: TypeAlias = Union[BinaryIO, IO[bytes]] """The ground truth, encoded numerically - classification: one-hot float [C] - regression: float [1] +- multi-target: dict[target_name -> one-hot/regression value] """ -_Coordinates: TypeAlias = Float[Tensor, "tile 2"] +_Coordinates: TypeAlias = Tensor @dataclass @@ -64,7 +78,7 @@ class PatientData(Generic[GroundTruthType]): def tile_bag_dataloader( *, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None, task: Task, categories: Sequence[Category] | None = None, @@ -74,7 +88,7 @@ def tile_bag_dataloader( transform: Callable[[Tensor], Tensor] | None, ) -> tuple[ DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], ]: """Creates a dataloader from patient data for tile-level (bagged) features. @@ -86,103 +100,139 @@ def tile_bag_dataloader( task='regression': returns float targets """ - if task == "classification": - raw_ground_truths = np.array([patient.ground_truth for patient in patient_data]) - categories = ( - categories if categories is not None else list(np.unique(raw_ground_truths)) - ) - # one_hot = torch.tensor(raw_ground_truths.reshape(-1, 1) == categories) - one_hot = torch.tensor( - raw_ground_truths.reshape(-1, 1) == categories, dtype=torch.float32 - ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=one_hot, - transform=transform, - ) - cats_out: Sequence[Category] = list(categories) - elif task == "regression": - raw_targets = np.array( - [ - np.nan if p.ground_truth is None else float(p.ground_truth) - for p in patient_data - ], - dtype=np.float32, - ) - y = torch.from_numpy(raw_targets).reshape(-1, 1) + targets, cats_out = _parse_targets( + patient_data=patient_data, + task=task, + categories=categories, + ) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out = [] + is_multitarget = isinstance(targets[0], dict) - elif task == "survival": # Not yet support logistic-harzard - times: list[float] = [] - events: list[float] = [] + collate_fn = _collate_multitarget if is_multitarget else _collate_to_tuple - for p in patient_data: - if p.ground_truth is None: - times.append(np.nan) - events.append(np.nan) - continue + ds = BagDataset( + bags=[patient.feature_files for patient in patient_data], + bag_size=bag_size, + ground_truths=targets, + transform=transform, + ) + dl = DataLoader( + ds, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + worker_init_fn=Seed.get_loader_worker_init() if Seed._is_set() else None, + ) - try: - time_str, status_str = p.ground_truth.split(" ", 1) + return ( + cast( + DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], + dl, + ), + cats_out, + ) - # Handle missing values encoded as "nan" - if time_str.lower() == "nan": - times.append(np.nan) - else: - times.append(float(time_str)) - - if status_str.lower() == "nan": - events.append(np.nan) - elif status_str.lower() in {"dead", "event", "1", "Yes", "yes"}: - events.append(1.0) - elif status_str.lower() in {"alive", "censored", "0", "No", "no"}: - events.append(0.0) - else: - events.append(np.nan) # unknown status → mark missing - except Exception: +def _parse_targets( + *, + patient_data: Sequence, + task: Task, + categories: Sequence[Category] | None = None, + target_spec: dict[str, Any] | None = None, + target_label: str | None = None, +) -> tuple[ + Union[torch.Tensor, list[dict[str, torch.Tensor]]], + Sequence[Category] | Mapping[str, Sequence[Category]], +]: + """ + Parse raw GroundTruth (str) into model-ready tensors. + This is the ONLY place task semantics live. + """ + + gts = [p.ground_truth for p in patient_data] + + if task == "classification": + if any(isinstance(gt, dict) for gt in gts if gt is not None): + # infer target names from the first non-None dict + first_dict = next(gt for gt in gts if isinstance(gt, dict)) + target_names = list(first_dict.keys()) + + # infer categories per target (ignore None patients, ignore None values) + categories_out: dict[str, list[str]] = {t: [] for t in target_names} + for gt in gts: + if not isinstance(gt, dict): + continue + for t in target_names: + v = gt.get(t) + if v is not None: + categories_out[t].append(v) + + # make unique + sorted + categories_out = { + t: sorted(set(vals)) for t, vals in categories_out.items() + } + + # encode per patient; if gt missing -> all zeros + encoded: list[dict[str, Tensor]] = [] + for gt in gts: + patient_encoded: dict[str, Tensor] = {} + for t in target_names: + cats = categories_out[t] + if not isinstance(gt, dict) or gt.get(t) is None: + one_hot = torch.zeros(len(cats), dtype=torch.float32) + else: + one_hot = torch.tensor( + [gt[t] == c for c in cats], + dtype=torch.float32, + ) + patient_encoded[t] = one_hot + encoded.append(patient_encoded) + + # IMPORTANT: return categories as mapping, not list-of-target-names + return encoded, categories_out + + # single target + unique = {gt for gt in gts if gt is not None} + if len(unique) >= 2 or categories is not None: + raw = np.array([p.ground_truth for p in patient_data]) + categories = categories or list(sorted(unique)) + labels = torch.tensor( + raw.reshape(-1, 1) == categories, + dtype=torch.float32, + ) + return labels, categories + + raise ValueError( + "Only one unique class found in classification task. " + "This is usually a data or configuration error." + ) + + elif task == "regression": + y = torch.tensor( + [np.nan if gt is None else float(gt) for gt in gts], + dtype=torch.float32, + ).reshape(-1, 1) + return y, [] + + elif task == "survival": + times, events = [], [] + for gt in gts: + if gt is None: times.append(np.nan) events.append(np.nan) + continue - # Final tensor shape: (N, 2) - y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) + time_str, status_str = gt.split(" ", 1) + times.append(np.nan if time_str.lower() == "nan" else float(time_str)) + events.append(_parse_survival_status(status_str)) - ds = BagDataset( - bags=[patient.feature_files for patient in patient_data], - bag_size=bag_size, - ground_truths=y, - transform=transform, - ) - cats_out: Sequence[Category] = [] # survival has no categories + y = torch.tensor(np.column_stack([times, events]), dtype=torch.float32) + return y, [] else: - raise ValueError(f"Unknown task: {task}") - - return ( - cast( - DataLoader[tuple[Bags, CoordinatesBatch, BagSizes, EncodedTargets]], - DataLoader( - ds, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - collate_fn=_collate_to_tuple, - worker_init_fn=Seed.get_loader_worker_init() - if Seed._is_set() - else None, - ), - ), - cats_out, - ) + raise ValueError(f"Unsupported task: {task}") def _collate_to_tuple( @@ -210,6 +260,24 @@ def _collate_to_tuple( return (bags, coords, bag_sizes, encoded_targets) +def _collate_multitarget( + items: list[tuple[_Bag, _Coordinates, BagSize, Dict[str, Tensor]]], +) -> tuple[Bags, CoordinatesBatch, BagSizes, Dict[str, Tensor]]: + bags = torch.stack([b for b, _, _, _ in items]) + coords = torch.stack([c for _, c, _, _ in items]) + bag_sizes = torch.tensor([s for _, _, s, _ in items]) + + acc: Dict[str, List[Tensor]] = {} + + for _, _, _, tdict in items: + for k, v in tdict.items(): + acc.setdefault(k, []).append(v) + + targets: Dict[str, Tensor] = {k: torch.stack(v, dim=0) for k, v in acc.items()} + + return bags, coords, bag_sizes, targets + + def patient_feature_dataloader( *, patient_data: Sequence[PatientData[GroundTruth | None]], @@ -237,21 +305,31 @@ def create_dataloader( *, feature_type: str, task: Task, - patient_data: Sequence[PatientData[GroundTruth | None]], + patient_data: Sequence[PatientData[GroundTruth | None | dict]], bag_size: int | None = None, batch_size: int, shuffle: bool, num_workers: int, transform: Callable[[Tensor], Tensor] | None, - categories: Sequence[Category] | None = None, -) -> tuple[DataLoader, Sequence[Category]]: + categories: Sequence[Category] | Mapping[str, Sequence[Category]] | None = None, +) -> tuple[DataLoader, Sequence[Category] | Mapping[str, Sequence[Category]]]: """Unified dataloader for all feature types and tasks.""" if feature_type == "tile": + # For multi-target classification, categories may be a mapping from + # target name to per-target categories. _parse_targets (used inside + # tile_bag_dataloader) only consumes explicit categories for the + # single-target case, so we pass a sequence or None here. + cats_arg: Sequence[Category] | None + if isinstance(categories, Mapping): + cats_arg = None + else: + cats_arg = categories + return tile_bag_dataloader( patient_data=patient_data, bag_size=bag_size, task=task, - categories=categories, + categories=cats_arg, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, @@ -263,21 +341,32 @@ def create_dataloader( if task == "classification": raw = np.array([p.ground_truth for p in patient_data]) - categories = categories or list(np.unique(raw)) - labels = torch.tensor(raw.reshape(-1, 1) == categories, dtype=torch.float32) - elif task == "regression": + categories_out = categories or list(np.unique(raw)) labels = torch.tensor( - [ - float(gt) - for gt in (p.ground_truth for p in patient_data) - if gt is not None - ], - dtype=torch.float32, - ).reshape(-1, 1) + raw.reshape(-1, 1) == categories_out, dtype=torch.float32 + ) + elif task == "regression": + values: list[float] = [] + for gt in (p.ground_truth for p in patient_data): + if gt is None: + continue + if isinstance(gt, dict): + # Use first value for multi-target regression + first_val = next(iter(gt.values())) + values.append(float(first_val)) + else: + values.append(float(gt)) + + labels = torch.tensor(values, dtype=torch.float32).reshape(-1, 1) elif task == "survival": times, events = [], [] for p in patient_data: - t, e = (p.ground_truth or "nan nan").split(" ", 1) + if isinstance(p.ground_truth, dict): + # Multi-target survival: use first target + val = list(p.ground_truth.values())[0] + t, e = (val or "nan nan").split(" ", 1) + else: + t, e = (p.ground_truth or "nan nan").split(" ", 1) times.append(float(t) if t.lower() != "nan" else np.nan) events.append(_parse_survival_status(e)) @@ -340,9 +429,9 @@ def load_patient_level_data( clini_table: Path, feature_dir: Path, patient_label: PandasLabel, - ground_truth_label: PandasLabel | None = None, # <- now optional - time_label: PandasLabel | None = None, # <- for survival - status_label: PandasLabel | None = None, # <- for survival + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, + time_label: PandasLabel | None = None, + status_label: PandasLabel | None = None, feature_ext: str = ".h5", ) -> dict[PatientId, PatientData]: """ @@ -419,7 +508,7 @@ class BagDataset(Dataset[tuple[_Bag, _Coordinates, BagSize, _EncodedTarget]]): If `bag_size` is None, all the samples will be used. """ - ground_truths: Float[Tensor, "index category_is_hot"] | Float[Tensor, "index 1"] + ground_truths: Tensor | list[dict[str, Tensor]] # ground_truths: Bool[Tensor, "index category_is_hot"] # """The ground truth for each bag, one-hot encoded.""" @@ -529,7 +618,7 @@ def mpp(self) -> SlideMPP: def get_coords(feature_h5: h5py.File) -> CoordsInfo: - # --- NEW: handle missing coords ----multiplex data bypass: no coords found; generated fake coords + # NEW: handle missing coords - multiplex data bypass: no coords found; generated fake coords if "coords" not in feature_h5: feats_obj = feature_h5["patch_embeddings"] @@ -627,33 +716,71 @@ def patient_to_ground_truth_from_clini_table_( *, clini_table_path: Path | TextIO, patient_label: PandasLabel, - ground_truth_label: PandasLabel, -) -> dict[PatientId, GroundTruth]: - """Loads the patients and their ground truths from a clini table.""" + ground_truth_label: PandasLabel | Sequence[PandasLabel], +) -> ( + dict[PatientId, GroundTruth | None] | dict[PatientId, dict[str, GroundTruth | None]] +): + """Loads the patients and their ground truths from a clini table. + + `ground_truth_label` may be either a single column name (str) or a sequence + of column names. In the latter case the returned mapping will contain a + dict mapping column -> value for each patient (supporting multi-target + setups). + """ + # Normalize to list for uniform handling + if isinstance(ground_truth_label, str): + cols = [patient_label, ground_truth_label] + multi = False + target_cols_inner: list[PandasLabel] = [] + else: + cols = [patient_label, *list(ground_truth_label)] + multi = True + target_cols_inner = [c for c in cols if c != patient_label] + clini_df = read_table( clini_table_path, - usecols=[patient_label, ground_truth_label], + usecols=cols, dtype=str, - ).dropna() + ) + + # If multi-target, keep rows where at least one target is present; for + # single target behave like before and drop rows missing the value. + if multi: + clini_df = clini_df.dropna(subset=target_cols_inner, how="all") + else: + clini_df = clini_df.dropna(subset=[ground_truth_label]) + try: - patient_to_ground_truth: Mapping[PatientId, GroundTruth] = clini_df.set_index( - patient_label, verify_integrity=True - )[ground_truth_label].to_dict() + if multi: + # Build mapping patient -> {col: value} + result: dict[PatientId, dict[str, GroundTruth | None]] = {} + for _, row in clini_df.iterrows(): + pid = row[patient_label] + # Convert pandas nan to None and keep strings otherwise + result[pid] = { + col: (None if pd.isna(row[col]) else str(row[col])) + for col in target_cols_inner + } + return result + else: + patient_to_ground_truth: Mapping[PatientId, str] = cast( + Mapping[PatientId, str], + clini_df.set_index(patient_label, verify_integrity=True)[ + cast(PandasLabel, ground_truth_label) + ].to_dict(), + ) + return cast(dict[PatientId, GroundTruth | None], patient_to_ground_truth) except KeyError as e: if patient_label not in clini_df: raise ValueError( f"{patient_label} was not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - elif ground_truth_label not in clini_df: + else: raise ValueError( - f"{ground_truth_label} was not found in clini table " + f"One or more ground truth columns were not found in clini table " f"(columns in clini table: {clini_df.columns})" ) from e - else: - raise e from e - - return patient_to_ground_truth def patient_to_survival_from_clini_table_( @@ -667,7 +794,6 @@ def patient_to_survival_from_clini_table_( Loads patients and their survival ground truths (time + event) from a clini table. Returns - ------- dict[PatientId, GroundTruth] Mapping patient_id -> "time status" (e.g. "302 dead", "476 alive"). """ @@ -780,7 +906,9 @@ def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: def filter_complete_patient_data_( *, - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth: Mapping[ + PatientId, GroundTruth | dict[str, GroundTruth] | None + ], slide_to_patient: Mapping[FeaturePath, PatientId], drop_patients_with_missing_ground_truth: bool, ) -> Mapping[PatientId, PatientData]: @@ -865,7 +993,7 @@ def _log_patient_slide_feature_inconsistencies( ) -def get_stride(coords: Float[Tensor, "tile 2"]) -> float: +def get_stride(coords: Tensor) -> float: """Gets the minimum step width between any two coordintes.""" xs: Tensor = coords[:, 0].unique(sorted=True) ys: Tensor = coords[:, 1].unique(sorted=True) @@ -927,26 +1055,130 @@ def _parse_survival_status(value) -> int | None: ) +def load_patient_data_( + *, + feature_dir: Path, + clini_table: Path, + slide_table: Path | None, + task: Task, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + time_label: PandasLabel | None, + status_label: PandasLabel | None, + patient_label: PandasLabel, + filename_label: PandasLabel, + drop_patients_with_missing_ground_truth: bool = True, +) -> tuple[Mapping[PatientId, PatientData], str]: + """Load patient data based on feature type (tile, slide, or patient). + + This consolidates the common data loading logic used across train, crossval, and deploy. + + Returns: + (patient_to_data, feature_type) + """ + feature_type = detect_feature_type(feature_dir) + + if feature_type in ("tile", "slide"): + if slide_table is None: + raise ValueError("A slide table is required for tile/slide-level features") + + # Load ground truth based on task + if task == "survival": + if time_label is None or status_label is None: + raise ValueError( + "Both time_label and status_label are required for survival modeling" + ) + patient_to_ground_truth = patient_to_survival_from_clini_table_( + clini_table_path=clini_table, + time_label=time_label, + status_label=status_label, + patient_label=patient_label, + ) + else: + if ground_truth_label is None: + raise ValueError( + "Ground truth label is required for classification or regression modeling" + ) + patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( + clini_table_path=clini_table, + ground_truth_label=ground_truth_label, + patient_label=patient_label, + ) + + # Link slides to patients + slide_to_patient: Final[dict[FeaturePath, PatientId]] = ( + slide_to_patient_from_slide_table_( + slide_table_path=slide_table, + feature_dir=feature_dir, + patient_label=patient_label, + filename_label=filename_label, + ) + ) + + # Filter to complete patient data + patient_to_data = filter_complete_patient_data_( + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | dict[str, GroundTruth] | None], + patient_to_ground_truth, + ), + slide_to_patient=slide_to_patient, + drop_patients_with_missing_ground_truth=drop_patients_with_missing_ground_truth, + ) + elif feature_type == "patient": + patient_to_data = load_patient_level_data( + task=task, + clini_table=clini_table, + feature_dir=feature_dir, + patient_label=patient_label, + ground_truth_label=ground_truth_label, + time_label=time_label, + status_label=status_label, + ) + else: + raise RuntimeError(f"Unknown feature type: {feature_type}") + + return patient_to_data, feature_type + + def log_patient_class_summary( *, patient_to_data: Mapping[PatientId, PatientData], categories: Sequence[Category] | None, - prefix: str = "", ) -> None: + """ + Logs class distribution. + Supports both single-target and multi-target classification. + """ + ground_truths = [ - pd.ground_truth - for pd in patient_to_data.values() - if pd.ground_truth is not None + p.ground_truth for p in patient_to_data.values() if p.ground_truth is not None ] if not ground_truths: - _logger.warning(f"{prefix}No ground truths available to summarize.") + _logger.warning("No ground truths available for summary.") return - cats = categories or sorted(set(ground_truths)) - counter = Counter(ground_truths) + # Multi-target + if isinstance(ground_truths[0], dict): + # Collect per-target values + per_target: dict[str, list] = {} - _logger.info( - f"{prefix}Total patients: {len(ground_truths)} | " - + " | ".join([f"Class {c}: {counter.get(c, 0)}" for c in cats]) - ) + for gt in ground_truths: + for key, value in gt.items(): + per_target.setdefault(key, []).append(value) + + for target_name, values in per_target.items(): + counts = {} + for v in values: + counts[v] = counts.get(v, 0) + 1 + + _logger.info( + f"[Multi-target] Target '{target_name}' distribution: {counts}" + ) + + # Single-target + else: + counts = {} + for gt in ground_truths: + counts[gt] = counts.get(gt, 0) + 1 + + _logger.info(f"Class distribution: {counts}") diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 905c6005..10272328 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd import torch -from jaxtyping import Float +import torch.nn.functional as F from lightning.pytorch.accelerators.accelerator import Accelerator from stamp.modeling.data import ( @@ -20,7 +20,7 @@ slide_to_patient_from_slide_table_, ) from stamp.modeling.registry import ModelName, load_model_class -from stamp.types import GroundTruth, PandasLabel, PatientId +from stamp.types import Category, GroundTruth, PandasLabel, PatientId __all__ = ["deploy_categorical_model_"] @@ -32,6 +32,13 @@ Logit: TypeAlias = float +# Prediction type aliases +PredictionSingle: TypeAlias = torch.Tensor +PredictionMulti: TypeAlias = dict[str, torch.Tensor] +PredictionsType: TypeAlias = Mapping[ + PatientId, Union[PredictionSingle, PredictionMulti] +] + def load_model_from_ckpt(path: Union[str, Path]): ckpt = torch.load(path, map_location="cpu", weights_only=False) @@ -50,7 +57,7 @@ def deploy_categorical_model_( clini_table: Path | None, slide_table: Path | None, feature_dir: Path, - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, patient_label: PandasLabel, @@ -126,7 +133,12 @@ def deploy_categorical_model_( # classification/regression: still use ground_truth_label if ( len( - ground_truth_labels := set(model.ground_truth_label for model in models) + ground_truth_labels := { + tuple(model.ground_truth_label) + if isinstance(model.ground_truth_label, list) + else (model.ground_truth_label,) + for model in models + } ) != 1 ): @@ -145,17 +157,21 @@ def deploy_categorical_model_( f"{ground_truth_label} vs {model_ground_truth_label}" ) - ground_truth_label = ground_truth_label or model_ground_truth_label + ground_truth_label = ground_truth_label or cast( + PandasLabel, model_ground_truth_label + ) output_dir.mkdir(exist_ok=True, parents=True) model_categories = None if task == "classification": # Ensure the categories were the same between all models - category_sets = {tuple(m.categories) for m in models} + category_sets = { + tuple(cast(Sequence[GroundTruth], m.categories)) for m in models + } if len(category_sets) != 1: raise RuntimeError(f"Categories differ between models: {category_sets}") - model_categories = list(models[0].categories) + model_categories = list(cast(Sequence[GroundTruth], models[0].categories)) # Data loading logic if feature_type in ("tile", "slide"): @@ -171,6 +187,15 @@ def deploy_categorical_model_( ) if clini_table is not None: if task == "survival": + if not hasattr(models[0], "time_label") or not isinstance( + models[0].time_label, str + ): + raise AttributeError("Model is missing valid 'time_label' (str).") + if not hasattr(models[0], "status_label") or not isinstance( + models[0].status_label, str + ): + raise AttributeError("Model is missing valid 'status_label' (str).") + patient_to_ground_truth = patient_to_survival_from_clini_table_( clini_table_path=clini_table, patient_label=patient_label, @@ -192,7 +217,10 @@ def deploy_categorical_model_( patient_id: None for patient_id in set(slide_to_patient.values()) } patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, + patient_to_ground_truth=cast( + Mapping[PatientId, GroundTruth | None], + patient_to_ground_truth, + ), slide_to_patient=slide_to_patient, drop_patients_with_missing_ground_truth=False, ) @@ -241,7 +269,10 @@ def deploy_categorical_model_( "regression": _to_regression_prediction_df, "survival": _to_survival_prediction_df, }[task] - all_predictions: list[Mapping[PatientId, Float[torch.Tensor, "category"]]] = [] # noqa: F821 + all_predictions: list[PredictionsType] = [] + categories_for_export: ( + Sequence[Category] | Mapping[str, Sequence[Category]] | None + ) = cast(Sequence[Category] | Mapping[str, Sequence[Category]] | None, None) for model_i, model in enumerate(models): predictions = _predict( model=model, @@ -251,6 +282,26 @@ def deploy_categorical_model_( ) all_predictions.append(predictions) + if isinstance(next(iter(predictions.values())), dict): + # Multi-target case: gather categories across all targets for export (use model categories if available, else infer from GT) + categories_accum: dict[str, set[GroundTruth]] = {} + + for pd_item in patient_to_data.values(): + gt = pd_item.ground_truth + if isinstance(gt, dict): + for k, v in gt.items(): + if v is not None: + categories_accum.setdefault(k, set()).add(v) + + categories_for_export = {k: sorted(v) for k, v in categories_accum.items()} + + else: + # Single-target case: use categories from model if available, else infer from GT + if task == "classification": + categories_for_export = models[0].categories + else: + categories_for_export = [] + # cut-off values from survival ckpt cut_off = ( getattr(model.hparams, "train_pred_median", None) @@ -261,7 +312,7 @@ def deploy_categorical_model_( # Only save individual model files when deploying multiple models (ensemble) if len(models) > 1: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -270,7 +321,7 @@ def deploy_categorical_model_( ).to_csv(output_dir / f"patient-preds-{model_i}.csv", index=False) else: df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, predictions=predictions, patient_label=patient_label, @@ -279,17 +330,29 @@ def deploy_categorical_model_( ).to_csv(output_dir / "patient-preds.csv", index=False) if task == "classification": - # TODO we probably also want to save the 95% confidence interval in addition to the mean + # compute mean prediction across models (supports single- and multi-target) + mean_preds: dict[PatientId, object] = {} + for pid in patient_ids: + model_preds = cast( + list[torch.Tensor], [preds[pid] for preds in all_predictions] + ) + firstp = model_preds[0] + if isinstance(firstp, dict): + # per-target averaging + mean_preds[pid] = { + t: torch.stack([p[t] for p in model_preds]).mean(dim=0) + for t in firstp.keys() + } + else: + mean_preds[pid] = torch.stack(model_preds).mean(dim=0) + + assert categories_for_export is not None, ( + "categories_for_export must be set before use" + ) df_builder( - categories=model_categories, + categories=categories_for_export, patient_to_ground_truth=patient_to_ground_truth, - predictions={ - # Mean prediction - patient_id: torch.stack( - [predictions[patient_id] for predictions in all_predictions] - ).mean(dim=0) - for patient_id in patient_ids - }, + predictions=mean_preds, patient_label=patient_label, ground_truth_label=ground_truth_label, ).to_csv(output_dir / "patient-preds_95_confidence_interval.csv", index=False) @@ -301,7 +364,7 @@ def _predict( test_dl: torch.utils.data.DataLoader, patient_ids: Sequence[PatientId], accelerator: str | Accelerator, -) -> Mapping[PatientId, Float[torch.Tensor, "..."]]: +) -> PredictionsType: model = model.eval() torch.set_float32_matmul_precision("medium") @@ -310,8 +373,11 @@ def _predict( getattr(model, "train_patients", []) ) | set(getattr(model, "valid_patients", [])) if overlap := patients_used_for_training & set(patient_ids): - raise ValueError( - f"some of the patients in the validation set were used during training: {overlap}" + _logger.critical( + "DATA LEAKAGE DETECTED: %d patient(s) in deployment set were used " + "during training/validation. Overlapping IDs: %s", + len(overlap), + sorted(overlap), ) trainer = lightning.Trainer( @@ -320,51 +386,190 @@ def _predict( logger=False, ) - raw_preds = torch.concat(cast(list[torch.Tensor], trainer.predict(model, test_dl))) + outs = trainer.predict(model, test_dl) + + if not outs: + return {} + + first = outs[0] + + # Multi-target case: each element of outs is a dict[target_label -> tensor] + if isinstance(first, dict): + per_target_lists: dict[str, list[torch.Tensor]] = {} + for out in outs: + if not isinstance(out, dict): + raise RuntimeError("Mixed prediction output types from model") + for k, v in out.items(): + per_target_lists.setdefault(k, []).append(v) + + per_target_tensors: dict[str, torch.Tensor] = { + k: torch.cat(vlist, dim=0) for k, vlist in per_target_lists.items() + } + + if getattr(model.hparams, "task", None) == "classification": + for k in list(per_target_tensors.keys()): + per_target_tensors[k] = torch.softmax(per_target_tensors[k], dim=1) + + # build per-patient dicts + num_preds = next(iter(per_target_tensors.values())).shape[0] + predictions: dict[PatientId, dict[str, torch.Tensor]] = {} + for i, pid in enumerate(patient_ids[:num_preds]): + predictions[pid] = { + k: per_target_tensors[k][i] for k in per_target_tensors.keys() + } + + return predictions + + # Single-target case: each element of outs is a tensor + outs_single = cast(list[torch.Tensor], outs) + + raw_preds = torch.cat(outs_single, dim=0) if getattr(model.hparams, "task", None) == "classification": - predictions = torch.softmax(raw_preds, dim=1) + raw_preds = torch.softmax(raw_preds, dim=1) elif getattr(model.hparams, "task", None) == "survival": - predictions = raw_preds.squeeze(-1) # (N,) risk scores - else: # regression - predictions = raw_preds + raw_preds = raw_preds.squeeze(-1) + + result: dict[PatientId, torch.Tensor] = { + pid: raw_preds[i] for i, pid in enumerate(patient_ids) + } - return dict(zip(patient_ids, predictions, strict=True)) + return result def _to_prediction_df( *, - categories: Sequence[GroundTruth], - patient_to_ground_truth: Mapping[PatientId, GroundTruth | None], - predictions: Mapping[PatientId, torch.Tensor], + categories: Sequence[GroundTruth] | Mapping[str, Sequence[GroundTruth]], + patient_to_ground_truth: Mapping[PatientId, GroundTruth | None] + | Mapping[PatientId, dict[str, GroundTruth | None]], + predictions: Mapping[PatientId, torch.Tensor] + | Mapping[PatientId, dict[str, torch.Tensor]], patient_label: PandasLabel, - ground_truth_label: PandasLabel, + ground_truth_label: PandasLabel | Sequence[PandasLabel], **kwargs, ) -> pd.DataFrame: - """Compiles deployment results into a DataFrame.""" - return pd.DataFrame( - [ - { - patient_label: patient_id, - ground_truth_label: patient_to_ground_truth.get(patient_id), - "pred": categories[int(prediction.argmax())], - **{ - f"{ground_truth_label}_{category}": prediction[i_cat].item() - for i_cat, category in enumerate(categories) - }, - "loss": ( - torch.nn.functional.cross_entropy( - prediction.reshape(1, -1), - torch.tensor(np.where(np.array(categories) == ground_truth)[0]), - ).item() - if (ground_truth := patient_to_ground_truth.get(patient_id)) - is not None - else None - ), - } - for patient_id, prediction in predictions.items() - ] - ).sort_values(by="loss") + """Compiles deployment results into a DataFrame. + + Supports single-target and multi-target classification. + - Single-target: `predictions` maps patient -> tensor and `categories` is a sequence. + - Multi-target: `predictions` maps patient -> dict[target_label -> tensor] and + `categories` is a mapping from target_label -> sequence of category names. + """ + first_pred = next(iter(predictions.values())) + + # Multi-target predictions: dict per patient + if isinstance(first_pred, dict): + # determine target labels + target_labels = list(cast(dict, first_pred).keys()) + + # prepare categories mapping + if isinstance(categories, dict): + cats_map = categories + else: + # try infer categories list ordering: assume categories is a sequence-of-sequences + cats_map = {} + if isinstance(categories, Sequence): + try: + for i, t in enumerate(target_labels): + cats_map[t] = list( + cast(Sequence[Sequence[GroundTruth]], categories)[i] + ) + except Exception: + cats_map = {} + + # infer missing category lists from ground truth + if any(t not in cats_map for t in target_labels): + inferred: dict[str, set] = {t: set() for t in target_labels} + for pid, gt in patient_to_ground_truth.items(): + if isinstance(gt, dict): + for t in target_labels: + val = gt.get(t) + if val is not None: + inferred[t].add(val) + for t in target_labels: + if t not in cats_map: + cats_map[t] = sorted(inferred.get(t, [])) + + rows = [] + for pid, pred_dict in predictions.items(): + row: dict = {patient_label: pid} + gt_entry = patient_to_ground_truth.get(pid) + # ground truths per target + for t in target_labels: + if isinstance(gt_entry, dict): + row[t] = gt_entry.get(t) + else: + row[t] = gt_entry + + total_loss = 0.0 + has_loss = False + for t in target_labels: + tensor = cast(dict[str, torch.Tensor], pred_dict)[t] + probs = tensor.detach().cpu() + cats: Sequence[GroundTruth] = cast( + Sequence[GroundTruth], + cats_map.get(t, []), + ) + if probs.numel() == 1: + row[f"pred_{t}"] = float(probs.item()) + else: + pred_idx = int(probs.argmax().item()) + row[f"pred_{t}"] = ( + cats[pred_idx] if pred_idx < len(cats) else pred_idx + ) + for i_cat, cat in enumerate(cats): + if i_cat < probs.shape[0]: + row[f"{t}_{cat}"] = float(probs[i_cat].item()) + else: + row[f"{t}_{cat}"] = None + + if isinstance(gt_entry, dict) and (gt := gt_entry.get(t)) is not None: + try: + target_index = int(np.where(np.array(cats) == gt)[0][0]) + loss = torch.nn.functional.cross_entropy( + probs.reshape(1, -1), torch.tensor([target_index]) + ).item() + total_loss += loss + has_loss = True + except Exception: + pass + + row["loss"] = total_loss if has_loss else None + rows.append(row) + + return pd.DataFrame(rows) + + # Single-target (original behaviour) + if not all(isinstance(p, torch.Tensor) for p in predictions.values()): + raise TypeError("Single-target block received multi-target dict predictions.") + + predictions = cast(Mapping[PatientId, torch.Tensor], predictions) + + rows = [] + for pid, prediction in predictions.items(): + gt = patient_to_ground_truth.get(pid) + cats = cast(Sequence[GroundTruth], categories) + pred_idx = int(prediction.argmax()) + row = { + patient_label: pid, + ground_truth_label: gt, + "pred": cats[pred_idx], + **{ + f"{ground_truth_label}_{category}": float(prediction[i_cat].item()) + for i_cat, category in enumerate(cats) + }, + "loss": ( + torch.nn.functional.cross_entropy( + prediction.reshape(1, -1), + torch.tensor(np.where(np.array(cats) == gt)[0]), + ).item() + if gt is not None + else None + ), + } + rows.append(row) + + return pd.DataFrame(rows).sort_values(by="loss") def _to_regression_prediction_df( @@ -383,8 +588,6 @@ def _to_regression_prediction_df( - pred (float) - loss (per-sample L1 loss if GT available, else None) """ - import torch.nn.functional as F - return pd.DataFrame( [ { diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 59a0a3aa..86003c52 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -3,7 +3,7 @@ import inspect from abc import ABC from collections.abc import Iterable, Sequence -from typing import Any, TypeAlias +from typing import Any, Mapping, TypeAlias import lightning import numpy as np @@ -14,6 +14,11 @@ from torchmetrics.classification import MulticlassAUROC import stamp +from stamp.modeling.models.barspoon import ( + EncDecTransformer, + LitMilClassificationMixin, + TargetLabel, +) from stamp.modeling.models.cox import neg_partial_log_likelihood from stamp.types import ( Bags, @@ -818,3 +823,75 @@ class LitPatientSurvival(LitSlideSurvival): """ supported_features = ["patient"] + + +class LitEncDecTransformer(LitMilClassificationMixin): + def __init__( + self, + *, + dim_input: int, + category_weights: Mapping[TargetLabel, torch.Tensor], + model_class: type[nn.Module] | None = None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, + categories: Mapping[str, Sequence[Category]], + # Model parameters + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + # Other hparams + learning_rate: float = 1e-4, + **hparams: Any, + ) -> None: + weights_dict: dict[TargetLabel, torch.Tensor] = dict(category_weights) + super().__init__( + weights=weights_dict, + learning_rate=learning_rate, + ) + _ = hparams # so we don't get unused parameter warnings + + self.model = EncDecTransformer( + d_features=dim_input, + target_n_outs={t: len(w) for t, w in category_weights.items()}, + d_model=d_model, + num_encoder_heads=num_encoder_heads, + num_decoder_heads=num_decoder_heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dim_feedforward=dim_feedforward, + positional_encoding=positional_encoding, + ) + + self.hparams["supported_features"] = "tile" + self.hparams.update({"task": "classification"}) + # ---- Normalize categories into strict mapping[str, list[str]] ---- + if not isinstance(categories, Mapping): + raise ValueError( + "Multi-target classification requires categories as Mapping[str, Sequence[str]]." + ) + + normalized_categories: dict[str, list[str]] = { + str(k): list(v) for k, v in categories.items() + } + + # Sanity check: head size must match category size + for t, w in category_weights.items(): + if t not in normalized_categories: + raise ValueError(f"Missing categories for target '{t}'") + if len(normalized_categories[t]) != len(w): + raise ValueError( + f"Category mismatch for target '{t}': " + f"{len(normalized_categories[t])} categories " + f"but head has {len(w)} outputs." + ) + + self.ground_truth_label = ground_truth_label + self.categories = normalized_categories + + self.save_hyperparameters() + + def forward(self, *args): + return self.model(*args) diff --git a/src/stamp/modeling/models/barspoon.py b/src/stamp/modeling/models/barspoon.py new file mode 100644 index 00000000..f841bb3d --- /dev/null +++ b/src/stamp/modeling/models/barspoon.py @@ -0,0 +1,367 @@ +""" +Port from https://github.com/KatherLab/barspoon-transformer +""" + +import re +from typing import Any, TypeAlias + +import lightning +import torch +import torch.nn.functional as F +import torchmetrics +from packaging.version import Version +from torch import nn +from torchmetrics.classification import MulticlassAUROC +from torchmetrics.utilities.data import dim_zero_cat + +import stamp +from stamp.types import Bags, BagSizes, CoordinatesBatch + +__all__ = [ + "EncDecTransformer", + "LitMilClassificationMixin", + "SafeMulticlassAUROC", +] + + +TargetLabel: TypeAlias = str + + +class EncDecTransformer(nn.Module): + """An encoder decoder architecture for multilabel classification tasks + + This architecture is a modified version of the one found in [Attention Is + All You Need][1]: First, we project the features into a lower-dimensional + feature space, to prevent the transformer architecture's complexity from + exploding for high-dimensional features. We add sinusodial [positional + encodings][1]. We then encode these projected input tokens using a + transformer encoder stack. Next, we decode these tokens using a set of + class tokens, one per output label. Finally, we forward each of the decoded + tokens through a fully connected layer to get a label-wise prediction. + + PE1 + | + +--+ v +---+ + t1 ->|FC|--+-->| |--+ + . +--+ | E | | + . | x | | + . +--+ | m | | + tn ->|FC|--+-->| |--+ + +--+ ^ +---+ | + | | + PEn v + +---+ +---+ + c1 ---------------->| |-->|FC1|--> s1 + . | D | +---+ . + . | x | . + . | l | +---+ . + ck ---------------->| |-->|FCk|--> sk + +---+ +---+ + + We opted for this architecture instead of a more traditional [Vision + Transformer][2] to improve performance for multi-label predictions with many + labels. Our experiments have shown that adding too many class tokens to a + vision transformer decreases its performance, as the same weights have to + both process the tiles' information and the class token's processing. Using + an encoder-decoder architecture alleviates these issues, as the data-flow of + the class tokens is completely independent of the encoding of the tiles. + Furthermore, analysis has shown that there is almost no interaction between + the different classes in the decoder. While this points to the decoder + being more powerful than needed in practice, this also means that each + label's prediction is mostly independent of the others. As a consequence, + noisy labels will not negatively impact the accuracy of non-noisy ones. + + In our experiments so far we did not see any improvement by adding + positional encodings. We tried + + 1. [Sinusodal encodings][1] + 2. Adding absolute positions to the feature vector, scaled down so the + maximum value in the training dataset is 1. + + Since neither reduced performance and the author percieves the first one to + be more elegant (as the magnitude of the positional encodings is bounded), + we opted to keep the positional encoding regardless in the hopes of it + improving performance on future tasks. + + The architecture _differs_ from the one descibed in [Attention Is All You + Need][1] as follows: + + 1. There is an initial projection stage to reduce the dimension of the + feature vectors and allow us to use the transformer with arbitrary + features. + 2. Instead of the language translation task described in [Attention Is All + You Need][1], where the tokens of the words translated so far are used + to predict the next word in the sequence, we use a set of fixed, learned + class tokens in conjunction with equally as many independent fully + connected layers to predict multiple labels at once. + + [1]: https://arxiv.org/abs/1706.03762 "Attention Is All You Need" + [2]: https://arxiv.org/abs/2010.11929 + "An Image is Worth 16x16 Words: + Transformers for Image Recognition at Scale" + """ + + def __init__( + self, + d_features: int, + target_n_outs: dict[str, int], + *, + d_model: int = 512, + num_encoder_heads: int = 8, + num_decoder_heads: int = 8, + num_encoder_layers: int = 2, + num_decoder_layers: int = 2, + dim_feedforward: int = 2048, + positional_encoding: bool = True, + ) -> None: + super().__init__() + + self.projector = nn.Sequential(nn.Linear(d_features, d_model), nn.ReLU()) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=num_encoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=num_encoder_layers, enable_nested_tensor=False + ) + + self.target_labels = target_n_outs.keys() + + # One class token per output label + self.class_tokens = nn.ParameterDict( + { + sanitize(target_label): torch.rand(d_model) + for target_label in target_n_outs + } + ) + + decoder_layer = nn.TransformerDecoderLayer( + d_model=d_model, + nhead=num_decoder_heads, + dim_feedforward=dim_feedforward, + batch_first=True, + norm_first=True, + ) + self.transformer_decoder = nn.TransformerDecoder( + decoder_layer, num_layers=num_decoder_layers + ) + + self.heads = nn.ModuleDict( + { + sanitize(target_label): nn.Linear( + in_features=d_model, out_features=n_out + ) + for target_label, n_out in target_n_outs.items() + } + ) + + self.positional_encoding = positional_encoding + + def forward( + self, + tile_tokens: torch.Tensor, + tile_positions: torch.Tensor, + ) -> dict[str, torch.Tensor]: + batch_size, _, _ = tile_tokens.shape + + tile_tokens = self.projector(tile_tokens) # shape: [bs, seq_len, d_model] + + if self.positional_encoding: + # Add positional encodings + d_model = tile_tokens.size(-1) + x = tile_positions.unsqueeze(-1) / 100_000 ** ( + torch.arange(d_model // 4).type_as(tile_positions) / d_model + ) + positional_encodings = torch.cat( + [ + torch.sin(x).flatten(start_dim=-2), + torch.cos(x).flatten(start_dim=-2), + ], + dim=-1, + ) + tile_tokens = tile_tokens + positional_encodings + + tile_tokens = self.transformer_encoder(tile_tokens) + + class_tokens = torch.stack( + [self.class_tokens[sanitize(t)] for t in self.target_labels] + ).expand(batch_size, -1, -1) + class_tokens = self.transformer_decoder(tgt=class_tokens, memory=tile_tokens) + + # Apply the corresponding head to each class token + logits = { + target_label: self.heads[sanitize(target_label)](class_token) + for target_label, class_token in zip( + self.target_labels, + class_tokens.permute(1, 0, 2), # Permute to [target, batch, d_model] + strict=True, + ) + } + + return logits + + +class LitMilClassificationMixin(lightning.LightningModule): + """Makes a module into a multilabel, multiclass Lightning one""" + + supported_features = ["tile"] + + def __init__( + self, + *, + weights: dict[TargetLabel, torch.Tensor], + # Other hparams + learning_rate: float = 1e-4, + stamp_version: Version = Version(stamp.__version__), + **hparams: Any, + ) -> None: + super().__init__() + _ = hparams # So we don't get unused parameter warnings + + # Check if version is compatible. + if stamp_version < Version("2.4.0"): + # Update this as we change our model in incompatible ways! + raise ValueError( + f"model has been built with stamp version {stamp_version} " + f"which is incompatible with the current version." + ) + elif stamp_version > Version(stamp.__version__): + # Let's be strict with models "from the future", + # better fail deadly than have broken results. + raise ValueError( + "model has been built with a stamp version newer than the installed one " + f"({stamp_version} > {stamp.__version__}). " + "Please upgrade stamp to a compatible version." + ) + + self.hparams.update({"task": "classification"}) + + self.learning_rate = learning_rate + + target_aurocs = torchmetrics.MetricCollection( + { + sanitize(target_label): SafeMulticlassAUROC(num_classes=len(weight)) + for target_label, weight in weights.items() + } + ) + for step_name in ["train", "validation", "test"]: + setattr( + self, + f"{step_name}_target_aurocs", + target_aurocs.clone(prefix=f"{step_name}_"), + ) + + self.weights = weights + + self.save_hyperparameters() + + def step( + self, + batch: tuple[Bags, CoordinatesBatch, BagSizes, dict[str, torch.Tensor]], + step_name=None, + ): + """Process a batch with structure (feats, coords, bag_sizes, targets). + + Args: + batch: Tuple of (feats, coords, bag_sizes, targets) where: + - feats: bag features [batch, bag_size, feature_dim] + - coords: tile coordinates [batch, bag_size, 2] + - bag_sizes: number of tiles per bag [batch] + - targets: dict mapping target names to one-hot encoded tensors [batch, num_classes] + step_name: Optional step name for logging ('train', 'validation', 'test'). + """ + feats: Bags + coords: CoordinatesBatch + bag_sizes: BagSizes + targets: dict[str, torch.Tensor] + feats, coords, bag_sizes, targets = batch + logits = self(feats, coords) + + # Calculate the cross entropy loss for each target, then sum them + loss = sum( + F.cross_entropy( + (logit := logits[target_label]), + targets[target_label].type_as(logit), + weight=weight.type_as(logit), + ) + for target_label, weight in self.weights.items() + ) + + if step_name: + self.log( + f"{step_name}_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + ) + + # Update target-wise metrics + for target_label in self.weights: + target_auroc = getattr(self, f"{step_name}_target_aurocs")[ + sanitize(target_label) + ] + is_na = (targets[target_label] == 0).all(dim=1) + target_auroc.update( + logits[target_label][~is_na], + targets[target_label][~is_na].argmax(dim=1), + ) + self.log( + f"{step_name}_{target_label}_auroc", + target_auroc, + on_step=False, + on_epoch=True, + sync_dist=True, + ) + + return loss + + def training_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="train") + + def validation_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="validation") + + def test_step(self, batch, batch_idx): # pyright: ignore[reportIncompatibleMethodOverride] + return self.step(batch, step_name="test") + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + if len(batch) == 2: + feats, positions = batch + else: + feats, positions, _, _ = batch + + logits = self(feats, positions) + + softmaxed = { + target_label: torch.softmax(x, 1) for target_label, x in logits.items() + } + return softmaxed + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + +def sanitize(x: str) -> str: + return re.sub(r"[^A-Za-z0-9_]", "_", x) + + +class SafeMulticlassAUROC(MulticlassAUROC): + """A Multiclass AUROC that doesn't blow up when no targets are given""" + + def compute(self) -> torch.Tensor: + # Add faux entry if there are none so far + if len(self.preds) == 0: + self.update(torch.zeros(1, self.num_classes), torch.zeros(1).long()) + elif len(dim_zero_cat(self.preds)) == 0: + self.update( + torch.zeros(1, self.num_classes).type_as(self.preds[0]), + torch.zeros(1).long().type_as(self.target[0]), + ) + return super().compute() diff --git a/src/stamp/modeling/registry.py b/src/stamp/modeling/registry.py index 2205af22..14011084 100644 --- a/src/stamp/modeling/registry.py +++ b/src/stamp/modeling/registry.py @@ -1,6 +1,7 @@ from enum import StrEnum from stamp.modeling.models import ( + LitEncDecTransformer, LitPatientClassifier, LitPatientRegressor, LitPatientSurvival, @@ -21,6 +22,7 @@ class ModelName(StrEnum): MLP = "mlp" TRANS_MIL = "trans_mil" LINEAR = "linear" + BARSPOON = "barspoon" # Map (feature_type, task) → correct Lightning wrapper class @@ -34,6 +36,7 @@ class ModelName(StrEnum): ("patient", "classification"): LitPatientClassifier, ("patient", "regression"): LitPatientRegressor, ("patient", "survival"): LitPatientSurvival, + # ("tile", "multiclass"): LitEncDecTransformer, } @@ -54,6 +57,13 @@ def load_model_class(task: Task, feature_type: str, model_name: ModelName): case ModelName.MLP: from stamp.modeling.models.mlp import MLP as ModelClass + case ModelName.BARSPOON: + from stamp.modeling.models.barspoon import ( + EncDecTransformer as ModelClass, + ) + + LitModelClass = LitEncDecTransformer + case ModelName.LINEAR: from stamp.modeling.models.mlp import ( Linear as ModelClass, diff --git a/src/stamp/modeling/train.py b/src/stamp/modeling/train.py index b855e2a0..fb2bdb3b 100644 --- a/src/stamp/modeling/train.py +++ b/src/stamp/modeling/train.py @@ -2,7 +2,7 @@ import shutil from collections.abc import Callable, Mapping, Sequence from pathlib import Path -from typing import cast +from typing import Any, cast import lightning import torch @@ -18,13 +18,7 @@ PatientData, PatientFeatureDataset, create_dataloader, - detect_feature_type, - filter_complete_patient_data_, - load_patient_level_data, - log_patient_class_summary, - patient_to_ground_truth_from_clini_table_, - patient_to_survival_from_clini_table_, - slide_to_patient_from_slide_table_, + load_patient_data_, ) from stamp.modeling.registry import ModelName, load_model_class from stamp.modeling.transforms import VaryPrecisionTransform @@ -53,66 +47,25 @@ def train_categorical_model_( advanced: AdvancedConfig, ) -> None: """Trains a model based on the feature type.""" - feature_type = detect_feature_type(config.feature_dir) - _logger.info(f"Detected feature type: {feature_type}") - - if feature_type in ("tile", "slide"): - if config.slide_table is None: - raise ValueError("A slide table is required for modeling") - if config.task == "survival": - if config.time_label is None or config.status_label is None: - raise ValueError( - "Both time_label and status_label is required for survival modeling" - ) - patient_to_ground_truth = patient_to_survival_from_clini_table_( - clini_table_path=config.clini_table, - time_label=config.time_label, - status_label=config.status_label, - patient_label=config.patient_label, - ) - else: - if config.ground_truth_label is None: - raise ValueError( - "Ground truth label is required for tile-level modeling" - ) - patient_to_ground_truth = patient_to_ground_truth_from_clini_table_( - clini_table_path=config.clini_table, - ground_truth_label=config.ground_truth_label, - patient_label=config.patient_label, - ) - slide_to_patient = slide_to_patient_from_slide_table_( - slide_table_path=config.slide_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - filename_label=config.filename_label, - ) - patient_to_data = filter_complete_patient_data_( - patient_to_ground_truth=patient_to_ground_truth, - slide_to_patient=slide_to_patient, - drop_patients_with_missing_ground_truth=True, - ) - elif feature_type == "patient": - # Patient-level: ignore slide_table - if config.slide_table is not None: - _logger.warning("slide_table is ignored for patient-level features.") - - patient_to_data = load_patient_level_data( - task=config.task, - clini_table=config.clini_table, - feature_dir=config.feature_dir, - patient_label=config.patient_label, - ground_truth_label=config.ground_truth_label, - time_label=config.time_label, - status_label=config.status_label, - ) - else: - raise RuntimeError(f"Unknown feature type: {feature_type}") - if config.task is None: raise ValueError( "task must be set to 'classification' | 'regression' | 'survival'" ) + patient_to_data, feature_type = load_patient_data_( + feature_dir=config.feature_dir, + clini_table=config.clini_table, + slide_table=config.slide_table, + task=config.task, + ground_truth_label=config.ground_truth_label, + time_label=config.time_label, + status_label=config.status_label, + patient_label=config.patient_label, + filename_label=config.filename_label, + drop_patients_with_missing_ground_truth=True, + ) + _logger.info(f"Detected feature type: {feature_type}") + # Train the model (the rest of the logic is unchanged) model, train_dl, valid_dl = setup_model_for_training( patient_to_data=patient_to_data, @@ -145,14 +98,14 @@ def train_categorical_model_( def setup_model_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, train_transform: Callable[[torch.Tensor], torch.Tensor] | None, feature_type: str, advanced: AdvancedConfig, # Metadata, has no effect on model training - ground_truth_label: PandasLabel | None, + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None, time_label: PandasLabel | None, status_label: PandasLabel | None, clini_table: Path, @@ -193,10 +146,6 @@ def setup_model_for_training( feature_type=feature_type, train_categories=train_categories, ) - log_patient_class_summary( - patient_to_data={pid: patient_to_data[pid] for pid in patient_to_data}, - categories=categories, - ) # 1. Default to a model if none is specified if advanced.model_name is None: @@ -209,6 +158,7 @@ def setup_model_for_training( LitModelClass, ModelClass = load_model_class( task, feature_type, advanced.model_name ) + print(f"Using Lightning wrapper class: {LitModelClass}") # 3. Validate that the chosen model supports the feature type if feature_type not in LitModelClass.supported_features: @@ -272,7 +222,7 @@ def setup_model_for_training( def setup_dataloaders_for_training( *, - patient_to_data: Mapping[PatientId, PatientData[GroundTruth]], + patient_to_data: Mapping[PatientId, PatientData[GroundTruth | None | dict]], task: Task, categories: Sequence[Category] | None, bag_size: int, @@ -283,7 +233,7 @@ def setup_dataloaders_for_training( ) -> tuple[ DataLoader, DataLoader, - Sequence[Category], + Sequence[Category] | Mapping[str, Sequence[Category]], int, Sequence[PatientId], Sequence[PatientId], @@ -310,10 +260,25 @@ def setup_dataloaders_for_training( ) if task == "classification": - stratify = ground_truths + # Handle both single and multi-target cases + if ground_truths and isinstance(ground_truths[0], dict): + # Multi-target: use first target for stratification + first_key = list(ground_truths[0].keys())[0] + stratify = [cast(dict, gt)[first_key] for gt in ground_truths] + else: + stratify = ground_truths elif task == "survival": - # Extract event indicator (status) - statuses = [int(gt.split()[1]) for gt in ground_truths] + # Extract event indicator (status) - handle both single and multi-target + statuses = [] + for gt in ground_truths: + if isinstance(gt, dict): + # Multi-target survival: extract from first target + first_key = list(gt.keys())[0] + val = cast(dict, gt)[first_key] + if val: + statuses.append(int(val.split()[1])) + else: + statuses.append(int(gt.split()[1])) stratify = statuses elif task == "regression": stratify = None @@ -321,7 +286,10 @@ def setup_dataloaders_for_training( train_patients, valid_patients = cast( tuple[Sequence[PatientId], Sequence[PatientId]], train_test_split( - list(patient_to_data), stratify=stratify, shuffle=True, random_state=0 + list(patient_to_data), + stratify=cast(Any, stratify), + shuffle=True, + random_state=0, ), ) @@ -441,28 +409,50 @@ def _compute_class_weights_and_check_categories( *, train_dl: DataLoader, feature_type: str, - train_categories: Sequence[str], -) -> torch.Tensor: + train_categories: Sequence[str] | Mapping[str, Sequence[str]], +) -> torch.Tensor | dict[str, torch.Tensor]: """ Computes class weights and checks for category issues. Logs warnings if there are too few or underpopulated categories. Returns normalized category weights as a torch.Tensor. """ if feature_type == "tile": - category_counts = cast(BagDataset, train_dl.dataset).ground_truths.sum(dim=0) + dataset = cast(BagDataset, train_dl.dataset) + + if isinstance(dataset.ground_truths, list): + # Multi-target case: compute weights per target head + weights_per_target: dict[str, torch.Tensor] = {} + + target_keys = dataset.ground_truths[0].keys() + + for key in target_keys: + stacked = torch.stack([gt[key] for gt in dataset.ground_truths], dim=0) + counts = stacked.sum(dim=0) + w = counts.sum() / counts + weights_per_target[key] = w / w.sum() + + return weights_per_target + else: + category_counts = dataset.ground_truths.sum(dim=0) else: - category_counts = cast( - PatientFeatureDataset, train_dl.dataset - ).ground_truths.sum(dim=0) + dataset = cast(PatientFeatureDataset, train_dl.dataset) + category_counts = dataset.ground_truths.sum(dim=0) cat_ratio_reciprocal = category_counts.sum() / category_counts category_weights = cat_ratio_reciprocal / cat_ratio_reciprocal.sum() if len(train_categories) <= 1: raise ValueError(f"not enough categories to train on: {train_categories}") - elif any(category_counts < 16): + elif (category_counts < 16).any(): + category_counts_list = ( + category_counts.tolist() + if category_counts.dim() > 0 + else [category_counts.item()] + ) underpopulated_categories = { category: int(count) - for category, count in zip(train_categories, category_counts, strict=True) + for category, count in zip( + train_categories, category_counts_list, strict=True + ) if count < 16 } _logger.warning( diff --git a/src/stamp/preprocessing/__init__.py b/src/stamp/preprocessing/__init__.py index a1844526..84cb48e9 100755 --- a/src/stamp/preprocessing/__init__.py +++ b/src/stamp/preprocessing/__init__.py @@ -16,7 +16,6 @@ from tqdm import tqdm import stamp -from stamp.cache import get_processing_code_hash from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor from stamp.preprocessing.tiling import ( @@ -32,6 +31,7 @@ SlidePixels, TilePixels, ) +from stamp.utils.cache import get_processing_code_hash __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck" diff --git a/src/stamp/preprocessing/extractor/chief_ctranspath.py b/src/stamp/preprocessing/extractor/chief_ctranspath.py index 2d2e6b9b..03f5bba6 100644 --- a/src/stamp/preprocessing/extractor/chief_ctranspath.py +++ b/src/stamp/preprocessing/extractor/chief_ctranspath.py @@ -1,6 +1,6 @@ from pathlib import Path -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown diff --git a/src/stamp/preprocessing/extractor/ctranspath.py b/src/stamp/preprocessing/extractor/ctranspath.py index ba9a277c..387e7947 100644 --- a/src/stamp/preprocessing/extractor/ctranspath.py +++ b/src/stamp/preprocessing/extractor/ctranspath.py @@ -7,7 +7,7 @@ from pathlib import Path from typing import Optional, TypeVar, cast -from stamp.cache import STAMP_CACHE_DIR, file_digest +from stamp.utils.cache import STAMP_CACHE_DIR, file_digest try: import gdown diff --git a/src/stamp/preprocessing/extractor/dinobloom.py b/src/stamp/preprocessing/extractor/dinobloom.py index fb7713ce..54b79c53 100644 --- a/src/stamp/preprocessing/extractor/dinobloom.py +++ b/src/stamp/preprocessing/extractor/dinobloom.py @@ -8,9 +8,9 @@ from torch import nn from torchvision import transforms -from stamp.cache import STAMP_CACHE_DIR from stamp.preprocessing.config import ExtractorName from stamp.preprocessing.extractor import Extractor +from stamp.utils.cache import STAMP_CACHE_DIR __author__ = "Marko van Treeck" __copyright__ = "Copyright (C) 2022-2025 Marko van Treeck" diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index 82a3efba..e09a06fa 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -461,7 +461,7 @@ def _extract_mpp_from_metadata(slide: openslide.AbstractSlide) -> SlideMPP | Non return None doc = minidom.parseString(xml_path) collection = doc.documentElement - images = collection.getElementsByTagName("Image") + images = collection.getElementsByTagName("Image") # pyright: ignore[reportOptionalMemberAccess] pixels = images[0].getElementsByTagName("Pixels") mpp = float(pixels[0].getAttribute("PhysicalSizeX")) except Exception: diff --git a/src/stamp/statistics/__init__.py b/src/stamp/statistics/__init__.py index ec09e1e0..bdbef1fa 100644 --- a/src/stamp/statistics/__init__.py +++ b/src/stamp/statistics/__init__.py @@ -1,3 +1,11 @@ +"""Statistics utilities (wrappers) for classification, regression and survival. + +This module provides a small, stable wrapper `compute_stats_` that dispatches +to the task-specific statistic implementations found in the submodules. +""" + +from __future__ import annotations + from collections.abc import Sequence from pathlib import Path from typing import NewType @@ -17,23 +25,25 @@ plot_multiple_decorated_roc_curves, plot_single_decorated_roc_curve, ) -from stamp.statistics.survival import ( - _plot_km, - _survival_stats_for_csv, -) +from stamp.statistics.survival import _plot_km, _survival_stats_for_csv from stamp.types import PandasLabel, Task +__all__ = ["StatsConfig", "compute_stats_"] + + __author__ = "Marko van Treeck, Minh Duc Nguyen" __copyright__ = "Copyright (C) 2022-2024 Marko van Treeck, Minh Duc Nguyen" __license__ = "MIT" def _read_table(file: Path, **kwargs) -> pd.DataFrame: - """Loads a dataframe from a file.""" + """Load a dataframe from CSV or XLSX file path. + + This small helper centralizes file IO formatting and keeps callers simple. + """ if isinstance(file, Path) and file.suffix == ".xlsx": return pd.read_excel(file, **kwargs) - else: - return pd.read_csv(file, **kwargs) + return pd.read_csv(file, **kwargs) class StatsConfig(BaseModel): @@ -60,6 +70,11 @@ def compute_stats_( time_label: str | None = None, status_label: str | None = None, ) -> None: + """Compute and save statistics for the provided task and prediction CSVs. + + This wrapper keeps the external API stable while delegating the detailed + computations and plotting to the submodules under `stamp.statistics.*`. + """ match task: case "classification": if true_class is None or ground_truth_label is None: @@ -105,7 +120,6 @@ def compute_stats_( n_bootstrap_samples=n_bootstrap_samples, threshold_cmap=threshold_cmap, ) - else: plot_multiple_decorated_roc_curves( ax=ax, @@ -116,9 +130,7 @@ def compute_stats_( ) fig.tight_layout() - if not output_dir.exists(): - output_dir.mkdir(parents=True, exist_ok=True) - + output_dir.mkdir(parents=True, exist_ok=True) fig.savefig(output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg") plt.close(fig) @@ -134,7 +146,6 @@ def compute_stats_( title=f"{ground_truth_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, ) - else: plot_multiple_decorated_precision_recall_curves( ax=ax, @@ -203,12 +214,7 @@ def compute_stats_( cut_off=cut_off, ) - # ------------------------------------------------------------------ # # Save individual and aggregated CSVs - # ------------------------------------------------------------------ # stats_df = pd.DataFrame(per_fold).transpose() stats_df.index.name = "fold_name" # label the index column stats_df.to_csv(output_dir / "survival-stats_individual.csv", index=True) - - # agg_df = _aggregate_with_ci(stats_df) - # agg_df.to_csv(output_dir / "survival-stats_aggregated.csv", index=True) diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 063793cf..7c298a54 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -46,7 +46,7 @@ def _survival_stats_for_csv( if risk_label is None: risk_label = "pred_score" - # --- Clean NaNs and invalid events before computing stats --- + # Clean NaNs and invalid events before computing stats df = df.dropna(subset=[time_label, status_label, risk_label]).copy() df = df[df[status_label].isin([0, 1])] if len(df) == 0: @@ -56,10 +56,10 @@ def _survival_stats_for_csv( event = np.asarray(df[status_label], dtype=int) risk = np.asarray(df[risk_label], dtype=float) - # --- Concordance index --- + # Concordance index c_index, n_pairs = _cindex(time, event, risk) - # --- Log-rank test (median split) --- + # Log-rank test (median split) median_risk = float(cut_off) if cut_off is not None else float(np.nanmedian(risk)) low_mask = risk <= median_risk high_mask = risk > median_risk @@ -101,7 +101,7 @@ def _plot_km( if risk_label is None: risk_label = "pred_score" - # --- Clean NaNs and invalid entries --- + # Clean NaNs and invalid entries df = df.replace(["NaN", "nan", "None", "Inf", "inf"], np.nan) df = df.dropna(subset=[time_label, status_label, risk_label]).copy() df = df[df[status_label].isin([0, 1])] @@ -113,7 +113,7 @@ def _plot_km( event = np.asarray(df[status_label], dtype=int) risk = np.asarray(df[risk_label], dtype=float) - # --- split groups --- + # split groups median_risk = float(cut_off) if cut_off is not None else np.nanmedian(risk) low_mask = risk <= median_risk high_mask = risk > median_risk @@ -138,7 +138,7 @@ def _plot_km( add_at_risk_counts(kmf_low, kmf_high, ax=ax) - # --- log-rank and c-index --- + # log-rank and c-index res = logrank_test( low_df[time_label], high_df[time_label], diff --git a/src/stamp/types.py b/src/stamp/types.py index f1f571cc..c1ff6873 100644 --- a/src/stamp/types.py +++ b/src/stamp/types.py @@ -37,6 +37,7 @@ PatientId: TypeAlias = str GroundTruth: TypeAlias = str +MultiClassGroundTruth: TypeAlias = tuple[str, ...] FeaturePath = NewType("FeaturePath", Path) Category: TypeAlias = str diff --git a/src/stamp/cache.py b/src/stamp/utils/cache.py similarity index 100% rename from src/stamp/cache.py rename to src/stamp/utils/cache.py diff --git a/src/stamp/config.py b/src/stamp/utils/config.py similarity index 100% rename from src/stamp/config.py rename to src/stamp/utils/config.py diff --git a/src/stamp/seed.py b/src/stamp/utils/seed.py similarity index 100% rename from src/stamp/seed.py rename to src/stamp/utils/seed.py diff --git a/src/stamp/utils/target_file.py b/src/stamp/utils/target_file.py new file mode 100644 index 00000000..08c30b6b --- /dev/null +++ b/src/stamp/utils/target_file.py @@ -0,0 +1,351 @@ +"""Automatically generate target information from clini table + +# The `barspoon-targets 2.0` File Format + +A barspoon target file is a [TOML][1] file with the following entries: + + - A `version` key mapping to a version string `"barspoon-targets "`, where + `` is a [PEP-440 version string][2] compatible with `2.0`. + - A `targets` table, the keys of which are target labels (as found in the + clinical table) and the values specify exactly one of the following: + 1. A categorical target label, marked by the presence of a `categories` + key-value pair. + 2. A target label to quantize, marked by the presence of a `thresholds` + key-value pair. + 3. A target format defined in in a later version of barspoon targets. + A target may only ever have one of the fields `categories` or `thresholds`. + A definition of these entries can be found below. + +[1]: https://toml.io "Tom's Obvious Minimal Language" +[2]: https://peps.python.org/pep-0440/ + "PEP 440 - Version Identification and Dependency Specification" + +## Categorical Target Label + +A categorical target is a target table with a key-value pair `categories`. +`categories` contains a list of lists of literal strings. Each list of strings +will be treated as one category, with all literal strings within that list being +treated as one representative for that category. This allows the user to easily +group related classes into one large class (i.e. `"True", "1", "Yes"` could all +be unified into the same category). + +### Category Weights + +It is possible to assign a weight to each category, to e.g. weigh rarer classes +more heavily. The weights are stored in a table `targets.LABEL.class_weights`, +whose keys is the first representative of each category, and the values of which +is the weight of the category as a floating point number. + +## Target Label to Quantize + +If a target has the `thresholds` option key set, it is interpreted as a +continuous target which has to be quantized. `thresholds` has to be a list of +floating point numbers [t_0, t_n], n > 1 containing the thresholds of the bins +to quantize the values into. A categorical target will be quantized into bins + +```asciimath +b_0 = [t_0; t_1], b_1 = (t_1; b_2], ... b_(n-1) = (t_(n-1); t_n] +``` + +The bins will be treated as categories with names +`f"[{t_0:+1.2e};{t_1:+1.2e}]"` for the first bin and +`f"({t_i:+1.2e};{t_(i+1):+1.2e}]"` for all other bins + +To avoid confusion, we recommend to also format the `thresholds` list the same +way. + +The bins can also be weighted. See _Categorical Target Label: Category Weights_ +for details. + + > Experience has shown that many labels contain non-negative values with a + > disproportionate amount (more than n_samples/n_bins) of zeroes. We thus + > decided to make the _right_ side of each bin inclusive, as the bin (-A,0] + > then naturally includes those zero values. +""" + +import logging +from pathlib import Path +from typing import ( + Any, + Dict, + List, + NamedTuple, + Optional, + Sequence, + TextIO, + Tuple, +) + +import numpy as np +import numpy.typing as npt +import pandas as pd +import torch +import torch.nn.functional as F +from packaging.specifiers import Specifier + + +def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: + if not isinstance(path, Path): + return pd.read_csv(path, **kwargs) + elif path.suffix == ".xlsx": + return pd.read_excel(path, **kwargs) + elif path.suffix == ".csv": + return pd.read_csv(path, **kwargs) + else: + raise ValueError( + "table to load has to either be an excel (`*.xlsx`) or csv (`*.csv`) file." + ) + + +__all__ = ["build_targets", "decode_targets"] + + +class TargetSpec(NamedTuple): + version: str + targets: Dict[str, Dict[str, Any]] + + +class EncodedTarget(NamedTuple): + categories: List[str] + encoded: torch.Tensor + weight: torch.Tensor + + +def encode_category( + *, + clini_df: pd.DataFrame, + target_label: str, + categories: Sequence[List[str]], + class_weights: Optional[Dict[str, float]] = None, + **ignored, +) -> Tuple[List[str], torch.Tensor, torch.Tensor]: + # Map each category to its index + category_map = {member: idx for idx, cat in enumerate(categories) for member in cat} + + # Map each item to it's category's index, mapping nans to num_classes+1 + # This way we can easily discard the NaN column later + indexes = clini_df[target_label].map(lambda c: category_map.get(c, len(categories))) + indexes = torch.tensor(indexes.values) + + # Discard nan column + one_hot = F.one_hot(indexes, num_classes=len(categories) + 1)[:, :-1] + + # Class weights + if class_weights is not None: + weight = torch.tensor([class_weights[c[0]] for c in categories]) + else: + # No class weights given; use normalized inverse frequency + counts = one_hot.sum(dim=0) + weight = (w := (counts.sum() / counts)) / w.sum() + + # Warn user of unused labels + if ignored: + logging.warn(f"ignored labels in target {target_label}: {ignored}") + + return [c[0] for c in categories], one_hot, weight + + +def encode_quantize( + *, + clini_df: pd.DataFrame, + target_label: str, + thresholds: npt.NDArray[np.floating[Any]], + class_weights: Optional[Dict[str, float]] = None, + **ignored, +) -> Tuple[List[str], torch.Tensor, torch.Tensor]: + # Warn user of unused labels + if ignored: + logging.warn(f"ignored labels in target {target_label}: {ignored}") + + n_bins = len(thresholds) - 1 + numeric_vals = torch.tensor(pd.to_numeric(clini_df[target_label]).values).reshape( + -1, 1 + ) + + # Map each value to a class index as follows: + # 1. If the value is NaN or less than the left-most threshold, use class + # index 0 + # 2. If it is between the left-most and the right-most threshold, set it to + # the bin number (starting from 1) + # 3. If it is larger than the right-most threshold, set it to N_bins + 1 + bin_index = ( + (numeric_vals > torch.tensor(thresholds).reshape(1, -1)).count_nonzero(1) + # For the first bucket, we have to include the lower threshold + + (numeric_vals.reshape(-1) == thresholds[0]) + ) + # One hot encode and discard nan columns (first and last col) + one_hot = F.one_hot(bin_index, num_classes=n_bins + 2)[:, 1:-1] + + # Class weights + categories = [ + f"[{thresholds[0]:+1.2e};{thresholds[1]:+1.2e}]", + *( + f"({lower:+1.2e};{upper:+1.2e}]" + for lower, upper in zip(thresholds[1:-1], thresholds[2:], strict=True) + ), + ] + + if class_weights is not None: + weight = torch.tensor([class_weights[c] for c in categories]) + else: + # No class weights given; use normalized inverse frequency + counts = one_hot.sum(0) + weight = (w := (np.divide(counts.sum(), counts, where=counts > 0))) / w.sum() + + return categories, one_hot, weight + + +def decode_targets( + encoded: torch.Tensor, + *, + target_labels: Sequence[str], + targets: Dict[str, Any], + version: str = "barspoon-targets 2.0", + **ignored, +) -> List[np.ndarray]: + name, version = version.split(" ") + spec = Specifier("~=2.0") + + if not (name == "barspoon-targets" and spec.contains(version)): + raise ValueError( + f"incompatible target file: expected barspoon-targets{spec}, found `{name} {version}`" + ) + + # Warn user of unused labels + if ignored: + logging.warn(f"ignored parameters: {ignored}") + + decoded_targets = [] + curr_col = 0 + for target_label in target_labels: + info = targets[target_label] + + if (categories := info.get("categories")) is not None: + # Add another column which is one iff all the other values are zero + encoded_target = encoded[:, curr_col : curr_col + len(categories)] + is_none = ~encoded_target.any(dim=1).view(-1, 1) + encoded_target = torch.cat([encoded_target, is_none], dim=1) + + # Decode to class labels + representatives = np.array([c[0] for c in categories] + [None]) + category_index = encoded_target.argmax(dim=1) + decoded = representatives[category_index] + decoded_targets.append(decoded) + + curr_col += len(categories) + + elif (thresholds := info.get("thresholds")) is not None: + n_bins = len(thresholds) - 1 + encoded_target = encoded[:, curr_col : curr_col + n_bins] + is_none = ~encoded_target.any(dim=1).view(-1, 1) + encoded_target = torch.cat([encoded_target, is_none], dim=1) + + bin_edges = [-np.inf, *thresholds, np.inf] + representatives = np.array( + [ + f"[{lower:+1.2e};{upper:+1.2e})" + for lower, upper in zip(bin_edges[:-1], bin_edges[1:]) + ] + ) + decoded = representatives[encoded_target.argmax(dim=1)] + + decoded_targets.append(decoded) + + curr_col += n_bins + + else: + raise ValueError(f"cannot decode {target_label}: no target info") + + return decoded_targets + + +def build_targets( + *, + clini_tables: Sequence[Path], + categorical_labels: Sequence[str], + category_min_count: int = 32, + quantize: Sequence[tuple[str, int]] = (), +) -> Dict[str, EncodedTarget]: + clini_df = pd.concat([read_table(c) for c in clini_tables]) + encoded_targets: Dict[str, EncodedTarget] = {} + + # categorical targets + for target_label in categorical_labels: + counts = clini_df[target_label].value_counts() + well_supported = counts[counts >= category_min_count] + + if len(well_supported) <= 1: + continue + + categories = [[str(cat)] for cat in well_supported.index] + + weights = well_supported.sum() / well_supported + weights /= weights.sum() + + representatives, encoded, weight = encode_category( + clini_df=clini_df, + target_label=target_label, + categories=categories, + class_weights=weights.to_dict(), + ) + + encoded_targets[target_label] = EncodedTarget( + categories=representatives, + encoded=encoded, + weight=weight, + ) + + # quantized targets + for target_label, bincount in quantize: + vals = pd.to_numeric(clini_df[target_label]).dropna() + + if vals.empty: + continue + + vals_clamped = vals.replace( + { + -np.inf: vals[vals != -np.inf].min(), + np.inf: vals[vals != np.inf].max(), + } + ) + + thresholds = np.array( + [ + -np.inf, + *np.quantile(vals_clamped, q=np.linspace(0, 1, bincount + 1))[1:-1], + np.inf, + ], + dtype=float, + ) + + representatives, encoded, weight = encode_quantize( + clini_df=clini_df, + target_label=target_label, + thresholds=thresholds, + ) + + if encoded.shape[1] <= 1: + continue + + encoded_targets[target_label] = EncodedTarget( + categories=representatives, + encoded=encoded, + weight=weight, + ) + + return encoded_targets + + +if __name__ == "__main__": + encoded = build_targets( + clini_tables=[ + Path( + "/mnt/bulk-neptune/nguyenmin/stamp-dev/experiments/survival_prediction/TCGA-CRC-DX_CLINI.xlsx" + ) + ], + categorical_labels=["BRAF", "KRAS", "NRAS"], + category_min_count=32, + quantize=[], + ) + for name, enc in encoded.items(): + print(name, enc.encoded.shape) diff --git a/tests/random_data.py b/tests/random_data.py index bd95d1bc..c7c36880 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -254,6 +254,92 @@ def create_random_patient_level_dataset( return clini_path, slide_path, feat_dir, categories +def create_random_multi_target_dataset( + *, + dir: Path, + n_patients: int, + max_slides_per_patient: int, + min_tiles_per_slide: int, + max_tiles_per_slide: int, + feat_dim: int, + target_labels: Sequence[str], + categories_per_target: Sequence[Sequence[str]], + extractor_name: str = "random-test-generator", + min_slides_per_patient: int = 1, +) -> tuple[Path, Path, Path, Sequence[Sequence[str]]]: + """ + Create a random multi-target tile-level dataset. + + Args: + dir: Directory to create dataset in + n_patients: Number of patients + max_slides_per_patient: Maximum slides per patient + min_tiles_per_slide: Minimum tiles per slide + max_tiles_per_slide: Maximum tiles per slide + feat_dim: Feature dimension + target_labels: Names of the target columns (e.g., ["subtype", "grade"]) + categories_per_target: Categories for each target (e.g., [["A", "B"], ["1", "2", "3"]]) + extractor_name: Name of the extractor + min_slides_per_patient: Minimum slides per patient + + Returns: + Tuple of (clini_path, slide_path, feat_dir, categories_per_target) + """ + if len(target_labels) != len(categories_per_target): + raise ValueError( + "target_labels and categories_per_target must have same length" + ) + + slide_path_to_patient: Mapping[Path, PatientId] = {} + patient_to_ground_truths: Mapping[PatientId, dict[str, str]] = {} + clini_path = dir / "clini.csv" + slide_path = dir / "slide.csv" + + feat_dir = dir / "feats" + feat_dir.mkdir() + + for _ in range(n_patients): + # Random patient ID + patient_id = random_string(16) + + # Generate ground truths for each target + ground_truths = {} + for target_label, categories in zip(target_labels, categories_per_target): + ground_truths[target_label] = random.choice(categories) + + patient_to_ground_truths[patient_id] = ground_truths + + # Generate some slides + for _ in range(random.randint(min_slides_per_patient, max_slides_per_patient)): + slide_path_to_patient[ + create_random_feature_file( + tmp_path=feat_dir, + min_tiles=min_tiles_per_slide, + max_tiles=max_tiles_per_slide, + feat_dim=feat_dim, + extractor_name=extractor_name, + ).relative_to(feat_dir) + ] = patient_id + + # Create clinical table with multiple target columns + clini_data = [] + for patient_id, ground_truths in patient_to_ground_truths.items(): + row = {"patient": patient_id} + row.update(ground_truths) + clini_data.append(row) + + clini_df = pd.DataFrame(clini_data) + clini_df.to_csv(clini_path, index=False) + + slide_df = pd.DataFrame( + slide_path_to_patient.items(), + columns=["slide_path", "patient"], + ) + slide_df.to_csv(slide_path, index=False) + + return clini_path, slide_path, feat_dir, categories_per_target + + def create_random_feature_file( *, tmp_path: Path, diff --git a/tests/test_cache_tiles.py b/tests/test_cache_tiles.py index a665e92c..d9f1411d 100644 --- a/tests/test_cache_tiles.py +++ b/tests/test_cache_tiles.py @@ -6,10 +6,10 @@ import numpy as np import pytest -from stamp.cache import download_file from stamp.preprocessing import Microns, TilePixels from stamp.preprocessing.tiling import _Tile, tiles_with_cache from stamp.types import ImageExtension, SlidePixels +from stamp.utils.cache import download_file def _get_tiles_and_images( diff --git a/tests/test_config.py b/tests/test_config.py index 15b5dd80..fbb53a6c 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,6 @@ # %% from pathlib import Path -from stamp.config import StampConfig from stamp.heatmaps.config import HeatmapConfig from stamp.modeling.config import ( AdvancedConfig, @@ -21,6 +20,7 @@ TilePixels, ) from stamp.statistics import StatsConfig +from stamp.utils.config import StampConfig def test_config_parsing() -> None: diff --git a/tests/test_crossval.py b/tests/test_crossval.py index 184a5c23..5e53f50c 100644 --- a/tests/test_crossval.py +++ b/tests/test_crossval.py @@ -13,7 +13,7 @@ VitModelParams, ) from stamp.modeling.crossval import categorical_crossval_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow diff --git a/tests/test_data.py b/tests/test_data.py index 564d6426..3c86a931 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -23,7 +23,6 @@ get_coords, slide_to_patient_from_slide_table_, ) -from stamp.seed import Seed from stamp.types import ( BagSize, FeaturePath, @@ -33,6 +32,7 @@ SlideMPP, TilePixels, ) +from stamp.utils.seed import Seed @pytest.mark.filterwarnings("ignore:some patients have no associated slides") diff --git a/tests/test_deployment.py b/tests/test_deployment.py index de20ea12..4e1570cc 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import numpy as np import pytest @@ -24,8 +25,8 @@ ) from stamp.modeling.models.mlp import MLP from stamp.modeling.models.vision_tranformer import VisionTransformer -from stamp.seed import Seed from stamp.types import GroundTruth, PatientId, Task +from stamp.utils.seed import Seed def test_predict_patient_level( @@ -83,7 +84,8 @@ def test_predict_patient_level( assert len(predictions) == len(patient_to_data) for pid in patient_ids: - assert predictions[pid].shape == torch.Size([3]), "expected one score per class" + pred = cast(torch.Tensor, predictions[pid]) + assert pred.shape == torch.Size([3]), "expected one score per class" # Check if scores are consistent between runs and different for different patients more_patient_ids = [PatientId(f"pat{i}") for i in range(8, 11)] @@ -124,11 +126,13 @@ def test_predict_patient_level( assert len(more_predictions) == len(all_patient_ids) # Different patients should give different results assert not torch.allclose( - more_predictions[more_patient_ids[0]], more_predictions[more_patient_ids[1]] + cast(torch.Tensor, more_predictions[more_patient_ids[0]]), + cast(torch.Tensor, more_predictions[more_patient_ids[1]]), ), "different inputs should give different results" # The same patient should yield the same result assert torch.allclose( - predictions[patient_ids[0]], more_predictions[patient_ids[0]] + cast(torch.Tensor, predictions[patient_ids[0]]), + cast(torch.Tensor, more_predictions[patient_ids[0]]), ), "the same inputs should repeatedly yield the same results" @@ -295,7 +299,7 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: div_factor=25.0, ) - # ---- Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) + # Build tile-level feature file so batch = (bags, coords, bag_sizes, gt) if task == "classification": feature_file = make_old_feature_file( feats=torch.rand(23, dim_feats), coords=torch.rand(23, 2) @@ -319,7 +323,7 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: ) } - # ---- Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) + # Use tile_bag_dataloader for ALL tasks (so batch has 4 elements) test_dl, _ = tile_bag_dataloader( task=task, # "classification" | "regression" | "survival" patient_data=list(patient_to_data.values()), @@ -341,12 +345,17 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: assert len(predictions) == 1 pred = list(predictions.values())[0] if task == "classification": - assert pred.shape == torch.Size([len(categories)]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([len(categories)]) elif task == "regression": - assert pred.shape == torch.Size([1]) + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.shape == torch.Size([1]) else: # survival # Cox model → scalar log-risk, KM → vector or matrix - assert pred.ndim in (0, 1, 2), f"unexpected survival output shape: {pred.shape}" + pred_tensor = cast(torch.Tensor, pred) + assert pred_tensor.ndim in (0, 1, 2), ( + f"unexpected survival output shape: {pred_tensor.shape}" + ) # Repeatability predictions2 = _predict( @@ -356,4 +365,6 @@ def test_mil_predict_generic(tmp_path: Path, task: Task) -> None: accelerator="cpu", ) for pid in predictions: - assert torch.allclose(predictions[pid], predictions2[pid]) + assert torch.allclose( + cast(torch.Tensor, predictions[pid]), cast(torch.Tensor, predictions2[pid]) + ) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 3edef575..28d7c2c1 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -10,13 +10,13 @@ from huggingface_hub.errors import GatedRepoError from random_data import create_random_dataset, create_random_feature_file, random_string -from stamp.cache import download_file from stamp.encoding import ( EncoderName, init_patient_encoder_, init_slide_encoder_, ) from stamp.preprocessing.config import ExtractorName +from stamp.utils.cache import download_file # Contains an accepted input patch-level feature encoder # TODO: Make a class for each extractor instead of a function. This class diff --git a/tests/test_feature_extractors.py b/tests/test_feature_extractors.py index 699c10f6..6323ee8d 100644 --- a/tests/test_feature_extractors.py +++ b/tests/test_feature_extractors.py @@ -7,8 +7,8 @@ import torch from huggingface_hub.errors import GatedRepoError -from stamp.cache import download_file from stamp.preprocessing import ExtractorName, Microns, TilePixels, extract_ +from stamp.utils.cache import download_file @pytest.mark.slow diff --git a/tests/test_model.py b/tests/test_model.py index 1aa6d80a..f74e0a99 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ import torch +from stamp.modeling.models.barspoon import EncDecTransformer from stamp.modeling.models.mlp import MLP from stamp.modeling.models.trans_mil import TransMIL from stamp.modeling.models.vision_tranformer import VisionTransformer @@ -162,3 +163,114 @@ def test_trans_mil_inference_reproducibility( ) assert logits1.allclose(logits2) + + +def test_enc_dec_transformer_dims( + batch_size: int = 6, + n_tiles: int = 75, + input_dim: int = 456, + d_model: int = 128, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=256, + positional_encoding=True, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert set(logits.keys()) == set(target_n_outs.keys()) + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_single_target( + batch_size: int = 4, + n_tiles: int = 50, + input_dim: int = 256, + d_model: int = 64, +) -> None: + target_n_outs = {"label": 5} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + assert list(logits.keys()) == ["label"] + assert logits["label"].shape == (batch_size, 5) + + +def test_enc_dec_transformer_no_positional_encoding( + batch_size: int = 4, + n_tiles: int = 30, + input_dim: int = 128, + d_model: int = 64, +) -> None: + target_n_outs = {"a": 2, "b": 3} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=128, + positional_encoding=False, + ) + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + logits = model.forward(bags, coords) + + for target_label, n_out in target_n_outs.items(): + assert logits[target_label].shape == (batch_size, n_out) + + +def test_enc_dec_transformer_inference_reproducibility( + batch_size: int = 5, + n_tiles: int = 40, + input_dim: int = 200, + d_model: int = 64, +) -> None: + target_n_outs = {"subtype": 3, "grade": 4} + model = EncDecTransformer( + d_features=input_dim, + target_n_outs=target_n_outs, + d_model=d_model, + num_encoder_heads=4, + num_decoder_heads=4, + num_encoder_layers=2, + num_decoder_layers=2, + dim_feedforward=128, + ) + model = model.eval() + + bags = torch.rand((batch_size, n_tiles, input_dim)) + coords = torch.rand((batch_size, n_tiles, 2)) + + with torch.inference_mode(): + logits1 = model.forward(bags, coords) + logits2 = model.forward(bags, coords) + + for target_label in target_n_outs: + assert logits1[target_label].allclose(logits2[target_label]) diff --git a/tests/test_statistics.py b/tests/test_statistics.py index 790b98ab..b600d7e7 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np +import pandas as pd import torch from random_data import random_patient_preds, random_string @@ -47,3 +48,153 @@ def test_statistics_integration( def test_statistics_integration_for_multiple_patient_preds(tmp_path: Path) -> None: return test_statistics_integration(tmp_path=tmp_path, n_patient_preds=5) + + +def test_statistics_survival_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that survival statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + times = np.random.uniform(30, 2000, size=n_patients) + statuses = np.random.choice([0, 1], size=n_patients, p=[0.3, 0.7]) + risks = np.random.randn(n_patients) + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "day": times, + "status": statuses, + "pred_score": risks, + } + ) + df.to_csv(tmp_path / f"survival-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="survival", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"survival-preds-{i}.csv" for i in range(n_folds)], + time_label="day", + status_label="status", + ) + + assert (tmp_path / "output" / "survival-stats_individual.csv").is_file() + + +def test_statistics_survival_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_survival_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_regression_integration( + *, + tmp_path: Path, + n_folds: int = 1, + n_patients: int = 200, +) -> None: + """Check that regression statistics run without crashing.""" + random.seed(0) + np.random.seed(0) + + for fold_i in range(n_folds): + y_true = np.random.uniform(0, 100, size=n_patients) + y_pred = y_true + np.random.randn(n_patients) * 10 # noisy predictions + df = pd.DataFrame( + { + "patient": [random_string(8) for _ in range(n_patients)], + "target": y_true, + "pred": y_pred, + } + ) + df.to_csv(tmp_path / f"regression-preds-{fold_i}.csv", index=False) + + compute_stats_( + task="regression", + output_dir=tmp_path / "output", + pred_csvs=[tmp_path / f"regression-preds-{i}.csv" for i in range(n_folds)], + ground_truth_label="target", + ) + + assert (tmp_path / "output" / "target_regression-stats_individual.csv").is_file() + assert (tmp_path / "output" / "target_regression-stats_aggregated.csv").is_file() + + +def test_statistics_regression_integration_multiple_folds(tmp_path: Path) -> None: + return test_statistics_regression_integration(tmp_path=tmp_path, n_folds=5) + + +def test_statistics_multi_target_classification_integration( + *, + tmp_path: Path, + n_patient_preds: int = 1, +) -> None: + """Check that multi-target classification statistics run without crashing. + + Multi-target predictions produce separate ground-truth columns per target. + We run compute_stats_ once per target, as the statistics pipeline handles + one target at a time. + """ + random.seed(0) + np.random.seed(0) + torch.random.manual_seed(0) + + categories_per_target = {"subtype": ["A", "B"], "grade": ["1", "2", "3"]} + + for pred_i in range(n_patient_preds): + n_patients = random.randint(100, 500) + data: dict[str, list] = { + "patient": [random_string(8) for _ in range(n_patients)], + } + + for target_label, cats in categories_per_target.items(): + data[target_label] = [random.choice(cats) for _ in range(n_patients)] + probs = torch.softmax(torch.rand(len(cats), n_patients), dim=0) + for j, cat in enumerate(cats): + data[f"{target_label}_{cat}"] = probs[j].tolist() + + pd.DataFrame(data).to_csv( + tmp_path / f"multi-target-preds-{pred_i}.csv", index=False + ) + + # Run statistics per target (as the pipeline would do) + for target_label, cats in categories_per_target.items(): + true_class = cats[0] + compute_stats_( + task="classification", + output_dir=tmp_path / "output" / target_label, + pred_csvs=[ + tmp_path / f"multi-target-preds-{i}.csv" for i in range(n_patient_preds) + ], + ground_truth_label=target_label, + true_class=true_class, + ) + + assert ( + tmp_path + / "output" + / target_label + / f"{target_label}_categorical-stats_aggregated.csv" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"roc-curve_{target_label}={true_class}.svg" + ).is_file() + assert ( + tmp_path + / "output" + / target_label + / f"pr-curve_{target_label}={true_class}.svg" + ).is_file() + + +def test_statistics_multi_target_classification_multiple_preds( + tmp_path: Path, +) -> None: + return test_statistics_multi_target_classification_integration( + tmp_path=tmp_path, n_patient_preds=3 + ) diff --git a/tests/test_train_deploy.py b/tests/test_train_deploy.py index 0180d171..ea5547f1 100644 --- a/tests/test_train_deploy.py +++ b/tests/test_train_deploy.py @@ -8,6 +8,7 @@ import torch from random_data import ( create_random_dataset, + create_random_multi_target_dataset, create_random_patient_level_dataset, create_random_patient_level_survival_dataset, create_random_regression_dataset, @@ -22,8 +23,9 @@ VitModelParams, ) from stamp.modeling.deploy import deploy_categorical_model_ +from stamp.modeling.registry import ModelName from stamp.modeling.train import train_categorical_model_ -from stamp.seed import Seed +from stamp.utils.seed import Seed @pytest.mark.slow @@ -123,6 +125,9 @@ def test_train_deploy_integration( pytest.param(False, True, id="use vary_precision_transform"), ], ) +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_integration( *, tmp_path: Path, @@ -356,6 +361,9 @@ def test_train_deploy_survival_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_regression_integration( *, tmp_path: Path, @@ -465,6 +473,9 @@ def test_train_deploy_patient_level_regression_integration( @pytest.mark.slow +@pytest.mark.filterwarnings( + "ignore:.*violates type hint.*not instance of tuple:UserWarning" +) def test_train_deploy_patient_level_survival_integration( *, tmp_path: Path, @@ -531,3 +542,89 @@ def test_train_deploy_patient_level_survival_integration( accelerator="gpu" if torch.cuda.is_available() else "cpu", num_workers=min(os.cpu_count() or 1, 16), ) + + +@pytest.mark.slow +@pytest.mark.filterwarnings("ignore:No positive samples in targets") +def test_train_deploy_multi_target_integration( + *, + tmp_path: Path, + feat_dim: int = 25, +) -> None: + """Integration test: train + deploy a multi-target tile-level classification model.""" + Seed.set(42) + + (tmp_path / "train").mkdir() + (tmp_path / "deploy").mkdir() + + # Define multi-target setup: subtype (2 categories) and grade (3 categories) + target_labels = ["subtype", "grade"] + categories_per_target = [["A", "B"], ["1", "2", "3"]] + + # Create random multi-target tile-level dataset + train_clini_path, train_slide_path, train_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "train", + n_patients=400, + max_slides_per_patient=3, + min_tiles_per_slide=20, + max_tiles_per_slide=600, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + deploy_clini_path, deploy_slide_path, deploy_feature_dir, _ = ( + create_random_multi_target_dataset( + dir=tmp_path / "deploy", + n_patients=50, + max_slides_per_patient=3, + min_tiles_per_slide=20, + max_tiles_per_slide=600, + feat_dim=feat_dim, + target_labels=target_labels, + categories_per_target=categories_per_target, + ) + ) + + # Build config objects + config = TrainConfig( + task="classification", + clini_table=train_clini_path, + slide_table=train_slide_path, + feature_dir=train_feature_dir, + output_dir=tmp_path / "train_output", + patient_label="patient", + ground_truth_label=target_labels, + filename_label="slide_path", + categories=[cat for cats in categories_per_target for cat in cats], + ) + + advanced = AdvancedConfig( + bag_size=500, + num_workers=min(os.cpu_count() or 1, 16), + batch_size=8, + max_epochs=2, + patience=1, + accelerator="gpu" if torch.cuda.is_available() else "cpu", + model_params=ModelParams(), + model_name=ModelName.BARSPOON, + ) + + # Train + deploy multi-target model + train_categorical_model_(config=config, advanced=advanced) + + deploy_categorical_model_( + output_dir=tmp_path / "deploy_output", + checkpoint_paths=[tmp_path / "train_output" / "model.ckpt"], + clini_table=deploy_clini_path, + slide_table=deploy_slide_path, + feature_dir=deploy_feature_dir, + patient_label="patient", + ground_truth_label=target_labels, + time_label=None, + status_label=None, + filename_label="slide_path", + accelerator="gpu" if torch.cuda.is_available() else "cpu", + num_workers=min(os.cpu_count() or 1, 16), + ) diff --git a/uv.lock b/uv.lock index c4015d9f..96b4b73a 100644 --- a/uv.lock +++ b/uv.lock @@ -3699,13 +3699,14 @@ wheels = [ [[package]] name = "stamp" -version = "2.3.0" +version = "2.4.0" source = { editable = "." } dependencies = [ { name = "beartype" }, { name = "einops" }, { name = "h5py" }, { name = "jaxtyping" }, + { name = "lifelines" }, { name = "lightning" }, { name = "matplotlib" }, { name = "numpy" }, @@ -3807,7 +3808,6 @@ gigapath = [ { name = "fvcore" }, { name = "gigapath" }, { name = "iopath" }, - { name = "lifelines" }, { name = "monai" }, { name = "scikit-image" }, { name = "scikit-survival" }, @@ -3828,7 +3828,6 @@ gpu = [ { name = "huggingface-hub" }, { name = "iopath" }, { name = "jinja2" }, - { name = "lifelines" }, { name = "madeleine" }, { name = "mamba-ssm" }, { name = "monai" }, @@ -3920,7 +3919,7 @@ requires-dist = [ { name = "iopath", marker = "extra == 'gigapath'" }, { name = "jaxtyping", specifier = ">=0.3.2" }, { name = "jinja2", marker = "extra == 'cobra'", specifier = ">=3.1.4" }, - { name = "lifelines", marker = "extra == 'gigapath'" }, + { name = "lifelines", specifier = ">=0.28.0" }, { name = "lightning", specifier = ">=2.5.2" }, { name = "madeleine", marker = "extra == 'madeleine'", git = "https://github.com/mahmoodlab/MADELEINE.git?rev=de7c85acc2bdad352e6df8eee5694f8b6f288012" }, { name = "mamba-ssm", marker = "extra == 'cobra'", specifier = ">=2.2.6.post3" }, @@ -4747,4 +4746,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/24/fd/725b8e73ac2a50e78a4534ac43c6addf5c1c2d65380dd48a9169cc6739a9/yarl-1.20.1-cp313-cp313t-win32.whl", hash = "sha256:b121ff6a7cbd4abc28985b6028235491941b9fe8fe226e6fdc539c977ea1739d", size = 86591, upload-time = "2025-06-10T00:45:25.793Z" }, { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" }, { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" }, -] \ No newline at end of file +] From d8cc268200732f851c36ab59e4cbb1bae6f5c782 Mon Sep 17 00:00:00 2001 From: mducducd Date: Fri, 13 Feb 2026 14:16:16 +0000 Subject: [PATCH 2/5] add multi-target support --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3069dddf..9a0ed68b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ STAMP is an **end‑to‑end, weakly‑supervised deep‑learning pipeline** tha * 🎓 **Beginner‑friendly & expert‑ready**: Zero‑code CLI and YAML config for routine use; optional code‑level customization for advanced research. * 🧩 **Model‑rich**: Out‑of‑the‑box support for **+20 foundation models** at [tile level](getting-started.md#feature-extraction) (e.g., *Virchow‑v2*, *UNI‑v2*) and [slide level](getting-started.md#slide-level-encoding) (e.g., *TITAN*, *COBRA*). * 🔬 **Weakly‑supervised**: End‑to‑end MIL with Transformer aggregation for training, cross‑validation and external deployment; no pixel‑level labels required. -* 🧮 **Multi-task learning**: Unified framework for **classification**, **regression**, and **cox-based survival analysis**. +* 🧮 **Multi-task learning**: Unified framework for **classification**, **multi-target classification**, **regression**, and **cox-based survival analysis**. * 📊 **Stats & results**: Built‑in metrics and patient‑level predictions, ready for analysis and reporting. * 🖼️ **Explainable**: Generates heatmaps and top‑tile exports out‑of‑the‑box for transparent model auditing and publication‑ready figures. * 🤝 **Collaborative by design**: Clinicians drive hypothesis & interpretation while engineers handle compute; STAMP’s modular CLI mirrors real‑world workflows and tracks every step for full reproducibility. From f7fa2c599a2e34ddd8cf67ae764e887995ec80f7 Mon Sep 17 00:00:00 2001 From: mducducd Date: Mon, 16 Feb 2026 14:12:38 +0000 Subject: [PATCH 3/5] add multi-target statistics --- src/stamp/statistics/categorical.py | 82 +++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 2b5c859e..30a03a86 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -21,6 +21,30 @@ ] +def _detect_targets_from_columns(columns: Sequence[str]) -> list[str]: + """Detect target columns from CSV column names. + + Assumes multi-target format where each target has: + - A ground truth column (target name) + - A prediction column (pred_{target}) + - Probability columns ({target}_{class1}, {target}_{class2}, ...) + + Returns: + List of target names detected. + """ + # Convert to list to handle pandas Index + columns = list(columns) + targets = [] + for col in columns: + # Look for columns that start with "pred_" + if col.startswith("pred_"): + target_name = col[5:] # Remove "pred_" prefix + # Verify the target column exists + if target_name in columns: + targets.append(target_name) + return sorted(targets) + + def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: """Calculates some stats for categorical prediction tables. @@ -110,3 +134,61 @@ def categorical_aggregated_( preds_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_individual.csv") stats_df = _aggregate_categorical_stats(preds_df.reset_index()) stats_df.to_csv(outpath / f"{ground_truth_label}_categorical-stats_aggregated.csv") + + +def categorical_aggregated_multitarget_( + *, + preds_csvs: Sequence[Path], + outpath: Path, + target_labels: Sequence[str], +) -> None: + """Calculate statistics for multi-target categorical deployments. + + Args: + preds_csvs: CSV files containing predictions. + outpath: Path to save the results to. + target_labels: List of target labels to compute statistics for. + + This will apply `_categorical` to each target in the multi-target setup, + calculate statistics per target, and save both individual and aggregated results. + """ + outpath.mkdir(parents=True, exist_ok=True) + + all_target_stats = {} + + for target_label in target_labels: + # Process each target separately + preds_dfs = {} + for p in preds_csvs: + df = pd.read_csv(p, dtype=str) + # Drop rows where this target's ground truth is missing + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs[Path(p).parent.name] = _categorical(df_clean, target_label) + + if not preds_dfs: + continue + + # Concatenate and save individual stats for this target + preds_df = pd.concat(preds_dfs).sort_index() + preds_df.to_csv(outpath / f"{target_label}_categorical-stats_individual.csv") + + # Aggregate stats for this target + stats_df = _aggregate_categorical_stats(preds_df.reset_index()) + stats_df.to_csv(outpath / f"{target_label}_categorical-stats_aggregated.csv") + + # Store for summary + all_target_stats[target_label] = stats_df + + # Create a combined summary across all targets + if all_target_stats: + summary_dfs = [] + for target_name, stats_df in all_target_stats.items(): + stats_copy = stats_df.copy() + stats_copy.index = pd.MultiIndex.from_product( + [[target_name], stats_copy.index], names=["target", "class"] + ) + summary_dfs.append(stats_copy) + + combined_summary = pd.concat(summary_dfs) + combined_summary.to_csv(outpath / "multitarget_categorical-stats_summary.csv") From 747143f65b57edd50df8725b63f78b089f6a5a3f Mon Sep 17 00:00:00 2001 From: mducducd Date: Mon, 16 Feb 2026 15:28:08 +0000 Subject: [PATCH 4/5] refactor --- src/stamp/encoding/__init__.py | 2 +- src/stamp/encoding/encoder/__init__.py | 4 +- src/stamp/heatmaps/__init__.py | 100 ++--- src/stamp/modeling/crossval.py | 2 +- src/stamp/modeling/data.py | 53 ++- src/stamp/modeling/deploy.py | 2 +- src/stamp/modeling/models/__init__.py | 125 ++++--- src/stamp/modeling/models/mlp.py | 2 +- .../preprocessing/extractor/ctranspath.py | 2 +- src/stamp/preprocessing/tiling.py | 5 +- src/stamp/statistics/__init__.py | 271 +++++++++++--- src/stamp/statistics/categorical.py | 14 +- src/stamp/utils/target_file.py | 351 ------------------ tests/random_data.py | 6 +- tests/test_data.py | 27 ++ tests/test_deployment.py | 10 +- 16 files changed, 432 insertions(+), 544 deletions(-) delete mode 100644 src/stamp/utils/target_file.py diff --git a/src/stamp/encoding/__init__.py b/src/stamp/encoding/__init__.py index 9cb873bb..02d46594 100644 --- a/src/stamp/encoding/__init__.py +++ b/src/stamp/encoding/__init__.py @@ -73,7 +73,7 @@ def init_slide_encoder_( selected_encoder = encoder case _ as unreachable: - assert_never(unreachable) # type: ignore + assert_never(unreachable) selected_encoder.encode_slides_( output_dir=output_dir, diff --git a/src/stamp/encoding/encoder/__init__.py b/src/stamp/encoding/encoder/__init__.py index 86daa54a..4720ef9b 100644 --- a/src/stamp/encoding/encoder/__init__.py +++ b/src/stamp/encoding/encoder/__init__.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from pathlib import Path from tempfile import NamedTemporaryFile +from typing import cast import h5py import numpy as np @@ -183,7 +184,8 @@ def _read_h5( elif not h5_path.endswith(".h5"): raise ValueError(f"File is not of type .h5: {os.path.basename(h5_path)}") with h5py.File(h5_path, "r") as f: - feats: Tensor = torch.tensor(f["feats"][:], dtype=self.precision) # type: ignore + feats_ds = cast(h5py.Dataset, f["feats"]) + feats: Tensor = torch.tensor(feats_ds[:], dtype=self.precision) coords: CoordsInfo = get_coords(f) extractor: str = f.attrs.get("extractor", "") if extractor == "": diff --git a/src/stamp/heatmaps/__init__.py b/src/stamp/heatmaps/__init__.py index 22fb5250..fb704fe6 100644 --- a/src/stamp/heatmaps/__init__.py +++ b/src/stamp/heatmaps/__init__.py @@ -5,7 +5,7 @@ import logging from collections.abc import Collection, Iterable from pathlib import Path -from typing import cast, no_type_check +from typing import cast import h5py import matplotlib.pyplot as plt @@ -19,7 +19,7 @@ from packaging.version import Version from PIL import Image from torch import Tensor -from torch.func import jacrev # pyright: ignore[reportPrivateImportUsage] +from torch.func import jacrev from stamp.modeling.data import get_coords, get_stride from stamp.modeling.deploy import load_model_from_ckpt @@ -29,6 +29,8 @@ _logger = logging.getLogger("stamp") +_SlideLike = openslide.OpenSlide | openslide.ImageSlide + def _gradcam_per_category( model: torch.nn.Module, @@ -37,23 +39,19 @@ def _gradcam_per_category( ) -> Float[Tensor, "tile category"]: feat_dim = -1 - cam = ( - ( - feats - * jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze(0) - )(feats) - ) - .mean(feat_dim) # type: ignore - .abs() + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze(0) + )(feats), ) + cam = (feats * jac).mean(feat_dim).abs() cam = torch.softmax(cam, dim=-1) - return cam.permute(-1, -2) @@ -70,21 +68,28 @@ def _attention_rollout_single( device = feats.device - # 1. Forward pass to fill attn_weights in each SelfAttention layer + # --- 1. Forward pass to fill attn_weights in each SelfAttention layer --- _ = model( bags=feats.unsqueeze(0), coords=coords.unsqueeze(0), mask=torch.zeros(1, len(feats), dtype=torch.bool, device=device), ) - # 2. Rollout computation + # --- 2. Rollout computation --- attn_rollout: torch.Tensor | None = None - for layer in model.transformer.layers: # type: ignore - attn = getattr(layer[0], "attn_weights", None) # SelfAttention.attn_weights + transformer = getattr(model, "transformer", None) + if transformer is None: + raise RuntimeError("Model does not have a transformer attribute") + for layer in transformer.layers: + attn = getattr(layer, "attn_weights", None) + if attn is None: + first_child = next(iter(layer.children()), None) + if first_child is not None: + attn = getattr(first_child, "attn_weights", None) if attn is None: raise RuntimeError( "SelfAttention.attn_weights not found. " - "Make sure SelfAttention stores them." + "Make sure SelfAttention stores them on the layer or its first child." ) # attn: [heads, seq, seq] @@ -96,10 +101,10 @@ def _attention_rollout_single( if attn_rollout is None: raise RuntimeError("No attention maps collected from transformer layers.") - # 3. Extract CLS → tiles attention + # --- 3. Extract CLS → tiles attention --- cls_attn = attn_rollout[0, 1:] # [tile] - # 4. Normalize for visualization consistency + # --- 4. Normalize for visualization consistency --- cls_attn = cls_attn - cls_attn.min() cls_attn = cls_attn / (cls_attn.max().clamp(min=1e-8)) @@ -117,15 +122,18 @@ def _gradcam_single( """ feat_dim = -1 - jac = jacrev( - lambda bags: model.forward( - bags.unsqueeze(0), - coords=coords.unsqueeze(0), - mask=None, - ).squeeze() - )(feats) + jac = cast( + Tensor, + jacrev( + lambda bags: model.forward( + bags.unsqueeze(0), + coords=coords.unsqueeze(0), + mask=None, + ).squeeze() + )(feats), + ) - cam = (feats * jac).mean(feat_dim).abs() # type: ignore # [tile] + cam = (feats * jac).mean(feat_dim).abs() # [tile] return cam @@ -148,17 +156,21 @@ def _vals_to_im( def _show_thumb( - slide, thumb_ax: Axes, attention: Tensor, default_slide_mpp: SlideMPP | None + slide: _SlideLike, + thumb_ax: Axes, + attention: Tensor, + default_slide_mpp: SlideMPP | None, ) -> np.ndarray: mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int)) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = slide.get_thumbnail(thumb_size) thumb_ax.imshow(np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8]) return np.array(thumb)[: attention.shape[0] * 8, : attention.shape[1] * 8] def _get_thumb_array( - slide, + slide: _SlideLike, attention: torch.Tensor, default_slide_mpp: SlideMPP | None, ) -> np.ndarray: @@ -168,12 +180,12 @@ def _get_thumb_array( """ mpp = get_slide_mpp_(slide, default_mpp=default_slide_mpp) dims_um = np.array(slide.dimensions) * mpp - thumb = np.array(slide.get_thumbnail(np.round(dims_um * 8 / 256).astype(int))) + thumb_size = tuple(np.round(dims_um * 8 / 256).astype(int).tolist()) + thumb = np.array(slide.get_thumbnail(thumb_size)) thumb_crop = thumb[: attention.shape[0] * 8, : attention.shape[1] * 8] return thumb_crop -@no_type_check # beartype<=0.19.0 breaks here for some reason def _show_class_map( class_ax: Axes, top_score_indices: Integer[Tensor, "width height"], @@ -298,13 +310,8 @@ def heatmaps_( raise ValueError( f"Feature file {h5_path} is a slide or patient level feature. Heatmaps are currently supported for tile-level features only." ) - feats = ( - torch.tensor( - h5["feats"][:] # pyright: ignore[reportIndexIssue] - ) - .float() - .to(device) - ) + feats_np = np.asarray(h5["feats"]) + feats = torch.from_numpy(feats_np).float().to(device) coords_info = get_coords(h5) coords_um = torch.from_numpy(coords_info.coords_um).float() stride_um = Microns(get_stride(coords_um)) @@ -322,9 +329,10 @@ def heatmaps_( model = load_model_from_ckpt(checkpoint_path).eval() # TODO: Update version when a newer model logic breaks heatmaps. - if Version(model.stamp_version) < Version("2.4.0"): + stamp_version = str(getattr(model, "stamp_version", "")) + if Version(stamp_version) < Version("2.4.0"): raise ValueError( - f"model has been built with stamp version {model.stamp_version} " + f"model has been built with stamp version {stamp_version} " f"which is incompatible with the current version." ) @@ -356,7 +364,7 @@ def heatmaps_( with torch.no_grad(): scores = torch.softmax( - model.model.forward( + model.model( feats.unsqueeze(-2), coords=coords_um.unsqueeze(-2), mask=torch.zeros( diff --git a/src/stamp/modeling/crossval.py b/src/stamp/modeling/crossval.py index 2caccecb..8ddfb03d 100644 --- a/src/stamp/modeling/crossval.py +++ b/src/stamp/modeling/crossval.py @@ -121,7 +121,7 @@ def categorical_crossval_( categories_for_export: ( dict[str, list] | list ) = [] # declare upfront to avoid unbound variable warnings - categories: Sequence[GroundTruth] | list | None = [] # type: ignore # declare upfront to avoid unbound variable warnings + categories: Sequence[GroundTruth] | list | None = [] if config.task == "classification": # Determine categories for training (single-target) and for export (supports multi-target) diff --git a/src/stamp/modeling/data.py b/src/stamp/modeling/data.py index 8b30ab11..eadb42f8 100755 --- a/src/stamp/modeling/data.py +++ b/src/stamp/modeling/data.py @@ -3,13 +3,13 @@ import logging from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import KW_ONLY, dataclass +from io import BytesIO # accept in _BinaryIOLike at runtime from itertools import groupby from pathlib import Path from typing import ( IO, Any, BinaryIO, - Dict, Final, Generic, List, @@ -23,6 +23,9 @@ import numpy as np import pandas as pd import torch + +# Use beartype's typing for PEP-585 deprecation-safe hints +from beartype.typing import Dict from packaging.version import Version from torch import Tensor from torch.utils.data import DataLoader, Dataset @@ -58,7 +61,9 @@ _EncodedTarget: TypeAlias = ( Tensor | dict[str, Tensor] ) # Union of encoded targets or multi-target dict -_BinaryIOLike: TypeAlias = Union[BinaryIO, IO[bytes]] +_BinaryIOLike: TypeAlias = Union[ + BinaryIO, IO[bytes], BytesIO +] # includes io.BytesIO for runtime checks """The ground truth, encoded numerically - classification: one-hot float [C] - regression: float [1] @@ -73,7 +78,7 @@ class PatientData(Generic[GroundTruthType]): _ = KW_ONLY ground_truth: GroundTruthType - feature_files: Iterable[FeaturePath | BinaryIO] + feature_files: Iterable[FeaturePath | _BinaryIOLike] def tile_bag_dataloader( @@ -533,9 +538,19 @@ def __getitem__( for bag_file in self.bags[index]: with h5py.File(bag_file, "r") as h5: if "feats" in h5: - arr = h5["feats"][:] # pyright: ignore[reportIndexIssue] # original STAMP files + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + arr = feats_obj[:] # original STAMP files else: - arr = h5["patch_embeddings"][:] # type: ignore # your Kronos files + embeddings_obj = h5["patch_embeddings"] + if not isinstance(embeddings_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'patch_embeddings' to be an HDF5 dataset but got {type(embeddings_obj)}" + ) + arr = embeddings_obj[:] # your Kronos files feats.append(torch.from_numpy(arr)) coords_um.append(torch.from_numpy(get_coords(h5).coords_um)) @@ -569,7 +584,7 @@ class PatientFeatureDataset(Dataset): def __init__( self, - feature_files: Sequence[FeaturePath | BinaryIO], + feature_files: Sequence[FeaturePath | _BinaryIOLike], ground_truths: Tensor, # shape: [num_samples, num_classes] transform: Callable[[Tensor], Tensor] | None, ): @@ -585,7 +600,12 @@ def __len__(self): def __getitem__(self, idx: int): feature_file = self.feature_files[idx] with h5py.File(feature_file, "r") as h5: - feats = torch.from_numpy(h5["feats"][:]) # pyright: ignore[reportIndexIssue] + feats_obj = h5["feats"] + if not isinstance(feats_obj, h5py.Dataset): + raise RuntimeError( + f"expected 'feats' to be an HDF5 dataset but got {type(feats_obj)}" + ) + feats = torch.from_numpy(feats_obj[:]) # Accept [V] or [1, V] if feats.ndim == 2 and feats.shape[0] == 1: feats = feats[0] @@ -634,10 +654,15 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: tile_size_px = TilePixels(0) return CoordsInfo(coords_um, tile_size_um, tile_size_px) - coords: np.ndarray = feature_h5["coords"][:] # type: ignore - coords_um: np.ndarray | None = None + coords_obj = feature_h5["coords"] + if not isinstance(coords_obj, h5py.Dataset): + raise RuntimeError( + f"{feature_h5.filename}: expected 'coords' to be an HDF5 dataset but got {type(coords_obj)}" + ) + coords: np.ndarray = coords_obj[:] tile_size_um: Microns | None = None tile_size_px: TilePixels | None = None + coords_um: np.ndarray | None = None if (tile_size := feature_h5.attrs.get("tile_size", None)) and feature_h5.attrs.get( "unit", None ) == "um": @@ -672,7 +697,15 @@ def get_coords(feature_h5: h5py.File) -> CoordsInfo: ) if not tile_size_px and "tile_size_px" in feature_h5.attrs: - tile_size_px = TilePixels(int(feature_h5.attrs["tile_size_px"])) # pyright: ignore[reportArgumentType] + tile_size_px_attr = feature_h5.attrs.get("tile_size_px") + if tile_size_px_attr is not None and isinstance( + tile_size_px_attr, (int, float) + ): + tile_size_px = TilePixels(int(tile_size_px_attr)) + else: + raise RuntimeError( + "Invalid or missing 'tile_size_px' attribute in the feature file." + ) if not tile_size_um or coords_um is None: raise RuntimeError( diff --git a/src/stamp/modeling/deploy.py b/src/stamp/modeling/deploy.py index 10272328..d3b29ebd 100644 --- a/src/stamp/modeling/deploy.py +++ b/src/stamp/modeling/deploy.py @@ -276,7 +276,7 @@ def deploy_categorical_model_( for model_i, model in enumerate(models): predictions = _predict( model=model, - test_dl=test_dl, # pyright: ignore[reportPossiblyUnboundVariable] + test_dl=test_dl, patient_ids=patient_ids, accelerator=accelerator, ) diff --git a/src/stamp/modeling/models/__init__.py b/src/stamp/modeling/models/__init__.py index 86003c52..0b6a3885 100644 --- a/src/stamp/modeling/models/__init__.py +++ b/src/stamp/modeling/models/__init__.py @@ -3,11 +3,13 @@ import inspect from abc import ABC from collections.abc import Iterable, Sequence -from typing import Any, Mapping, TypeAlias +from typing import Any, TypeAlias import lightning -import numpy as np import torch + +# Use beartype.typing.Mapping to avoid PEP-585 deprecation warnings in beartype +from beartype.typing import Mapping from jaxtyping import Bool, Float from packaging.version import Version from torch import Tensor, nn, optim @@ -148,6 +150,19 @@ def on_train_batch_end(self, outputs, batch, batch_idx): ) +class _TileLevelMixin: + """Mixin for tile-level models providing shared MIL masking logic.""" + + @staticmethod + def _mask_from_bags(bags: Bags, bag_sizes: BagSizes) -> Bool[Tensor, "batch tile"]: + """Create attention mask for padded tiles in variable-length bags.""" + max_possible_bag_size = bags.size(1) + mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( + 0 + ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) + return mask + + class LitBaseClassifier(Base): """ PyTorch Lightning wrapper for tile level and patient level clasification. @@ -199,15 +214,16 @@ def __init__( self.class_weights = category_weights self.valid_auroc = MulticlassAUROC(len(categories)) # Number classes - self.categories = np.array(categories) + self.categories = list(categories) self.hparams.update({"task": "classification"}) -class LitTileClassifier(LitBaseClassifier): +class LitTileClassifier(_TileLevelMixin, LitBaseClassifier): """ - PyTorch Lightning wrapper for the model used in weakly supervised - learning settings, such as Multiple Instance Learning (MIL) for whole-slide images or patch-based data. + PyTorch Lightning wrapper for tile-level MIL classification. + + Used in weakly supervised settings for whole-slide images or patch-based data. """ supported_features = ["tile"] @@ -249,7 +265,6 @@ def _step( ) if step_name == "validation": - # TODO this is a bit ugly, we'd like to have `_step` without special cases self.valid_auroc.update(logits, targets.long().argmax(dim=-1)) self.log( f"{step_name}_auroc", @@ -291,31 +306,19 @@ def predict_step( # adding a mask here will *drastically* and *unbearably* increase memory usage return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideClassifier(LitBaseClassifier): - """ - PyTorch Lightning wrapper for MLPClassifier. - """ + """PyTorch Lightning wrapper for slide/patient-level classification.""" supported_features = ["slide"] def forward(self, x: Tensor) -> Tensor: return self.model(x) - def _step(self, batch, step_name: str): - feats, targets = batch + def _step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str + ) -> Loss: + feats, targets = list(batch) # Works for both tuple and list logits = self.model(feats.float()) loss = nn.functional.cross_entropy( logits, @@ -341,17 +344,25 @@ def _step(self, batch, step_name: str): ) return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch, "test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch return self.model(feats) @@ -402,9 +413,10 @@ def _compute_loss(y_true: Tensor, y_pred: Tensor) -> Loss: return nn.functional.l1_loss(y_true, y_pred) -class LitTileRegressor(LitBaseRegressor): +class LitTileRegressor(_TileLevelMixin, LitBaseRegressor): """ - PyTorch Lightning wrapper for weakly supervised / MIL regression at tile/patient level. + PyTorch Lightning wrapper for tile-level MIL regression. + Produces a single continuous output per bag (dim_output = 1). """ @@ -491,38 +503,22 @@ def predict_step( # keep memory usage low as in classifier return self.model(bags, coords=coords, mask=None) - def _mask_from_bags( - *, - bags: Bags, - bag_sizes: BagSizes, - ) -> Bool[Tensor, "batch tile"]: - max_possible_bag_size = bags.size(1) - mask = torch.arange(max_possible_bag_size).type_as(bag_sizes).unsqueeze( - 0 - ).repeat(len(bags), 1) >= bag_sizes.unsqueeze(1) - - return mask - class LitSlideRegressor(LitBaseRegressor): - """ - PyTorch Lightning wrapper for slide-level or patient-level regression. - Produces a single continuous output per slide (dim_output = 1). - """ + """PyTorch Lightning wrapper for slide/patient-level regression.""" supported_features = ["slide"] def forward(self, feats: Tensor) -> Tensor: - """Forward pass for slide-level features.""" return self.model(feats.float()) def _step( self, *, - batch: tuple[Tensor, Tensor], + batch: tuple[Tensor, Tensor] | list[Tensor], step_name: str, ) -> Loss: - feats, targets = batch + feats, targets = list(batch) # Works for both tuple and list preds = self.model(feats.float(), mask=None) # (B, 1) y = targets.to(preds).float() @@ -539,7 +535,6 @@ def _step( ) if step_name == "validation": - # same metrics as LitTileRegressor p = preds.squeeze(-1) t = y.squeeze(-1) self.log( @@ -552,17 +547,25 @@ def _step( return loss - def training_step(self, batch, batch_idx): + def training_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="training") - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="validation") - def test_step(self, batch, batch_idx): + def test_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Loss: return self._step(batch=batch, step_name="test") - def predict_step(self, batch, batch_idx): - feats, _ = batch + def predict_step( + self, batch: tuple[Tensor, Tensor] | list[Tensor], batch_idx: int + ) -> Tensor: + feats, _ = batch if isinstance(batch, tuple) else batch return self.model(feats.float()) @@ -707,12 +710,12 @@ def on_train_epoch_end(self): self.hparams.update({"train_pred_median": self.train_pred_median}) -class LitTileSurvival(LitSurvivalBase): +class LitTileSurvival(_TileLevelMixin, LitSurvivalBase): """ - Tile-level or patch-level survival analysis. - Expects dataloader batches like: - (bags, coords, bag_sizes, targets) - where targets is shape (B,2): [:,0]=time, [:,1]=event (1=event, 0=censored). + Tile-level survival analysis with Cox proportional hazards loss. + + Expects batches: (bags, coords, bag_sizes, targets) + where targets.shape = (B, 2): [:,0]=time, [:,1]=event (0=censored, 1=event). """ supported_features = ["tile"] diff --git a/src/stamp/modeling/models/mlp.py b/src/stamp/modeling/models/mlp.py index e4f8881f..e88a77ca 100644 --- a/src/stamp/modeling/models/mlp.py +++ b/src/stamp/modeling/models/mlp.py @@ -29,7 +29,7 @@ def __init__( layers.append(nn.Dropout(dropout)) in_dim = dim_hidden layers.append(nn.Linear(in_dim, dim_output)) - self.mlp = nn.Sequential(*layers) # type: ignore + self.mlp = nn.Sequential(*layers) @jaxtyped(typechecker=beartype) def forward( diff --git a/src/stamp/preprocessing/extractor/ctranspath.py b/src/stamp/preprocessing/extractor/ctranspath.py index 387e7947..d189e279 100644 --- a/src/stamp/preprocessing/extractor/ctranspath.py +++ b/src/stamp/preprocessing/extractor/ctranspath.py @@ -518,7 +518,7 @@ def forward(self, x, mask: Optional[torch.Tensor] = None): attn = q @ k.transpose(-2, -1) relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) + self.relative_position_index.view(-1) # pyright: ignore[reportCallIssue] ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], diff --git a/src/stamp/preprocessing/tiling.py b/src/stamp/preprocessing/tiling.py index e09a06fa..ce684ba4 100644 --- a/src/stamp/preprocessing/tiling.py +++ b/src/stamp/preprocessing/tiling.py @@ -461,7 +461,10 @@ def _extract_mpp_from_metadata(slide: openslide.AbstractSlide) -> SlideMPP | Non return None doc = minidom.parseString(xml_path) collection = doc.documentElement - images = collection.getElementsByTagName("Image") # pyright: ignore[reportOptionalMemberAccess] + if collection is None: + _logger.error("Document element is None, unable to extract MPP.") + return None + images = collection.getElementsByTagName("Image") pixels = images[0].getElementsByTagName("Pixels") mpp = float(pixels[0].getAttribute("PhysicalSizeX")) except Exception: diff --git a/src/stamp/statistics/__init__.py b/src/stamp/statistics/__init__.py index bdbef1fa..b3243ecc 100644 --- a/src/stamp/statistics/__init__.py +++ b/src/stamp/statistics/__init__.py @@ -15,7 +15,10 @@ from matplotlib import pyplot as plt from pydantic import BaseModel, ConfigDict, Field -from stamp.statistics.categorical import categorical_aggregated_ +from stamp.statistics.categorical import ( + categorical_aggregated_, + categorical_aggregated_multitarget_, +) from stamp.statistics.prc import ( plot_multiple_decorated_precision_recall_curves, plot_single_decorated_precision_recall_curve, @@ -51,7 +54,7 @@ class StatsConfig(BaseModel): task: Task = Field(default="classification") output_dir: Path pred_csvs: list[Path] - ground_truth_label: PandasLabel | None = None + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None true_class: str | None = None time_label: str | None = None status_label: str | None = None @@ -60,52 +63,65 @@ class StatsConfig(BaseModel): _Inches = NewType("_Inches", float) -def compute_stats_( +def _compute_multitarget_classification_stats( *, - task: Task, output_dir: Path, pred_csvs: Sequence[Path], - ground_truth_label: PandasLabel | None = None, - true_class: str | None = None, - time_label: str | None = None, - status_label: str | None = None, + target_labels: Sequence[str], ) -> None: - """Compute and save statistics for the provided task and prediction CSVs. + """Compute statistics and plots for multi-target classification. - This wrapper keeps the external API stable while delegating the detailed - computations and plotting to the submodules under `stamp.statistics.*`. + For each target, creates ROC and PRC curves for each class, + similar to single-target classification. """ - match task: - case "classification": - if true_class is None or ground_truth_label is None: - raise ValueError( - "both true_class and ground_truth_label are required in statistic configuration" - ) - - preds_dfs = [ - _read_table( - p, - usecols=[ground_truth_label, f"{ground_truth_label}_{true_class}"], - dtype={ - ground_truth_label: str, - f"{ground_truth_label}_{true_class}": float, - }, - ) - for p in pred_csvs - ] - - y_trues = [ - np.array(df[ground_truth_label] == true_class) for df in preds_dfs - ] - y_preds = [ - np.array(df[f"{ground_truth_label}_{true_class}"].values) - for df in preds_dfs - ] - n_bootstrap_samples = 1000 - figure_width = _Inches(3.8) - threshold_cmap = None - - roc_curve_figure_aspect_ratio = 1.08 + output_dir.mkdir(parents=True, exist_ok=True) + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + roc_curve_figure_aspect_ratio = 1.08 + + # Validate all target labels exist in CSV + first_df = _read_table(pred_csvs[0], nrows=0) + missing_targets = [t for t in target_labels if t not in first_df.columns] + if missing_targets: + raise ValueError( + f"Target labels not found in CSV: {missing_targets}. Available columns: {list(first_df.columns)}" + ) + + # Process each target + for target_label in target_labels: + # Load data for this target + preds_dfs = [] + for p in pred_csvs: + df = _read_table(p, dtype=str) + # Only keep rows where this target has ground truth + df_clean = df.dropna(subset=[target_label]) + if len(df_clean) > 0: + preds_dfs.append(df_clean) + + if not preds_dfs: + continue + + # Get unique classes for this target + classes = sorted(preds_dfs[0][target_label].unique()) + + # Create plots for each class in this target + for true_class in classes: + # Extract ground truth and predictions for this class + y_trues = [] + y_preds = [] + + for df in preds_dfs: + prob_col = f"{target_label}_{true_class}" + if prob_col not in df.columns: + continue + + y_trues.append(np.array(df[target_label] == true_class)) + y_preds.append(np.array(df[prob_col].astype(float).values)) + + if not y_trues: + continue + + # Plot ROC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, @@ -116,34 +132,35 @@ def compute_stats_( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, - threshold_cmap=threshold_cmap, + threshold_cmap=None, ) else: plot_multiple_decorated_roc_curves( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=None, ) fig.tight_layout() - output_dir.mkdir(parents=True, exist_ok=True) - fig.savefig(output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"roc-curve_{target_label}={true_class}.svg") plt.close(fig) + # Plot PRC curve fig, ax = plt.subplots( figsize=(figure_width, figure_width * roc_curve_figure_aspect_ratio), dpi=300, ) + if len(preds_dfs) == 1: plot_single_decorated_precision_recall_curve( ax=ax, y_true=y_trues[0], y_score=y_preds[0], - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", n_bootstrap_samples=n_bootstrap_samples, ) else: @@ -151,24 +168,170 @@ def compute_stats_( ax=ax, y_trues=y_trues, y_scores=y_preds, - title=f"{ground_truth_label} = {true_class}", + title=f"{target_label} = {true_class}", ) fig.tight_layout() - fig.savefig(output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg") + fig.savefig(output_dir / f"pr-curve_{target_label}={true_class}.svg") plt.close(fig) - categorical_aggregated_( - preds_csvs=pred_csvs, - ground_truth_label=ground_truth_label, - outpath=output_dir, + # Compute aggregated statistics for all targets + categorical_aggregated_multitarget_( + preds_csvs=pred_csvs, + outpath=output_dir, + target_labels=target_labels, + ) + + +def compute_stats_( + *, + task: Task, + output_dir: Path, + pred_csvs: Sequence[Path], + ground_truth_label: PandasLabel | Sequence[PandasLabel] | None = None, + true_class: str | None = None, + time_label: str | None = None, + status_label: str | None = None, +) -> None: + """Compute and save statistics for the provided task and prediction CSVs. + + This wrapper keeps the external API stable while delegating the detailed + computations and plotting to the submodules under `stamp.statistics.*`. + """ + match task: + case "classification": + # Check if multi-target based on ground_truth_label type + is_multitarget = ( + isinstance(ground_truth_label, (list, tuple)) + and len(ground_truth_label) > 1 ) + if is_multitarget: + # Multi-target classification + if not isinstance(ground_truth_label, (list, tuple)): + raise ValueError( + "ground_truth_label must be a list or tuple for multi-target classification" + ) + _compute_multitarget_classification_stats( + output_dir=output_dir, + pred_csvs=pred_csvs, + target_labels=list(ground_truth_label), + ) + else: + # Single-target classification (original behavior) + if true_class is None or ground_truth_label is None: + raise ValueError( + "both true_class and ground_truth_label are required in statistic configuration" + ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for single-target classification" + ) + + preds_dfs = [ + _read_table( + p, + usecols=[ + ground_truth_label, + f"{ground_truth_label}_{true_class}", + ], + dtype={ + ground_truth_label: str, + f"{ground_truth_label}_{true_class}": float, + }, + ) + for p in pred_csvs + ] + + y_trues = [ + np.array(df[ground_truth_label] == true_class) for df in preds_dfs + ] + y_preds = [ + np.array(df[f"{ground_truth_label}_{true_class}"].values) + for df in preds_dfs + ] + n_bootstrap_samples = 1000 + figure_width = _Inches(3.8) + threshold_cmap = None + + roc_curve_figure_aspect_ratio = 1.08 + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + + if len(preds_dfs) == 1: + plot_single_decorated_roc_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + threshold_cmap=threshold_cmap, + ) + else: + plot_multiple_decorated_roc_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=None, + ) + + fig.tight_layout() + output_dir.mkdir(parents=True, exist_ok=True) + fig.savefig( + output_dir / f"roc-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + fig, ax = plt.subplots( + figsize=( + figure_width, + figure_width * roc_curve_figure_aspect_ratio, + ), + dpi=300, + ) + if len(preds_dfs) == 1: + plot_single_decorated_precision_recall_curve( + ax=ax, + y_true=y_trues[0], + y_score=y_preds[0], + title=f"{ground_truth_label} = {true_class}", + n_bootstrap_samples=n_bootstrap_samples, + ) + else: + plot_multiple_decorated_precision_recall_curves( + ax=ax, + y_trues=y_trues, + y_scores=y_preds, + title=f"{ground_truth_label} = {true_class}", + ) + + fig.tight_layout() + fig.savefig( + output_dir / f"pr-curve_{ground_truth_label}={true_class}.svg" + ) + plt.close(fig) + + categorical_aggregated_( + preds_csvs=pred_csvs, + ground_truth_label=ground_truth_label, + outpath=output_dir, + ) + case "regression": if ground_truth_label is None: raise ValueError( "no ground_truth_label configuration supplied in statistic" ) + if not isinstance(ground_truth_label, str): + raise ValueError( + "ground_truth_label must be a string for regression (multi-target regression not yet supported)" + ) regression_aggregated_( preds_csvs=pred_csvs, ground_truth_label=ground_truth_label, diff --git a/src/stamp/statistics/categorical.py b/src/stamp/statistics/categorical.py index 30a03a86..9d6c4c12 100755 --- a/src/stamp/statistics/categorical.py +++ b/src/stamp/statistics/categorical.py @@ -62,29 +62,29 @@ def _categorical(preds_df: pd.DataFrame, target_label: str) -> pd.DataFrame: # roc_auc stats_df["roc_auc_score"] = [ - metrics.roc_auc_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.roc_auc_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # average_precision stats_df["average_precision_score"] = [ - metrics.average_precision_score(y_true == cat, y_pred[:, i]) # pyright: ignore[reportCallIssue,reportArgumentType] + metrics.average_precision_score(y_true == cat, y_pred[:, i]) for i, cat in enumerate(categories) ] # f1 score y_pred_labels = categories[y_pred.argmax(axis=1)] stats_df["f1_score"] = [ - metrics.f1_score(y_true == cat, y_pred_labels == cat) # pyright: ignore[reportCallIssue,reportArgumentType] - for cat in categories + metrics.f1_score(y_true == cat, y_pred_labels == cat) for cat in categories ] # p values p_values = [] for i, cat in enumerate(categories): - pos_scores = y_pred[:, i][y_true == cat] # pyright: ignore[reportCallIssue,reportArgumentType] - neg_scores = y_pred[:, i][y_true != cat] # pyright: ignore[reportCallIssue,reportArgumentType] - p_values.append(st.ttest_ind(pos_scores, neg_scores).pvalue) # pyright: ignore[reportGeneralTypeIssues, reportAttributeAccessIssue] + pos_scores = y_pred[:, i][y_true == cat] + neg_scores = y_pred[:, i][y_true != cat] + _, p_value = st.ttest_ind(pos_scores, neg_scores) + p_values.append(p_value) stats_df["p_value"] = p_values assert set(_score_labels) & set(stats_df.columns) == set(_score_labels) diff --git a/src/stamp/utils/target_file.py b/src/stamp/utils/target_file.py deleted file mode 100644 index 08c30b6b..00000000 --- a/src/stamp/utils/target_file.py +++ /dev/null @@ -1,351 +0,0 @@ -"""Automatically generate target information from clini table - -# The `barspoon-targets 2.0` File Format - -A barspoon target file is a [TOML][1] file with the following entries: - - - A `version` key mapping to a version string `"barspoon-targets "`, where - `` is a [PEP-440 version string][2] compatible with `2.0`. - - A `targets` table, the keys of which are target labels (as found in the - clinical table) and the values specify exactly one of the following: - 1. A categorical target label, marked by the presence of a `categories` - key-value pair. - 2. A target label to quantize, marked by the presence of a `thresholds` - key-value pair. - 3. A target format defined in in a later version of barspoon targets. - A target may only ever have one of the fields `categories` or `thresholds`. - A definition of these entries can be found below. - -[1]: https://toml.io "Tom's Obvious Minimal Language" -[2]: https://peps.python.org/pep-0440/ - "PEP 440 - Version Identification and Dependency Specification" - -## Categorical Target Label - -A categorical target is a target table with a key-value pair `categories`. -`categories` contains a list of lists of literal strings. Each list of strings -will be treated as one category, with all literal strings within that list being -treated as one representative for that category. This allows the user to easily -group related classes into one large class (i.e. `"True", "1", "Yes"` could all -be unified into the same category). - -### Category Weights - -It is possible to assign a weight to each category, to e.g. weigh rarer classes -more heavily. The weights are stored in a table `targets.LABEL.class_weights`, -whose keys is the first representative of each category, and the values of which -is the weight of the category as a floating point number. - -## Target Label to Quantize - -If a target has the `thresholds` option key set, it is interpreted as a -continuous target which has to be quantized. `thresholds` has to be a list of -floating point numbers [t_0, t_n], n > 1 containing the thresholds of the bins -to quantize the values into. A categorical target will be quantized into bins - -```asciimath -b_0 = [t_0; t_1], b_1 = (t_1; b_2], ... b_(n-1) = (t_(n-1); t_n] -``` - -The bins will be treated as categories with names -`f"[{t_0:+1.2e};{t_1:+1.2e}]"` for the first bin and -`f"({t_i:+1.2e};{t_(i+1):+1.2e}]"` for all other bins - -To avoid confusion, we recommend to also format the `thresholds` list the same -way. - -The bins can also be weighted. See _Categorical Target Label: Category Weights_ -for details. - - > Experience has shown that many labels contain non-negative values with a - > disproportionate amount (more than n_samples/n_bins) of zeroes. We thus - > decided to make the _right_ side of each bin inclusive, as the bin (-A,0] - > then naturally includes those zero values. -""" - -import logging -from pathlib import Path -from typing import ( - Any, - Dict, - List, - NamedTuple, - Optional, - Sequence, - TextIO, - Tuple, -) - -import numpy as np -import numpy.typing as npt -import pandas as pd -import torch -import torch.nn.functional as F -from packaging.specifiers import Specifier - - -def read_table(path: Path | TextIO, **kwargs) -> pd.DataFrame: - if not isinstance(path, Path): - return pd.read_csv(path, **kwargs) - elif path.suffix == ".xlsx": - return pd.read_excel(path, **kwargs) - elif path.suffix == ".csv": - return pd.read_csv(path, **kwargs) - else: - raise ValueError( - "table to load has to either be an excel (`*.xlsx`) or csv (`*.csv`) file." - ) - - -__all__ = ["build_targets", "decode_targets"] - - -class TargetSpec(NamedTuple): - version: str - targets: Dict[str, Dict[str, Any]] - - -class EncodedTarget(NamedTuple): - categories: List[str] - encoded: torch.Tensor - weight: torch.Tensor - - -def encode_category( - *, - clini_df: pd.DataFrame, - target_label: str, - categories: Sequence[List[str]], - class_weights: Optional[Dict[str, float]] = None, - **ignored, -) -> Tuple[List[str], torch.Tensor, torch.Tensor]: - # Map each category to its index - category_map = {member: idx for idx, cat in enumerate(categories) for member in cat} - - # Map each item to it's category's index, mapping nans to num_classes+1 - # This way we can easily discard the NaN column later - indexes = clini_df[target_label].map(lambda c: category_map.get(c, len(categories))) - indexes = torch.tensor(indexes.values) - - # Discard nan column - one_hot = F.one_hot(indexes, num_classes=len(categories) + 1)[:, :-1] - - # Class weights - if class_weights is not None: - weight = torch.tensor([class_weights[c[0]] for c in categories]) - else: - # No class weights given; use normalized inverse frequency - counts = one_hot.sum(dim=0) - weight = (w := (counts.sum() / counts)) / w.sum() - - # Warn user of unused labels - if ignored: - logging.warn(f"ignored labels in target {target_label}: {ignored}") - - return [c[0] for c in categories], one_hot, weight - - -def encode_quantize( - *, - clini_df: pd.DataFrame, - target_label: str, - thresholds: npt.NDArray[np.floating[Any]], - class_weights: Optional[Dict[str, float]] = None, - **ignored, -) -> Tuple[List[str], torch.Tensor, torch.Tensor]: - # Warn user of unused labels - if ignored: - logging.warn(f"ignored labels in target {target_label}: {ignored}") - - n_bins = len(thresholds) - 1 - numeric_vals = torch.tensor(pd.to_numeric(clini_df[target_label]).values).reshape( - -1, 1 - ) - - # Map each value to a class index as follows: - # 1. If the value is NaN or less than the left-most threshold, use class - # index 0 - # 2. If it is between the left-most and the right-most threshold, set it to - # the bin number (starting from 1) - # 3. If it is larger than the right-most threshold, set it to N_bins + 1 - bin_index = ( - (numeric_vals > torch.tensor(thresholds).reshape(1, -1)).count_nonzero(1) - # For the first bucket, we have to include the lower threshold - + (numeric_vals.reshape(-1) == thresholds[0]) - ) - # One hot encode and discard nan columns (first and last col) - one_hot = F.one_hot(bin_index, num_classes=n_bins + 2)[:, 1:-1] - - # Class weights - categories = [ - f"[{thresholds[0]:+1.2e};{thresholds[1]:+1.2e}]", - *( - f"({lower:+1.2e};{upper:+1.2e}]" - for lower, upper in zip(thresholds[1:-1], thresholds[2:], strict=True) - ), - ] - - if class_weights is not None: - weight = torch.tensor([class_weights[c] for c in categories]) - else: - # No class weights given; use normalized inverse frequency - counts = one_hot.sum(0) - weight = (w := (np.divide(counts.sum(), counts, where=counts > 0))) / w.sum() - - return categories, one_hot, weight - - -def decode_targets( - encoded: torch.Tensor, - *, - target_labels: Sequence[str], - targets: Dict[str, Any], - version: str = "barspoon-targets 2.0", - **ignored, -) -> List[np.ndarray]: - name, version = version.split(" ") - spec = Specifier("~=2.0") - - if not (name == "barspoon-targets" and spec.contains(version)): - raise ValueError( - f"incompatible target file: expected barspoon-targets{spec}, found `{name} {version}`" - ) - - # Warn user of unused labels - if ignored: - logging.warn(f"ignored parameters: {ignored}") - - decoded_targets = [] - curr_col = 0 - for target_label in target_labels: - info = targets[target_label] - - if (categories := info.get("categories")) is not None: - # Add another column which is one iff all the other values are zero - encoded_target = encoded[:, curr_col : curr_col + len(categories)] - is_none = ~encoded_target.any(dim=1).view(-1, 1) - encoded_target = torch.cat([encoded_target, is_none], dim=1) - - # Decode to class labels - representatives = np.array([c[0] for c in categories] + [None]) - category_index = encoded_target.argmax(dim=1) - decoded = representatives[category_index] - decoded_targets.append(decoded) - - curr_col += len(categories) - - elif (thresholds := info.get("thresholds")) is not None: - n_bins = len(thresholds) - 1 - encoded_target = encoded[:, curr_col : curr_col + n_bins] - is_none = ~encoded_target.any(dim=1).view(-1, 1) - encoded_target = torch.cat([encoded_target, is_none], dim=1) - - bin_edges = [-np.inf, *thresholds, np.inf] - representatives = np.array( - [ - f"[{lower:+1.2e};{upper:+1.2e})" - for lower, upper in zip(bin_edges[:-1], bin_edges[1:]) - ] - ) - decoded = representatives[encoded_target.argmax(dim=1)] - - decoded_targets.append(decoded) - - curr_col += n_bins - - else: - raise ValueError(f"cannot decode {target_label}: no target info") - - return decoded_targets - - -def build_targets( - *, - clini_tables: Sequence[Path], - categorical_labels: Sequence[str], - category_min_count: int = 32, - quantize: Sequence[tuple[str, int]] = (), -) -> Dict[str, EncodedTarget]: - clini_df = pd.concat([read_table(c) for c in clini_tables]) - encoded_targets: Dict[str, EncodedTarget] = {} - - # categorical targets - for target_label in categorical_labels: - counts = clini_df[target_label].value_counts() - well_supported = counts[counts >= category_min_count] - - if len(well_supported) <= 1: - continue - - categories = [[str(cat)] for cat in well_supported.index] - - weights = well_supported.sum() / well_supported - weights /= weights.sum() - - representatives, encoded, weight = encode_category( - clini_df=clini_df, - target_label=target_label, - categories=categories, - class_weights=weights.to_dict(), - ) - - encoded_targets[target_label] = EncodedTarget( - categories=representatives, - encoded=encoded, - weight=weight, - ) - - # quantized targets - for target_label, bincount in quantize: - vals = pd.to_numeric(clini_df[target_label]).dropna() - - if vals.empty: - continue - - vals_clamped = vals.replace( - { - -np.inf: vals[vals != -np.inf].min(), - np.inf: vals[vals != np.inf].max(), - } - ) - - thresholds = np.array( - [ - -np.inf, - *np.quantile(vals_clamped, q=np.linspace(0, 1, bincount + 1))[1:-1], - np.inf, - ], - dtype=float, - ) - - representatives, encoded, weight = encode_quantize( - clini_df=clini_df, - target_label=target_label, - thresholds=thresholds, - ) - - if encoded.shape[1] <= 1: - continue - - encoded_targets[target_label] = EncodedTarget( - categories=representatives, - encoded=encoded, - weight=weight, - ) - - return encoded_targets - - -if __name__ == "__main__": - encoded = build_targets( - clini_tables=[ - Path( - "/mnt/bulk-neptune/nguyenmin/stamp-dev/experiments/survival_prediction/TCGA-CRC-DX_CLINI.xlsx" - ) - ], - categorical_labels=["BRAF", "KRAS", "NRAS"], - category_min_count=32, - quantize=[], - ) - for name, enc in encoded.items(): - print(name, enc.encoded.shape) diff --git a/tests/random_data.py b/tests/random_data.py index c7c36880..b79c9f42 100644 --- a/tests/random_data.py +++ b/tests/random_data.py @@ -74,13 +74,13 @@ def create_random_dataset( clini_df = pd.DataFrame( patient_to_ground_truth.items(), - columns=["patient", "ground-truth"], # pyright: ignore[reportArgumentType] + columns=["patient", "ground-truth"], ) clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( slide_path_to_patient.items(), - columns=["slide_path", "patient"], # pyright: ignore[reportArgumentType] + columns=["slide_path", "patient"], ) slide_df.to_csv(slide_path, index=False) @@ -130,7 +130,7 @@ def create_random_regression_dataset( # --- Write clini + slide tables --- clini_df = pd.DataFrame(patient_to_target, columns=["patient", "target"]) - clini_df["target"] = clini_df["target"].astype(float) # ✅ ensure numeric dtype + clini_df["target"] = clini_df["target"].astype(float) # ensure numeric dtype clini_df.to_csv(clini_path, index=False) slide_df = pd.DataFrame( diff --git a/tests/test_data.py b/tests/test_data.py index 3c86a931..a60d830e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,6 +3,7 @@ from pathlib import Path import h5py +import pandas as pd import pytest import torch from random_data import ( @@ -21,6 +22,7 @@ PatientFeatureDataset, filter_complete_patient_data_, get_coords, + patient_to_ground_truth_from_clini_table_, slide_to_patient_from_slide_table_, ) from stamp.types import ( @@ -85,6 +87,31 @@ def test_get_cohort_df(tmp_path: Path) -> None: } +def test_patient_to_ground_truth_multi_target(tmp_path: Path) -> None: + """Verify multi-target clini parsing returns dicts and drops rows missing all targets.""" + df = pd.DataFrame( + { + "patient": ["p1", "p2", "p3", "p4"], + "subtype": ["A", None, "B", None], + "grade": ["1", "2", None, None], + } + ) + df.to_csv(tmp_path / "clini.csv", index=False) + + result = patient_to_ground_truth_from_clini_table_( + clini_table_path=tmp_path / "clini.csv", + patient_label="patient", + ground_truth_label=["subtype", "grade"], + ) + + # p4 has both targets missing → dropped + assert "p4" not in result + + assert result["p1"] == {"subtype": "A", "grade": "1"} + assert result["p2"] == {"subtype": None, "grade": "2"} + assert result["p3"] == {"subtype": "B", "grade": None} + + @pytest.mark.parametrize( "feature_file_creator", [make_feature_file, make_old_feature_file], diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 4e1570cc..7d1d6589 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -167,7 +167,7 @@ def test_to_prediction_df(task: str) -> None: ) if task == "classification": preds_df = _to_prediction_df( - categories=list(model.categories), # type: ignore + categories=list(cast(list, model.categories)), patient_to_ground_truth={ PatientId("pat5"): GroundTruth("foo"), PatientId("pat6"): None, @@ -196,13 +196,13 @@ def test_to_prediction_df(task: str) -> None: # Check if no loss / target is given for targets with missing ground truths no_ground_truth = preds_df[preds_df["patient"].isin(["pat6"])] - assert no_ground_truth["target"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert no_ground_truth["loss"].isna().all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert no_ground_truth["target"].isna().all() + assert no_ground_truth["loss"].isna().all() # Check if loss / target is given for targets with ground truths with_ground_truth = preds_df[preds_df["patient"].isin(["pat5", "pat7"])] - assert (~with_ground_truth["target"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] - assert (~with_ground_truth["loss"].isna()).all() # pyright: ignore[reportGeneralTypeIssues,reportAttributeAccessIssue] + assert (~with_ground_truth["target"].isna()).all() + assert (~with_ground_truth["loss"].isna()).all() elif task == "regression": patient_to_ground_truth = {} From 4a453aea08b238f0db2b89de721a6fc65ef4cf4b Mon Sep 17 00:00:00 2001 From: Minh Duc Nguyen <37109868+mducducd@users.noreply.github.com> Date: Tue, 3 Mar 2026 10:29:23 +0100 Subject: [PATCH 5/5] Remove unused import from survival.py Removed unused import for add_at_risk_counts. --- src/stamp/statistics/survival.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stamp/statistics/survival.py b/src/stamp/statistics/survival.py index 6ff75b61..87415a7a 100644 --- a/src/stamp/statistics/survival.py +++ b/src/stamp/statistics/survival.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd from lifelines import KaplanMeierFitter -from lifelines.plotting import add_at_risk_counts from lifelines.statistics import logrank_test from lifelines.utils import concordance_index