diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 64d88bc8b..7d992d1ca 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -82,6 +82,8 @@ Solvers DeepEnsembleSupervisedSolver ReducedOrderModelSolver GAROM + AutoregressiveSolverInterface + AutoregressiveSolver Models diff --git a/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst new file mode 100644 index 000000000..4cde8d1b9 --- /dev/null +++ b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver.rst @@ -0,0 +1,7 @@ +Autoregressive Solver +====================== +.. currentmodule:: pina.solver.autoregressive_solver.autoregressive_solver + +.. autoclass:: pina._src.solver.autoregressive_solver.autoregressive_solver.AutoregressiveSolver + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst new file mode 100644 index 000000000..516409bd1 --- /dev/null +++ b/docs/source/_rst/solver/autoregressive_solver/autoregressive_solver_interface.rst @@ -0,0 +1,7 @@ +Autoregressive Solver Interface +================================= +.. currentmodule:: pina.solver.autoregressive_solver.autoregressive_solver_interface + +.. autoclass:: pina._src.solver.autoregressive_solver.autoregressive_solver_interface.AutoregressiveSolverInterface + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/_src/solver/autoregressive_solver/__init__.py b/pina/_src/solver/autoregressive_solver/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pina/_src/solver/autoregressive_solver/autoregressive_solver.py b/pina/_src/solver/autoregressive_solver/autoregressive_solver.py new file mode 100644 index 000000000..e0b92af3d --- /dev/null +++ b/pina/_src/solver/autoregressive_solver/autoregressive_solver.py @@ -0,0 +1,398 @@ +import torch +from pina._src.solver.autoregressive_solver.autoregressive_solver_interface import ( + AutoregressiveSolverInterface, +) +from pina._src.solver.solver import SingleSolverInterface +from pina._src.loss.loss_interface import LossInterface +from pina._src.core.utils import check_consistency + + +class AutoregressiveSolver( + AutoregressiveSolverInterface, SingleSolverInterface +): + r""" + The autoregressive Solver for learning dynamical systems. + + This solver learns a one-step transition function + :math:`\mathcal{M}: \mathbb{R}^n \rightarrow \mathbb{R}^n` that maps + a state :math:`\mathbf{y}_t` to the next state :math:`\mathbf{y}_{t+1}`. + + During training, the model is unrolled over multiple time steps to + learn long-term dynamics. Given an initial state :math:`\mathbf{y}_0`, + the model generates predictions recursively: + + .. math:: + \hat{\mathbf{y}}_{t+1} = \mathcal{M}(\hat{\mathbf{y}}_t), + \quad \hat{\mathbf{y}}_0 = \mathbf{y}_0 + + The loss is computed over the entire unroll window: + + .. math:: + \mathcal{L} = \sum_{t=1}^{T} w_t \|\hat{\mathbf{y}}_t - \mathbf{y}_t\|^2 + + where :math:`w_t` are exponential weights that down-weight later predictions + to stabilize training. + """ + + def __init__( + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + weighting=None, + use_lt=False, + reset_weights_at_epoch_start=True, + ): + """ + Initialization of the :class:`AutoregressiveSolver` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The neural network model to be used. + :param torch.nn.Module loss: The loss function to be minimized. + If ``None``, the :class:`torch.nn.MSELoss` loss is used. + Default is ``None``. + :param Optimizer optimizer: The optimizer to be used. + If ``None``, the :class:`torch.optim.Adam` optimizer is used. + Default is ``None``. + :param Scheduler scheduler: Learning rate scheduler. + If ``None``, the :class:`torch.optim.lr_scheduler.ConstantLR` + scheduler is used. Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If ``None``, no weighting schema is used. Default is ``None``. + :param bool use_lt: Whether to use LabelTensors. Default is ``False``. + :param bool reset_weights_at_epoch_start: If ``True``, the running + averages used for adaptive weighting are reset at the start of each + epoch. Setting this parameter to ``False`` can improve training + stability, especially when data are scarce. Default is ``True``. + :raise ValueError: If the provided loss function is not compatible. + :raise ValueError: If ``reset_weights_at_epoch_start`` is not a boolean. + """ + super().__init__( + problem=problem, + model=model, + optimizer=optimizer, + scheduler=scheduler, + weighting=weighting, + use_lt=use_lt, + ) + + # Check consistency + loss = loss or torch.nn.MSELoss() + check_consistency( + loss, (LossInterface, torch.nn.modules.loss._Loss), subclass=False + ) + check_consistency(reset_weights_at_epoch_start, bool) + + # Initialization + self._loss_fn = loss + self.reset_weights_at_epoch_start = reset_weights_at_epoch_start + self._running_avg = {} + self._step_count = {} + + def on_train_epoch_start(self): + """ + Clean up running averages at the start of each epoch if + ``reset_weights_at_epoch_start`` is True. + """ + if self.reset_weights_at_epoch_start: + self._running_avg.clear() + self._step_count.clear() + + def optimization_cycle(self, batch): + """ + The optimization cycle for autoregressive solvers. + + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :return: The losses computed for all conditions in the batch. + :rtype: dict + """ + # Store losses for each condition in the batch + condition_loss = {} + + # Loop through each condition and compute the autoregressive loss + for condition_name, points in batch: + # TODO: remove setting once AutoregressiveCondition is implemented + # TODO: pass a temporal weighting schema in the __init__ + if hasattr(self.problem.conditions[condition_name], "settings"): + settings = self.problem.conditions[condition_name].settings + eps = settings.get("eps", None) + kwargs = settings.get("kwargs", {}) + else: + eps = None + kwargs = {} + + loss = self.loss_autoregressive( + points["input"], + condition_name=condition_name, + eps=eps, + **kwargs, + ) + condition_loss[condition_name] = loss + return condition_loss + + def loss_autoregressive( + self, + input, + condition_name, + eps=None, + aggregation_strategy=None, + **kwargs, + ): + """ + Compute the loss for each autoregressive condition. + + :param input: The input tensor containing unroll windows. + :type input: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for loss computation. + :raise ValueError: If ``input`` has less than 4 dimensions. + :return: The scalar loss value for the given batch. + :rtype: torch.Tensor | LabelTensor + """ + # Check input dimensionality + if input.dim() < 4: + raise ValueError( + "The provided input tensor must have at least 4 dimensions:" + " [trajectories, windows, time_steps, *features]." + f" Got shape {input.shape}." + ) + + # Initialize current state and loss list + current_state = input[:, :, 0] + losses = [] + + # Iterate through the unroll window and compute the loss for each step + for step in range(1, input.shape[2]): + + # Predict + processed_input = self.preprocess_step(current_state, **kwargs) + output = self.forward(processed_input) + predicted_state = self.postprocess_step(output, **kwargs) + + # Compute step loss + target_state = input[:, :, step] + step_loss = self._loss_fn(predicted_state, target_state, **kwargs) + losses.append(step_loss) + + # Update current state for the next step + current_state = predicted_state + + # Stack step losses into a tensor of shape [time_steps - 1] + step_losses = torch.stack(losses).as_subclass(torch.Tensor) + + # Compute adaptive weights based on running averages of step losses + with torch.no_grad(): + condition_name = condition_name or "default" + weights = self._get_weights(condition_name, step_losses, eps) + + # Aggregate the weighted step losses into a single scalar loss value + if aggregation_strategy is None: + aggregation_strategy = torch.mean + + return aggregation_strategy(step_losses * weights) + + def preprocess_step(self, current_state, **kwargs): + """ + Pre-process the current state before passing it to the model's forward. + + :param current_state: The current state to be preprocessed. + :type current_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for pre-processing. + :return: The preprocessed state for the given step. + :rtype: torch.Tensor | LabelTensor + """ + return current_state + + def postprocess_step(self, predicted_state, **kwargs): + """ + Post-process the state predicted by the model. + + :param predicted_state: The predicted state tensor from the model. + :type predicted_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for post-processing. + :return: The post-processed predicted state tensor. + :rtype: torch.Tensor | LabelTensor + """ + return predicted_state + + def _get_weights(self, condition_name, step_losses, eps): + """ + Return cached weights or compute new ones. + + :param str condition_name: The name of the current condition. + :param torch.Tensor step_losses: The tensor of per-step losses. + :param float eps: The weighting parameter. + :return: The weights tensor. + :rtype: torch.Tensor + """ + # Determine the key for caching based on the condition name + key = condition_name or "default" + + # Initialize the key if not in the running averages. + if key not in self._running_avg: + self._running_avg[key] = step_losses.detach().clone() + self._step_count[key] = 1 + + # Update running averages and counts + else: + self._step_count[key] += 1 + value = step_losses.detach() - self._running_avg[key] + self._running_avg[key] += value / self._step_count[key] + + return self._compute_adaptive_weights(self._running_avg[key], eps) + + def _compute_adaptive_weights(self, step_losses, eps): + """ + Compute temporal adaptive weights. + + :param torch.Tensor step_losses: The tensor of per-step losses. + :param float eps: The weighting parameter. + :return: The weights tensor. + :rtype: torch.Tensor + """ + # If eps is None, return uniform weights + if eps is None: + return torch.ones_like(step_losses) + + # Compute cumulative loss and apply exponential weighting + cumulative_loss = -eps * torch.cumsum(step_losses, dim=0) + + return torch.exp(cumulative_loss) + + def predict(self, initial_state, n_steps, **kwargs): + """ + Generate predictions by recursively calling the model's forward. + + :param initial_state: The initial state from which to start prediction. + The initial state must be of shape ``[trajectories, 1, *features]``. + :type initial_state: torch.Tensor | LabelTensor + :param int n_steps: The number of autoregressive steps to predict. + :param dict kwargs: Additional keyword arguments. + :raise ValueError: If the provided initial_state tensor has less than 3 + dimensions. + :return: The predicted trajectory, including the initial state. It has + shape ``[trajectories, n_steps + 1, *features]``, where the first + step corresponds to the initial state. + :rtype: torch.Tensor | LabelTensor + """ + # Set model to evaluation mode for prediction + self.eval() + + # Check intial state dimensionality + if initial_state.dim() < 3: + raise ValueError( + "The provided initial_state tensor must have at least 3" + "dimensions: [trajectories, time_steps, *features]." + f" Got shape {initial_state.shape}." + ) + + # Initialize the list of predictions with the initial state + predictions = [initial_state] + + # Generate predictions recursively for n_steps + with torch.no_grad(): + for _ in range(n_steps): + input = self.preprocess_step(predictions[-1], **kwargs) + output = self.forward(input) + next_state = self.postprocess_step(output, **kwargs) + predictions.append(next_state) + + return torch.stack(predictions, dim=2) + + # TODO: integrate in the Autoregressive Condition once implemented + @staticmethod + def unroll(data, unroll_length, n_unrolls=None, randomize=True): + """ + Create unrolling time windows from temporal data. + + This function takes as input a tensor of shape + ``[trajectories, time_steps, *features]`` and produces a tensor of shape + ``[trajectories, windows, unroll_length, *features]``. + Each window contains a sequence of subsequent states used for computing + the multi-step loss during training. + + :param data: The temporal data tensor to be unrolled. + :type data: torch.Tensor | LabelTensor + :param int unroll_length: The number of time steps in each window. + :param int n_unrolls: The maximum number of windows to return. + If ``None``, all valid windows are returned. Default is ``None``. + :param bool randomize: If ``True``, starting indices are randomly + permuted before applying ``n_unrolls``. Default is ``True``. + :raise ValueError: If the input ``data`` has less than 3 dimensions. + :raise ValueError: If ``unroll_length`` is greater or equal to the + number of time steps in ``data``. + :return: A tensor of unrolled windows. + :rtype: torch.Tensor | LabelTensor + """ + # Check input dimensionality + if data.dim() < 3: + raise ValueError( + "The provided data tensor must have at least 3 dimensions:" + " [trajectories, time_steps, *features]." + f" Got shape {data.shape}." + ) + + # Determine valid starting indices for unroll windows + start_idx = AutoregressiveSolver._get_start_idx( + n_steps=data.shape[1], + unroll_length=unroll_length, + n_unrolls=n_unrolls, + randomize=randomize, + ) + + # Create unroll windows by slicing the data tensor at starting indices + windows = [data[:, s : s + unroll_length] for s in start_idx] + + return torch.stack(windows, dim=1) + + @staticmethod + def _get_start_idx(n_steps, unroll_length, n_unrolls=None, randomize=True): + """ + Determine starting indices for unroll windows. + + :param int n_steps: The total number of time steps in the data. + :param int unroll_length: The number of time steps in each window. + :param int n_unrolls: The maximum number of windows to return. + If ``None``, all valid windows are returned. Default is ``None``. + :param bool randomize: If ``True``, starting indices are randomly + permuted before applying ``n_unrolls``. Default is ``True``. + :raise ValueError: If ``unroll_length`` is greater or equal to the + number of time steps in ``data``. + :return: A tensor of starting indices for unroll windows. + :rtype: torch.Tensor + """ + # Calculate the last valid starting index for unroll windows + last_idx = n_steps - unroll_length + + # Raise error if no valid windows can be created + if last_idx < 0: + raise ValueError( + f"Cannot create unroll windows: unroll_length ({unroll_length})" + " cannot be greater or equal to the number of time_steps" + f" ({n_steps})." + ) + + # Generate ordered starting indices for unroll windows + indices = torch.arange(last_idx + 1) + + # Permute indices if randomization is enabled + if randomize: + indices = indices[torch.randperm(len(indices))] + + # Limit the number of windows if n_unrolls is specified + if n_unrolls is not None and n_unrolls < len(indices): + indices = indices[:n_unrolls] + + return indices + + @property + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ + return self._loss_fn diff --git a/pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py b/pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py new file mode 100644 index 000000000..7029995fd --- /dev/null +++ b/pina/_src/solver/autoregressive_solver/autoregressive_solver_interface.py @@ -0,0 +1,82 @@ +"""Module for the Autoregressive Solver Interface.""" + +from abc import abstractmethod +from pina._src.condition.data_condition import DataCondition +from pina._src.solver.solver import SolverInterface + + +class AutoregressiveSolverInterface(SolverInterface): + # TODO: fix once the AutoregressiveCondition is implemented. + """ + Abstract interface for all autoregressive solvers. + + Any solver implementing this interface is expected to be designed to learn + dynamical systems in an autoregressive manner. The solver should handle + conditions of type :class:`~pina.condition.data_condition.DataCondition`. + """ + + accepted_conditions_types = (DataCondition,) + + @abstractmethod + def preprocess_step(self, current_state, **kwargs): + """ + Pre-process the current state before passing it to the model's forward. + + :param current_state: The current state to be preprocessed. + :type current_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for pre-processing. + :return: The preprocessed state for the given step. + :rtype: torch.Tensor | LabelTensor + """ + + @abstractmethod + def postprocess_step(self, predicted_state, **kwargs): + """ + Post-process the state predicted by the model. + + :param predicted_state: The predicted state tensor from the model. + :type predicted_state: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for post-processing. + :return: The post-processed predicted state tensor. + :rtype: torch.Tensor | LabelTensor + """ + + # TODO: remove once the AutoregressiveCondition is implemented. + @abstractmethod + def loss_autoregressive(self, input, **kwargs): + """ + Compute the loss for each autoregressive condition. + + :param input: The input tensor containing unroll windows. + :type input: torch.Tensor | LabelTensor + :param dict kwargs: Additional keyword arguments for loss computation. + :return: The scalar loss value for the given batch. + :rtype: torch.Tensor | LabelTensor + """ + + @abstractmethod + def predict(self, starting_value, num_steps, **kwargs): + """ + Generate predictions by recursively applying the model. + + :param starting_value: The initial state from which to start prediction. + The initial state must be of shape ``[trajectories, 1, features]``, + where the trajectory dimension can be used for batching. + :type starting_value: torch.Tensor | LabelTensor + :param int num_steps: The number of autoregressive steps to predict. + :param dict kwargs: Additional keyword arguments. + :return: The predicted trajectory, including the initial state. It has + shape ``[trajectories, num_steps + 1, features]``, where the first + step corresponds to the initial state. + :rtype: torch.Tensor | LabelTensor + """ + + @property + @abstractmethod + def loss(self): + """ + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module + """ diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index a93914099..619e59d04 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -27,6 +27,8 @@ "DeepEnsembleSupervisedSolver", "DeepEnsemblePINN", "GAROM", + "AutoregressiveSolver", + "AutoregressiveSolverInterface", ] from pina._src.solver.solver import ( @@ -64,3 +66,8 @@ ) from pina._src.solver.garom import GAROM + +from pina._src.solver.autoregressive_solver.autoregressive_solver import ( + AutoregressiveSolver, + AutoregressiveSolverInterface, +) diff --git a/tests/test_solver/test_autoregressive_solver.py b/tests/test_solver/test_autoregressive_solver.py new file mode 100644 index 000000000..5dba8879c --- /dev/null +++ b/tests/test_solver/test_autoregressive_solver.py @@ -0,0 +1,203 @@ +import shutil +import pytest +import torch +from torch._dynamo.eval_frame import OptimizedModule + +from pina import Condition, Trainer, LabelTensor +from pina.solver import AutoregressiveSolver +from pina.condition import DataCondition +from pina.problem import AbstractProblem +from pina.model import FeedForward + + +# Hyperparameters and settings +n_traj = 5 +t_steps = 10 +n_feats = 2 +unroll_length = 3 +n_unrolls = 4 + + +# TODO: test this in AutoregressiveCondition once it's implemented +# Utility function to create synthetic data for testing +def create_data(n_traj, t_steps, n_feats, unroll_length, n_unrolls, use_lt): + + init_state = torch.rand(n_traj, n_feats) + traj = torch.stack([0.95**i * init_state for i in range(t_steps)], dim=1) + + data = AutoregressiveSolver.unroll( + data=traj, + unroll_length=unroll_length, + n_unrolls=n_unrolls, + randomize=True, + ) + labels = [f"feat_{i}" for i in range(n_feats)] + return LabelTensor(data, labels=labels) + + +# Data +data = create_data( + n_traj=n_traj, + t_steps=t_steps, + n_feats=n_feats, + unroll_length=unroll_length, + n_unrolls=n_unrolls, + use_lt=True, +) + + +# Problem +class Problem(AbstractProblem): + + input_variables = [f"feat_{i}" for i in range(n_feats)] + output_variables = [f"feat_{i}" for i in range(n_feats)] + conditions = {} + + def __init__(self, data): + super().__init__() + self.data = data + self.conditions = {"autoregressive": Condition(input=self.data)} + self.conditions_settings = { + "autoregressive": {"eps": 0.1} + } # TODO: remove once the autoregressive condition is implemented + + +problem = Problem(data) +model = FeedForward(n_feats, n_feats, 128, 2) + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("bool_value", [True, False]) +def test_constructor(use_lt, bool_value): + + solver = AutoregressiveSolver( + problem=problem, + model=model, + reset_weights_at_epoch_start=bool_value, + use_lt=use_lt, + ) + + assert solver.accepted_conditions_types == ( + DataCondition, + ) # TODO: update once the AutoregressiveCondition is implemented + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("batch_size", [None, 1, 2, 5]) +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("bool_value", [True, False]) +def test_solver_train(use_lt, batch_size, compile, bool_value): + solver = AutoregressiveSolver( + model=model, + problem=problem, + reset_weights_at_epoch_start=bool_value, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=1.0, + val_size=0.0, + test_size=0.0, + compile=compile, + ) + trainer.train() + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("batch_size", [None, 1, 2, 5]) +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("bool_value", [True, False]) +def test_solver_validation(use_lt, batch_size, compile, bool_value): + solver = AutoregressiveSolver( + model=model, + problem=problem, + reset_weights_at_epoch_start=bool_value, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.9, + val_size=0.1, + test_size=0.0, + compile=compile, + ) + trainer.train() + if trainer.compile: + assert isinstance(solver.model, OptimizedModule) + + +@pytest.mark.parametrize("use_lt", [True, False]) +@pytest.mark.parametrize("batch_size", [None, 1, 2, 5]) +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.parametrize("bool_value", [True, False]) +def test_solver_test(use_lt, batch_size, compile, bool_value): + solver = AutoregressiveSolver( + model=model, + problem=problem, + reset_weights_at_epoch_start=bool_value, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=2, + accelerator="cpu", + batch_size=batch_size, + train_size=0.7, + val_size=0.2, + test_size=0.1, + compile=compile, + ) + trainer.test() + + +@pytest.mark.parametrize("use_lt", [True, False]) +def test_train_load_restore(use_lt): + dir = "tests/test_solver/tmp" + solver = AutoregressiveSolver( + model=model, + problem=problem, + reset_weights_at_epoch_start=False, + use_lt=use_lt, + ) + trainer = Trainer( + solver=solver, + max_epochs=5, + accelerator="cpu", + batch_size=None, + train_size=0.7, + val_size=0.2, + test_size=0.1, + default_root_dir=dir, + ) + trainer.train() + + # restore + new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu") + new_trainer.train( + ckpt_path=f"{dir}/lightning_logs/version_0/checkpoints/" + + "epoch=4-step=5.ckpt" + ) + + # loading + new_solver = AutoregressiveSolver.load_from_checkpoint( + f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt", + problem=problem, + model=model, + ) + + test_pts = LabelTensor( + torch.rand(n_traj, t_steps, n_feats), problem.input_variables + ) + assert new_solver.forward(test_pts).shape == (n_traj, t_steps, n_feats) + assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape + torch.testing.assert_close( + new_solver.forward(test_pts), solver.forward(test_pts) + ) + + shutil.rmtree("tests/test_solver/tmp")