From 2c625ba06fdf1a4017625b06539475256ad6f121 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Tue, 15 Apr 2025 00:44:22 +0100 Subject: [PATCH 1/2] fix for: returning zero stress instead of None caused incorrect valid logging --- mace/modules/utils.py | 2 +- mace/tools/train.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mace/modules/utils.py b/mace/modules/utils.py index 5731118c7..daf6e1b37 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -56,7 +56,7 @@ def compute_forces_virials( create_graph=training, # Create graph for second derivative allow_unused=True, ) - stress = torch.zeros_like(displacement) + stress = None if compute_stress and virials is not None: cell = cell.view(-1, 3, 3) volume = torch.linalg.det(cell).abs().unsqueeze(-1) diff --git a/mace/tools/train.py b/mace/tools/train.py index c7c17e136..3209650e1 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -68,7 +68,7 @@ def valid_err_log( ) elif ( log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_stress"] is not None + and "rmse_stress" in eval_metrics and eval_metrics["rmse_stress"] is not None ): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 @@ -78,7 +78,7 @@ def valid_err_log( ) elif ( log_errors == "PerAtomRMSEstressvirials" - and eval_metrics["rmse_virials_per_atom"] is not None + and "rmse_virials_per_atom" in eval_metrics and eval_metrics["rmse_virials_per_atom"] is not None ): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 @@ -88,7 +88,7 @@ def valid_err_log( ) elif ( log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_stress_per_atom"] is not None + and "mae_stress_per_atom" in eval_metrics and eval_metrics["mae_stress_per_atom"] is not None ): error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 @@ -98,7 +98,7 @@ def valid_err_log( ) elif ( log_errors == "PerAtomMAEstressvirials" - and eval_metrics["mae_virials_per_atom"] is not None + and "mae_virials_per_atom" in eval_metrics and eval_metrics["mae_virials_per_atom"] is not None ): error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 From 73f4a1d959ea569004ebf2fedd4b2169f4d6c517 Mon Sep 17 00:00:00 2001 From: vue1999 Date: Wed, 16 Apr 2025 01:07:56 +0100 Subject: [PATCH 2/2] print error tables for virials and formatting --- mace/tools/tables_utils.py | 4 ++++ mace/tools/train.py | 12 ++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mace/tools/tables_utils.py b/mace/tools/tables_utils.py index 04ff64014..fd0417ad3 100644 --- a/mace/tools/tables_utils.py +++ b/mace/tools/tables_utils.py @@ -148,6 +148,7 @@ def create_error_table( ) elif ( table_type == "PerAtomRMSEstressvirials" + and "rmse_stress" in metrics and metrics["rmse_stress"] is not None ): table.add_row( @@ -161,6 +162,7 @@ def create_error_table( ) elif ( table_type == "PerAtomRMSEstressvirials" + and "rmse_virials" in metrics and metrics["rmse_virials"] is not None ): table.add_row( @@ -174,6 +176,7 @@ def create_error_table( ) elif ( table_type == "PerAtomMAEstressvirials" + and "mae_stress" in metrics and metrics["mae_stress"] is not None ): table.add_row( @@ -187,6 +190,7 @@ def create_error_table( ) elif ( table_type == "PerAtomMAEstressvirials" + and "mae_virials" in metrics and metrics["mae_virials"] is not None ): table.add_row( diff --git a/mace/tools/train.py b/mace/tools/train.py index 3209650e1..67972f1cd 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -68,7 +68,8 @@ def valid_err_log( ) elif ( log_errors == "PerAtomRMSEstressvirials" - and "rmse_stress" in eval_metrics and eval_metrics["rmse_stress"] is not None + and "rmse_stress" in eval_metrics + and eval_metrics["rmse_stress"] is not None ): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 @@ -78,7 +79,8 @@ def valid_err_log( ) elif ( log_errors == "PerAtomRMSEstressvirials" - and "rmse_virials_per_atom" in eval_metrics and eval_metrics["rmse_virials_per_atom"] is not None + and "rmse_virials_per_atom" in eval_metrics + and eval_metrics["rmse_virials_per_atom"] is not None ): error_e = eval_metrics["rmse_e_per_atom"] * 1e3 error_f = eval_metrics["rmse_f"] * 1e3 @@ -88,7 +90,8 @@ def valid_err_log( ) elif ( log_errors == "PerAtomMAEstressvirials" - and "mae_stress_per_atom" in eval_metrics and eval_metrics["mae_stress_per_atom"] is not None + and "mae_stress_per_atom" in eval_metrics + and eval_metrics["mae_stress_per_atom"] is not None ): error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3 @@ -98,7 +101,8 @@ def valid_err_log( ) elif ( log_errors == "PerAtomMAEstressvirials" - and "mae_virials_per_atom" in eval_metrics and eval_metrics["mae_virials_per_atom"] is not None + and "mae_virials_per_atom" in eval_metrics + and eval_metrics["mae_virials_per_atom"] is not None ): error_e = eval_metrics["mae_e_per_atom"] * 1e3 error_f = eval_metrics["mae_f"] * 1e3