diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index dfeecd2..f5f9cab 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -152,14 +152,7 @@ def run_full_pipeline( categorical_embedding_dims=model_params["categorical_embedding_dims"], num_classes=model_params["num_classes"], attention_config=attention_config, - label_attention_config=( - LabelAttentionConfig( - n_head=attention_config.n_head, - num_classes=model_params["num_classes"], - ) - if label_attention_enabled - else None - ), + n_heads_label_attention=attention_config.n_head, ) # Create training config