diff --git a/ScaFFold/worker.py b/ScaFFold/worker.py index 431c2eb..fcd1394 100644 --- a/ScaFFold/worker.py +++ b/ScaFFold/worker.py @@ -241,10 +241,14 @@ def main(kwargs_dict: dict = {}): total_train_time = train_data["epoch_duration"].sum() epochs = np.atleast_1d(train_data["epoch"]) total_epochs = int(epochs[-1]) - log.info( - f"Benchmark run at scale {config.problem_scale} complete. \n\ - Trained to >= 0.95 validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." - ) + if config.epochs == -1: + extra_msg = f"Trained to >= {config.target_dice} validation dice score in {total_train_time:.2f} seconds, {total_epochs} epochs." + else: + extra_msg = ( + f"Completed in {total_train_time:.2f} seconds, {total_epochs} epochs." + ) + + log.info(f"Benchmark run at scale {config.problem_scale} complete. \n{extra_msg}") # # Generate plots