diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index a4f2c55..e38bb97 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -50,11 +50,11 @@ class ModelConfig: """Base configuration class for text classifiers.""" embedding_dim: int + num_classes: int categorical_vocabulary_sizes: Optional[List[int]] = None categorical_embedding_dims: Optional[Union[List[int], int]] = None - num_classes: Optional[int] = None attention_config: Optional[AttentionConfig] = None - label_attention_config: Optional[LabelAttentionConfig] = None + n_heads_label_attention: Optional[int] = None def to_dict(self) -> Dict[str, Any]: return asdict(self) @@ -142,7 +142,7 @@ def __init__( self.embedding_dim = model_config.embedding_dim self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes self.num_classes = model_config.num_classes - self.enable_label_attention = model_config.label_attention_config is not None + self.enable_label_attention = model_config.n_heads_label_attention is not None if self.tokenizer.output_vectorized: self.text_embedder = None @@ -156,7 +156,10 @@ def __init__( embedding_dim=self.embedding_dim, padding_idx=tokenizer.padding_idx, attention_config=model_config.attention_config, - label_attention_config=model_config.label_attention_config, + label_attention_config=LabelAttentionConfig( + n_head=model_config.n_heads_label_attention, + num_classes=model_config.num_classes, + ), ) self.text_embedder = TextEmbedder( text_embedder_config=text_embedder_config, @@ -697,10 +700,6 @@ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassif # Reconstruct model_config model_config = ModelConfig.from_dict(metadata["model_config"]) - if isinstance(model_config.label_attention_config, dict): - model_config.label_attention_config = LabelAttentionConfig( - **model_config.label_attention_config - ) # Create instance instance = cls(