diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 5731118c7..daf6e1b37 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -56,7 +56,7 @@ def compute_forces_virials( create_graph=training, # Create graph for second derivative allow_unused=True, ) - stress = torch.zeros_like(displacement) + stress = None if compute_stress and virials is not None: cell = cell.view(-1, 3, 3) volume = torch.linalg.det(cell).abs().unsqueeze(-1) diff --git a/mace/tools/tables_utils.py b/mace/tools/tables_utils.py index 04ff64014..fd0417ad3 100644 --- a/mace/tools/tables_utils.py +++ b/mace/tools/tables_utils.py @@ -148,6 +148,7 @@ def create_error_table( ) elif ( table_type == "PerAtomRMSEstressvirials" + and "rmse_stress" in metrics and metrics["rmse_stress"] is not None ): table.add_row( @@ -161,6 +162,7 @@ def create_error_table( ) elif ( table_type == "PerAtomRMSEstressvirials" + and "rmse_virials" in metrics and metrics["rmse_virials"] is not None ): table.add_row( @@ -174,6 +176,7 @@ def create_error_table( ) elif ( table_type == "PerAtomMAEstressvirials" + and "mae_stress" in metrics and metrics["mae_stress"] is not None ): table.add_row( @@ -187,6 +190,7 @@ def create_error_table( ) elif ( table_type == "PerAtomMAEstressvirials" + and "mae_virials" in metrics and metrics["mae_virials"] is not None ): table.add_row( diff --git a/mace/tools/train.py b/mace/tools/train.py index c7c17e136..67972f1cd 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -68,6 +68,7 @@ def valid_err_log( ) elif ( log_errors == "PerAtomRMSEstressvirials" + and "rmse_stress" in eval_metrics and eval_metrics["rmse_stress"] is not None ): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 @@ -78,6 +79,7 @@ def valid_err_log( ) elif ( log_errors == "PerAtomRMSEstressvirials" + and "rmse_virials_per_atom" in eval_metrics and eval_metrics["rmse_virials_per_atom"] is not None ): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 @@ -88,6 +90,7 @@ def valid_err_log( ) elif ( log_errors == "PerAtomMAEstressvirials" + and "mae_stress_per_atom" in eval_metrics and eval_metrics["mae_stress_per_atom"] is not None ): error_e = eval_metrics["mae_e_per_atom"] * 1e3 @@ -98,6 +101,7 @@ def valid_err_log( ) elif ( log_errors == "PerAtomMAEstressvirials" + and "mae_virials_per_atom" in eval_metrics and eval_metrics["mae_virials_per_atom"] is not None ): error_e = eval_metrics["mae_e_per_atom"] * 1e3