diff --git a/configs/train/letter.json b/configs/train/letter.json index 576106b7..3fbac882 100644 --- a/configs/train/letter.json +++ b/configs/train/letter.json @@ -1,6 +1,6 @@ { - "experiment_name": "letter_data", - "best_metric": "validation/ndcg@20", + "experiment_name": "letter_tiger", + "best_metric": "loss", "train_epochs_num": 100, "dataset": { "type": "letter_full", @@ -39,21 +39,15 @@ }, "model": { "type": "tiger", - "rqvae_train_config_path": "../configs/train/rqvae_train_config.json", - "rqvae_checkpoint_path": "../checkpoints/rqvae_beauty_final_state.pth", - "embs_extractor_path": "../data/Beauty/rqvae/data_full.pt", "sequence_prefix": "item", - "predictions_prefix": "logits", - "positive_prefix": "labels", - "labels_prefix": "labels", "embedding_dim": 64, + "codebook_size": 256, + "num_positions": 200, "num_heads": 2, "num_encoder_layers": 2, "num_decoder_layers": 2, "dim_feedforward": 256, "dropout": 0.3, - "activation": "gelu", - "layer_norm_eps": 1e-9, "initializer_range": 0.02 }, "optimizer": { @@ -68,18 +62,28 @@ "type": "composite", "losses": [ { - "type": "ce", - "predictions_prefix": "logits", - "labels_prefix": "semantic.labels", + "type": "identity_map", + "predictions_prefix": "decoder_loss_1", "weight": 1.0, - "output_prefix": "semantic_loss" + "output_prefix": "decoder_loss_1" }, { - "type": "ce", - "predictions_prefix": "dedup.logits", - "labels_prefix": "dedup.labels", + "type": "identity_map", + "predictions_prefix": "decoder_loss_2", "weight": 1.0, - "output_prefix": "dedup_loss" + "output_prefix": "decoder_loss_2" + }, + { + "type": "identity_map", + "predictions_prefix": "decoder_loss_3", + "weight": 1.0, + "output_prefix": "decoder_loss_3" + }, + { + "type": "identity_map", + "predictions_prefix": "decoder_loss_4", + "weight": 1.0, + "output_prefix": "decoder_loss_4" } ], "output_prefix": "loss" @@ -91,94 +95,6 @@ "type": "metric", "on_step": 1, "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 1024, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 2048, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } } ] } diff --git a/modeling/loss/base.py b/modeling/loss/base.py index 8ec91326..ef51bbf7 100644 --- a/modeling/loss/base.py +++ b/modeling/loss/base.py @@ -141,7 +141,21 @@ def forward(self, inputs): inputs[self._output_prefix] = loss.cpu().item() return loss - + +class IdentityMapLoss(TorchLoss, config_name='identity_map'): + + def __init__(self, predictions_prefix, output_prefix=None): + super().__init__() + self._input_loss_key = predictions_prefix + self._output_prefix = output_prefix + + def forward(self, inputs): + loss = inputs[self._input_loss_key] + assert loss.dim() == 0, "Loss must be a scalar tensor" + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + return loss + class RqVaeLoss(TorchLoss, config_name='rqvae_loss'): def __init__(self, beta, output_prefix=None): diff --git a/modeling/models/base.py b/modeling/models/base.py index a1384384..736b9843 100644 --- a/modeling/models/base.py +++ b/modeling/models/base.py @@ -31,6 +31,13 @@ def _init_weights(self, initializer_range): a=-2 * initializer_range, b=2 * initializer_range ) + elif "bos_embedding" in key: + nn.init.trunc_normal_( + value.data, + std=initializer_range, + a=-2 * initializer_range, + b=2 * initializer_range, + ) else: raise ValueError(f'Unknown transformer weight: {key}') diff --git a/modeling/models/sasrec_semantic.py b/modeling/models/sasrec_semantic.py index 0230f41e..7cdca381 100644 --- a/modeling/models/sasrec_semantic.py +++ b/modeling/models/sasrec_semantic.py @@ -1,9 +1,10 @@ +import json + import torch -from .tiger import TigerModel -from models import SequentialTorchModel from torch import nn + +from models import SequentialTorchModel, RqVaeModel from utils import DEVICE, create_masked_tensor -from torch import nn class SasRecSemanticModel(SequentialTorchModel, config_name="sasrec_semantic"): @@ -67,9 +68,39 @@ def __init__( requires_grad=True, ) # len(self._codebook_sizes), codebook_size, embedding_dim + @classmethod + def init_rqvae(self, config): + rqvae_config = json.load(open(config["rqvae_train_config_path"])) + rqvae_config["model"]["should_init_codebooks"] = False + + rqvae_model = RqVaeModel.create_from_config(rqvae_config["model"]).to(DEVICE) + rqvae_model.load_state_dict( + torch.load(config["rqvae_checkpoint_path"], weights_only=True) + ) + rqvae_model.eval() + for param in rqvae_model.parameters(): + param.requires_grad = False + + codebook_sizes = rqvae_model.codebook_sizes + assert all([book_size == codebook_sizes[0] for book_size in codebook_sizes]) + + embs_extractor = torch.load(config["embs_extractor_path"], weights_only=False) + + embs_extractor = embs_extractor.sort_index() + + item_ids = embs_extractor.index.tolist() + assert item_ids == list(range(1, len(item_ids) + 1)) + + text_embeddings = torch.stack(embs_extractor["embeddings"].tolist()).to(DEVICE) + + semantic_ids, residuals = rqvae_model({"embeddings": text_embeddings}) + + return rqvae_model, semantic_ids, residuals, item_ids + + @classmethod def create_from_config(cls, config, **kwargs): - rqvae_model, semantic_ids, residuals, _ = TigerModel.init_rqvae(config) + rqvae_model, semantic_ids, residuals, _ = cls.init_rqvae(config) return cls( rqvae_model=rqvae_model, diff --git a/modeling/models/tiger.py b/modeling/models/tiger.py index 81d5825d..27db2442 100644 --- a/modeling/models/tiger.py +++ b/modeling/models/tiger.py @@ -1,489 +1,337 @@ -import json - import torch -from models.base import SequentialTorchModel -from rqvae_utils import CollisionSolver, SimplifiedTree -from torch import nn -from utils import DEVICE, create_masked_tensor, get_activation_function +import torch.nn as nn -from .rqvae import RqVaeModel +from models import TorchModel +from utils import get_activation_function, create_masked_tensor, DEVICE -class TigerModel(SequentialTorchModel, config_name="tiger"): +class TigerModel(TorchModel, config_name="tiger"): def __init__( - self, - rqvae_model, - item_id_to_semantic_id, - item_id_to_residual, - solver, - sequence_prefix, - pred_prefix, - positive_prefix, - labels_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_encoder_layers, - num_decoder_layers, - dim_feedforward, - dropout=0.0, - activation="relu", - layer_norm_eps=1e-9, - initializer_range=0.02, + self, + sequence_prefix, + embedding_dim, + codebook_size, + num_positions, + num_heads, + num_encoder_layers, + num_decoder_layers, + dim_feedforward, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_encoder_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=True, - ) + super().__init__() self._sequence_prefix = sequence_prefix - self._pred_prefix = pred_prefix - self._positive_prefix = positive_prefix - self._labels_prefix = labels_prefix + self._embedding_dim = embedding_dim + self._codebook_size = codebook_size + self._num_positions = num_positions + self._num_heads = num_heads + self._num_encoder_layers = num_encoder_layers + self._num_decoder_layers = num_decoder_layers + self._dim_feedforward = dim_feedforward + self._dropout = nn.Dropout(dropout) + self._layer_norm_eps = layer_norm_eps + + self._sem_id_len = 4 + + self.position_embeddings = nn.Embedding(num_embeddings=self._num_positions, embedding_dim=self._embedding_dim, + device=DEVICE) + + self.sem_id_position_embeddings = nn.Embedding(num_embeddings=self._sem_id_len, + embedding_dim=self._embedding_dim, device=DEVICE) + + self.bos_embedding = nn.Parameter(torch.randn(self._embedding_dim, device=DEVICE)) + + self.codebook_embeddings = nn.Embedding(num_embeddings=(self._codebook_size * self._sem_id_len), + embedding_dim=self._embedding_dim, device=DEVICE) + + transformer_encoder_layer = nn.TransformerEncoderLayer( + d_model=self._embedding_dim, + nhead=self._num_heads, + dim_feedforward=self._dim_feedforward, + dropout=dropout, + activation=get_activation_function(activation), + layer_norm_eps=self._layer_norm_eps, + batch_first=True, + device=DEVICE + ) transformer_decoder_layer = nn.TransformerDecoderLayer( - d_model=embedding_dim, - nhead=num_heads, + d_model=self._embedding_dim, + nhead=self._num_heads, dim_feedforward=dim_feedforward, dropout=dropout, activation=get_activation_function(activation), - layer_norm_eps=layer_norm_eps, + layer_norm_eps=self._layer_norm_eps, batch_first=True, + device=DEVICE ) self._decoder = nn.TransformerDecoder( transformer_decoder_layer, num_decoder_layers ) + self._encoder = nn.TransformerEncoder(transformer_encoder_layer, num_encoder_layers) - self._decoder_layernorm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps) - self._decoder_dropout = nn.Dropout(dropout) - - self._solver: CollisionSolver = solver - - self._codebook_sizes = rqvae_model.codebook_sizes - self._bos_weight = nn.Parameter( - torch.nn.init.trunc_normal_( - torch.zeros(embedding_dim), - std=initializer_range, - a=-2 * initializer_range, - b=2 * initializer_range, - ), - requires_grad=True, # TODOPK added for bos - ) - - self._codebook_embeddings = nn.Embedding( - num_embeddings=len(self._codebook_sizes) + 2, embedding_dim=embedding_dim - ) # + 2 for bos token & residual + # self._layernorm = nn.LayerNorm(self._embedding_dim, eps=layer_norm_eps) self._init_weights(initializer_range) - self._codebook_item_embeddings_stacked = nn.Parameter( - torch.stack([codebook for codebook in rqvae_model.codebooks]), - requires_grad=True, - ) # TODOPK (ask is it ok to have separate codebooks and _item_id_to_semantic_embedding) - - self._item_id_to_semantic_id = item_id_to_semantic_id - self._item_id_to_residual = item_id_to_residual - - self._item_id_to_semantic_embedding = nn.Parameter( - self.get_init_item_embeddings(item_id_to_semantic_id, item_id_to_residual), - requires_grad=True, - ) - - self._trie = SimplifiedTree(self._codebook_item_embeddings_stacked) + def _embed_semantic_tokens(self, sem_ids: torch.LongTensor) -> torch.Tensor: + """ + sem_ids: (N,) + embeds: (N, embedding_dim) + """ + + positions = torch.arange(sem_ids.size(0), device=DEVICE) % self._sem_id_len # (N,) + offsets = positions * self._codebook_size + assert offsets.shape == sem_ids.shape + return self.codebook_embeddings(offsets + sem_ids) # (N , embedding_dim) + + def _get_position_embeddings(self, mask: torch.BoolTensor) -> torch.Tensor: + batch_size, seq_len = mask.shape + position_ids = torch.arange(seq_len, device=DEVICE).unsqueeze(0).expand(batch_size, -1) + pos_emb = self.position_embeddings(position_ids) # (batch_size, max_seq_len, embedding_dim) + pos_emb[~mask] = 0.0 + return pos_emb # (batch_size, max_seq_len, embedding_dim) + + def _get_sem_ids_position_embeddings(self, mask: torch.BoolTensor) -> torch.Tensor: + batch_size, seq_len = mask.shape + position_ids = torch.arange(seq_len, device=DEVICE).unsqueeze(0).expand(batch_size, -1) + sem_pos_ids = position_ids.remainder(self._sem_id_len) + sem_pos_emb = self.sem_id_position_embeddings(sem_pos_ids) # (batch_size, max_seq_len, embedding_dim) + sem_pos_emb[~mask] = 0.0 + return sem_pos_emb # (batch_size, max_seq_len, embedding_dim) + + def _get_last_sem_ids_mask(self, all_sample_lengths: torch.Tensor) -> torch.Tensor: + """Создает маску для последних sem_id_len токенов каждой последовательности""" + total_tokens = all_sample_lengths.sum().item() + tgt_end_idx = torch.cumsum(all_sample_lengths, dim=0) + tgt_start_idx = tgt_end_idx - self._sem_id_len + + mask_flat_extended = torch.zeros(total_tokens + 1, dtype=torch.int, device=DEVICE) + mask_flat_extended[tgt_start_idx] += 1 + mask_flat_extended[tgt_end_idx] -= 1 + + mask_flat = torch.cumsum(mask_flat_extended, dim=0)[:total_tokens] + return mask_flat.bool() + + def _prepare_sem_id_batch( + self, + embeddings_flat: torch.Tensor, # (total_tokens, embedding_dim) + lengths: torch.LongTensor # (batch_size,) + ): + batch_size = lengths.size(0) + sem_id_len = self._sem_id_len - self._trie.build_tree_structure( - item_id_to_semantic_id.to(DEVICE), - item_id_to_residual.to(DEVICE), - torch.arange(1, len(item_id_to_semantic_id) + 1).to(DEVICE), - sum_with_residuals=False, - ) + # маска последних sem_id_len каждого батча + decoder_mask_flat = self._get_last_sem_ids_mask(lengths) - @classmethod - def init_rqvae(cls, config): - rqvae_config = json.load(open(config["rqvae_train_config_path"])) - rqvae_config["model"]["should_init_codebooks"] = False + # разделение плоского тензора + encoder_emb_flat = embeddings_flat[~decoder_mask_flat] + decoder_emb_flat = embeddings_flat[decoder_mask_flat] - rqvae_model = RqVaeModel.create_from_config(rqvae_config["model"]).to(DEVICE) - rqvae_model.load_state_dict( - torch.load(config["rqvae_checkpoint_path"], weights_only=True) - ) - rqvae_model.eval() - for param in rqvae_model.parameters(): - param.requires_grad = False + # эмбеддинги уже с добавленной размерностью max_seq_len и позиционными + encoder_embeddings, encoder_mask = self._create_encoder_tensors(encoder_emb_flat, lengths, sem_id_len) + decoder_embeddings = self._create_decoder_tensors(decoder_emb_flat, batch_size, sem_id_len) - codebook_sizes = rqvae_model.codebook_sizes - assert all([book_size == codebook_sizes[0] for book_size in codebook_sizes]) + # BOS + encoder_embeddings, encoder_mask = self._add_bos_to_encoder(encoder_embeddings, encoder_mask, batch_size) + decoder_embeddings = self._add_bos_to_decoder(decoder_embeddings, batch_size) - embs_extractor = torch.load(config["embs_extractor_path"], weights_only=False) + return encoder_embeddings, encoder_mask, decoder_embeddings - embs_extractor = embs_extractor.sort_index() + def _create_encoder_tensors(self, encoder_emb_flat, lengths, sem_id_len): + encoder_lengths = lengths - sem_id_len + encoder_embeddings, encoder_mask = create_masked_tensor(encoder_emb_flat, encoder_lengths) - item_ids = embs_extractor.index.tolist() - assert item_ids == list(range(1, len(item_ids) + 1)) + # позиционные эмбеддинги + pos_emb = self._get_position_embeddings(encoder_mask) + sem_pos_emb = self._get_sem_ids_position_embeddings(encoder_mask) + encoder_embeddings += pos_emb + sem_pos_emb - text_embeddings = torch.stack(embs_extractor["embeddings"].tolist()).to(DEVICE) + return encoder_embeddings, encoder_mask - semantic_ids, residuals = rqvae_model({"embeddings": text_embeddings}) + def _create_decoder_tensors(self, decoder_emb_flat, batch_size, sem_id_len): + decoder_embeddings = decoder_emb_flat.view(batch_size, sem_id_len, -1) - return rqvae_model, semantic_ids, residuals, item_ids + # позиционные эмбеддинги (только семантические) + sem_pos_ids = torch.arange(sem_id_len, device=DEVICE).expand(batch_size, -1) + sem_pos_emb = self.sem_id_position_embeddings(sem_pos_ids) + decoder_embeddings += sem_pos_emb - @classmethod - def create_from_config(cls, config, **kwargs): - rqvae_model, semantic_ids, residuals, item_ids = cls.init_rqvae(config) + return decoder_embeddings - solver = CollisionSolver( - emb_dim=residuals.shape[1], - sem_id_len=len(rqvae_model.codebook_sizes), - codebook_size=rqvae_model.codebook_sizes[0], - ) - solver.create_query_candidates_dict( - torch.tensor(item_ids), semantic_ids, residuals - ) + def _add_bos_to_encoder(self, encoder_embeddings, encoder_mask, batch_size): + bos = self.bos_embedding.view(1, 1, -1).expand(batch_size, 1, -1) + new_encoder_embeddings = torch.cat([bos, encoder_embeddings], dim=1) + new_mask = torch.cat([ + torch.ones(batch_size, 1, dtype=torch.bool, device=DEVICE), + encoder_mask + ], dim=1) + return new_encoder_embeddings, new_mask - return cls( - rqvae_model=rqvae_model, - item_id_to_semantic_id=semantic_ids, - item_id_to_residual=residuals, - solver=solver, - sequence_prefix=config["sequence_prefix"], - pred_prefix=config["predictions_prefix"], - positive_prefix=config["positive_prefix"], - labels_prefix=config["labels_prefix"], - num_items=rqvae_model.codebook_sizes[0], # unused - max_sequence_length=kwargs["max_sequence_length"], - embedding_dim=config["embedding_dim"], - num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), - num_encoder_layers=config["num_encoder_layers"], - num_decoder_layers=config["num_decoder_layers"], - dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), - dropout=config.get("dropout", 0.0), - initializer_range=config.get("initializer_range", 0.02), - ) + def _add_bos_to_decoder(self, decoder_embeddings, batch_size): + bos = self.bos_embedding.view(1, 1, -1).expand(batch_size, 1, -1) + return torch.cat([bos, decoder_embeddings], dim=1) - # semantic ids come with dedup token def forward(self, inputs): all_sample_events = inputs[ - "{}.ids".format(self._sequence_prefix) + "semantic_{}.ids".format(self._sequence_prefix) ] # (all_batch_events) all_sample_lengths = inputs[ - "{}.length".format(self._sequence_prefix) - ] # (batch_size) - - encoder_embeddings, encoder_mask = self._apply_sequential_encoder( - all_sample_events, all_sample_lengths * (len(self._codebook_sizes) + 1) - ) # (batch_size, enc_seq_len, embedding_dim), (batch_size, enc_seq_len) - - if self.training: - label_events = inputs["{}.ids".format(self._positive_prefix)] - label_lengths = inputs["{}.length".format(self._positive_prefix)] - - tgt_embeddings = self.get_item_embeddings( - label_events - ) # (all_batch_events, embedding_dim) - - decoder_outputs = self._apply_decoder( - tgt_embeddings, - label_lengths * (len(self._codebook_sizes) + 1), - encoder_embeddings, - encoder_mask, - ) # (batch_size, label_len, embedding_dim) - - decoder_prefix_scores = torch.einsum( - "bsd,scd->bsc", - decoder_outputs[:, :-1, :], - self._codebook_item_embeddings_stacked, - ) - - decoder_output_residual = decoder_outputs[:, -1, :] - - semantic_ids = self._item_id_to_semantic_id[ - label_events - 1 - ] # len(events), len(codebook_sizes) - true_residuals = self._item_id_to_residual[label_events - 1] - - true_info = self._solver.get_true_dedup_tokens(semantic_ids, true_residuals) - pred_info = self._solver.get_pred_scores( - semantic_ids, decoder_output_residual - ) - - return { - "logits": decoder_prefix_scores.reshape( - -1, decoder_prefix_scores.shape[2] - ), - "semantic.labels.ids": semantic_ids.reshape(-1), - "dedup.logits": pred_info["pred_scores"], - "dedup.labels.ids": true_info["true_dedup_tokens"], - } - # else: - # semantic_ids, tgt_embeddings = self._apply_decoder_autoregressive( - # encoder_embeddings, encoder_mask - # ) # (batch_size, len(self._codebook_sizes) (bos, residual)), (batch_size, len(self._codebook_sizes) + 2 (bos, residual), embedding_dim) - # TODOPK - # # 1 4 6 -> lookup -> sum = emb (last embedding) # bs, embedding_dim - # # take all embedings (from stacked) # all_items, embedding_dim - # # take from sasrec eval (indices + 1) - # # guarantee that all items are in correct order - - # residuals = tgt_embeddings[:, -1, :] - # semantic_ids = semantic_ids.to(torch.int64) - - # item_ids = self._trie.query(semantic_ids, items_to_query=20) - - # return item_ids - # TODOPK - # uid -> hash (murmurhash32) -> modulo (2000) -> get_embedding -> prepend - # first iteration -> for each user get embedding - - else: # eval mode - semantic_ids, tgt_embeddings = self._apply_decoder_autoregressive( - encoder_embeddings, encoder_mask - ) # (batch_size, len(self._codebook_sizes)), (batch_size, len(self._codebook_sizes) + 2, embedding_dim) - - embs = [] - for semantic_id in semantic_ids: - cur_emb = [] - for idx, codebook_id in enumerate(semantic_id): - cur_emb.append( - self._codebook_item_embeddings_stacked[idx][codebook_id.item()] - ) - embs.append(torch.stack(cur_emb)) - - last_embeddings = torch.stack(embs).sum(dim=1) # batch_size, embedding_dim - - candidate_scores = torch.einsum( - "bd,nd->bn", - last_embeddings, - self._item_id_to_semantic_embedding.sum(dim=1), - ) # (batch_size, num_items) - - _, indices = torch.topk( - candidate_scores, k=20, dim=-1, largest=True - ) # (batch_size, 20) - - return indices + 1 # tensors are 0 indexed - - def _apply_decoder( - self, tgt_embeddings, label_lengths, encoder_embeddings, encoder_mask - ): - tgt_embeddings, tgt_mask = create_masked_tensor( - data=tgt_embeddings, lengths=label_lengths - ) # (batch_size, dec_seq_len, embedding_dim), (batch_size, dec_seq_len) - - batch_size = tgt_embeddings.shape[0] - bos_embeddings = self._bos_weight.unsqueeze(0).expand( - batch_size, 1, -1 - ) # (batch_size, 1, embedding_dim) - - tgt_embeddings = torch.cat( - [bos_embeddings, tgt_embeddings[:, :-1, :]], dim=1 - ) # remove residual by using :-1 - - label_len = tgt_mask.shape[1] - - assert label_len == len(self._codebook_sizes) + 1 - - position_embeddings = self._decoder_pos_embeddings(label_lengths, tgt_mask) - assert torch.allclose(position_embeddings[~tgt_mask], tgt_embeddings[~tgt_mask]) - - tgt_embeddings = tgt_embeddings + position_embeddings + "semantic_{}.length".format(self._sequence_prefix) + ] # (batch_size) + assert all_sample_events.shape[0] == sum(all_sample_lengths) + embeddings_flat = self._embed_semantic_tokens(all_sample_events) - # TODOPK remove layernorm & dropout (for inference) - # tgt_embeddings = self._decoder_layernorm( - # tgt_embeddings - # ) # (batch_size, dec_seq_len, embedding_dim) - # tgt_embeddings = self._decoder_dropout( - # tgt_embeddings - # ) # (batch_size, dec_seq_len, embedding_dim) + assert embeddings_flat.shape[0] == sum(all_sample_lengths) - tgt_embeddings[~tgt_mask] = 0 - - causal_mask = ( - torch.tril(torch.ones(label_len, label_len)).bool().to(DEVICE) - ) # (dec_seq_len, dec_seq_len) - - decoder_outputs = self._decoder( - tgt=tgt_embeddings, - memory=encoder_embeddings, - tgt_mask=~causal_mask, - memory_key_padding_mask=~encoder_mask, - ) # (batch_size, dec_seq_len, embedding_dim) - - return decoder_outputs - - def _decoder_pos_embeddings(self, lengths, mask): - def codebook_lambda(x): - non_bos = x < len(self._codebook_sizes) - x[non_bos] = (len(self._codebook_sizes) - 1) - x[non_bos] - return x # 3, 0, 1, 2, 3, 0, 1, 2 ... len(self._codebook_sizes) = 3 for bos - - codebook_embeddings = self._get_position_embeddings( - lengths, mask, codebook_lambda, self._codebook_embeddings + (encoder_input_emb, # (batch_size, seq_len - sem_id_len + 1, embedding_dim) + encoder_input_mask, # (batch_size, seq_len - sem_id_len + 1) + decoder_input_embs) = ( # (batch_size, sem_id_len + 1, embedding_dim) + self._prepare_sem_id_batch(embeddings_flat, all_sample_lengths) ) - return codebook_embeddings + after_encoder_emb, after_encoder_mask = self._apply_encoder(encoder_input_emb, encoder_input_mask) - def _apply_decoder_autoregressive(self, encoder_embeddings, encoder_mask): - batch_size = encoder_embeddings.shape[0] - embedding_dim = encoder_embeddings.shape[2] - - tgt_embeddings = ( - self._bos_weight.unsqueeze(0) - .unsqueeze(0) - .expand(batch_size, 1, embedding_dim) - ) - - semantic_ids = torch.tensor([], device=DEVICE, dtype=torch.int64) + if self.training: + # последние sem ids + target_tokens_mask = self._get_last_sem_ids_mask(all_sample_lengths) + target_tokens = all_sample_events[target_tokens_mask].view(-1, self._sem_id_len) # (batch_size, sem_id_len) - for step in range(len(self._codebook_sizes) + 1): # semantic_id_seq + residual - index = len(self._codebook_sizes) if step == 0 else step - 1 + # Подготовка входа декодера (BOS + первые 3 токена) + tgt = decoder_input_embs[:, :-1, :] # (batch_size, sem_id_len, embedding_dim) + tgt_mask = nn.Transformer.generate_square_subsequent_mask( + tgt.size(1), device=DEVICE + ) # (sem_id_len, sem_id_len) - last_position_embedding = self._codebook_embeddings( - torch.full((batch_size,), index, device=DEVICE) - ) + decoder_output = self._decoder( + tgt=tgt, + memory=after_encoder_emb, + tgt_mask=tgt_mask, + memory_key_padding_mask=~after_encoder_mask + # должно быть True для паддинга и False для реальных позиций + ) # (batch_size, sem_id_len, embedding_dim) + + losses = [] # [(batch_size, codebook_size) * sem_id_len] + scores = [] # [(batch_size, codebook_size) * sem_id_len] + argmaxes = [] # [(batch_size, ) * sem_id_len] + + for i in range(self._sem_id_len): + weights = self.codebook_embeddings.weight[i * self._codebook_size: (i + 1) * self._codebook_size] + logits = torch.matmul( + decoder_output[:, i, :], weights.t() + ) # (batch_size, codebook_size) + scores.append(logits) + + pred_tokens = torch.argmax(logits, dim=-1) # (batch_size,) + argmaxes.append(pred_tokens) + + loss = nn.functional.cross_entropy(logits, target_tokens[:, i]) + losses.append(loss) - assert last_position_embedding.shape == tgt_embeddings[:, -1, :].shape - assert tgt_embeddings.shape == torch.Size([batch_size, step + 1, embedding_dim]) + return { + "decoder_loss_1": losses[0], # (1, ) + "decoder_loss_2": losses[1], # (1, ) + "decoder_loss_3": losses[2], # (1, ) + "decoder_loss_4": losses[3], # (1, ) + + "decoder_scores_1": scores[0], # (batch_size, codebook_size) + "decoder_scores_2": scores[1], # (batch_size, codebook_size) + "decoder_scores_3": scores[2], # (batch_size, codebook_size) + "decoder_scores_4": scores[3], # (batch_size, codebook_size) + + "decoder_argmax_1": argmaxes[0], # (batch_size, ) + "decoder_argmax_2": argmaxes[1], # (batch_size, ) + "decoder_argmax_3": argmaxes[2], # (batch_size, ) + "decoder_argmax_4": argmaxes[3], # (batch_size, ) + } + else: + batch_size = encoder_input_emb.size(0) - curr_step_embeddings = tgt_embeddings.clone() - curr_step_embeddings[:, -1, :] = ( - tgt_embeddings[:, -1, :] + last_position_embedding - ) - assert torch.allclose(tgt_embeddings[:, :-1, :], curr_step_embeddings[:, :-1, :]) - tgt_embeddings = curr_step_embeddings + tgt = self.bos_embedding.view(1, 1, -1).expand(batch_size, 1, -1) # (batch_size, 1, embedding_dim) - # curr_embeddings[:, -1, :] = self._decoder_layernorm(curr_embeddings[:, -1, :]) - # curr_embeddings[:, -1, :] = self._decoder_dropout(curr_embeddings[:, -1, :]) + memory_key_padding_mask = ~after_encoder_mask - causal_mask = ( - torch.tril(torch.ones(step + 1, step + 1)).bool().to(DEVICE) - ) # (dec_seq_len, dec_seq_len) + argmaxes = [] + scores = [] - decoder_output = self._decoder( - tgt=tgt_embeddings, - memory=encoder_embeddings, - tgt_mask=~causal_mask, - memory_key_padding_mask=~encoder_mask, - ) - - # TODOPK add assert for all except last layer (check if only last layer changes) - # TODOPK check decoder output for several outputs - # TODOPK ASK it is not true? - # assert that prelast items don't change - # assert decoder changes only last index in dim = 1 - - next_token_embedding = decoder_output[ - :, -1, : - ] # batch_size x embedding_dim - - if step < len(self._codebook_sizes): - codebook = self._codebook_item_embeddings_stacked[ - step - ] # codebook_size x embedding_dim - closest_semantic_ids = torch.argmax( - torch.einsum("bd,cd->bc", next_token_embedding, codebook), dim=1 - ) # batch_size - semantic_ids = torch.cat( - [semantic_ids, closest_semantic_ids.unsqueeze(1)], dim=1 - ) # batch_size x (step + 1) - next_token_embedding = codebook[ - closest_semantic_ids - ] # batch_size x embedding_dim - - tgt_embeddings = torch.cat( - [tgt_embeddings, next_token_embedding.unsqueeze(1)], dim=1 - ) - - return semantic_ids, tgt_embeddings - - def get_item_embeddings(self, events): - embs = self._item_id_to_semantic_embedding[ - events - 1 - ] # len(events), len(self._codebook_sizes) + 1, embedding_dim - return embs.reshape(-1, self._embedding_dim) - - def get_init_item_embeddings(self, item_id_to_semantic_id, item_id_to_residual): - result = [] - for semantic_id in item_id_to_semantic_id: - item_repr = [] - for codebook_idx, codebook_id in enumerate(semantic_id): - item_repr.append( - self._codebook_item_embeddings_stacked[codebook_idx][codebook_id] - ) - result.append(torch.stack(item_repr)) - - semantic_embeddings = torch.stack( - result - ) # len(events), len(codebook_sizes), embedding_dim - - residual = item_id_to_residual.unsqueeze(1) - - # get true item embeddings - item_embeddings = torch.cat( - [semantic_embeddings, residual], dim=1 - ) # len(events), len(self._codebook_sizes) + 1, embedding_dim - - return item_embeddings - - def _encoder_pos_embeddings(self, lengths, mask): - def position_lambda(x): - return x // ( - len(self._codebook_sizes) + 1 - ) # 5 5 5 5 4 4 4 4 ..., +1 for residual - - position_embeddings = self._get_position_embeddings( - lengths, mask, position_lambda, self._position_embeddings - ) + for step in range(self._sem_id_len): + tgt_mask = nn.Transformer.generate_square_subsequent_mask( + tgt.size(1), device=DEVICE + ) # (L, L) - def codebook_lambda(x): - x = len(self._codebook_sizes) - x % (len(self._codebook_sizes) + 1) - x[x == len(self._codebook_sizes)] = len(self._codebook_sizes) + 1 - # 0 1 2 4 0 1 2 4 ... # len(self._codebook_sizes) + 1 = 4 for residual - return x + decoder_output = self._decoder( + tgt=tgt, + memory=after_encoder_emb, + tgt_mask=tgt_mask, + memory_key_padding_mask=memory_key_padding_mask + ) # (batch_size, L, embedding_dim) - codebook_embeddings = self._get_position_embeddings( - lengths, mask, codebook_lambda, self._codebook_embeddings - ) + last_output = decoder_output[:, -1:, :] # (batch_size, 1, embedding_dim) - return position_embeddings + codebook_embeddings + weights = self.codebook_embeddings.weight[step * self._codebook_size: (step + 1) * self._codebook_size] + logits = torch.matmul( + last_output, + weights.t() + ).squeeze(1) # (batch_size, codebook_size) - def _get_position_embeddings(self, lengths, mask, position_lambda, embedding_layer): - batch_size = mask.shape[0] - seq_len = mask.shape[1] + scores.append(logits) + pred_token = torch.argmax(logits, dim=-1) # (batch_size,) + argmaxes.append(pred_token) - positions = ( - torch.arange(start=seq_len - 1, end=-1, step=-1, device=DEVICE)[None] - .tile([batch_size, 1]) - .long() - ) # (batch_size, seq_len) - positions_mask = positions < lengths[:, None] # (batch_size, max_seq_len) + if step < self._sem_id_len - 1: + next_embed = self.codebook_embeddings( + step * self._codebook_size + pred_token) # (batch_size, embedding_dim) - positions = positions[positions_mask] # (all_batch_events) - # 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 7 6 5 4 3 2 1 0 ... + pos_emb = self.sem_id_position_embeddings( + torch.tensor([step], device=DEVICE) + ).expand(batch_size, -1) + next_embed += pos_emb - positions = position_lambda(positions) # (all_batch_events) + next_embed = next_embed.unsqueeze(1) # (batch_size, 1, embedding_dim) + tgt = torch.cat([tgt, next_embed], dim=1) - # print(f"{positions.tolist()[:20]=}") + return { + "decoder_scores_1": scores[0], + "decoder_scores_2": scores[1], + "decoder_scores_3": scores[2], + "decoder_scores_4": scores[3], + + "decoder_argmax_1": argmaxes[0], + "decoder_argmax_2": argmaxes[1], + "decoder_argmax_3": argmaxes[2], + "decoder_argmax_4": argmaxes[3], + } - assert (positions >= 0).all() and ( - positions < embedding_layer.num_embeddings - ).all() + def _apply_encoder( + self, + embeddings, # (batch_size, max_seq_len, embedding_dim) + mask, # (batch_size, max_seq_len) + ): - position_embeddings = embedding_layer( - positions - ) # (all_batch_events, embedding_dim) + assert embeddings.shape[0] == mask.shape[0] + assert embeddings.shape[1] == mask.shape[1] - position_embeddings, _ = create_masked_tensor( - data=position_embeddings, lengths=lengths + embeddings = self._encoder( + src=embeddings, src_key_padding_mask=~mask ) # (batch_size, seq_len, embedding_dim) - return position_embeddings + return embeddings, mask + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + sequence_prefix=config["sequence_prefix"], + embedding_dim=config["embedding_dim"], + codebook_size=config["codebook_size"], + num_positions=config["num_positions"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_encoder_layers=config["num_encoder_layers"], + num_decoder_layers=config["num_decoder_layers"], + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), + ) diff --git a/modeling/utils/NewTigerTest.py b/modeling/utils/NewTigerTest.py new file mode 100644 index 00000000..563a5141 --- /dev/null +++ b/modeling/utils/NewTigerTest.py @@ -0,0 +1,195 @@ +import unittest + +import torch +from models.tiger import TigerModel +from utils import DEVICE, create_masked_tensor + + +def create_model(): + return TigerModel( + sequence_prefix="sequence", + embedding_dim=64, + codebook_size=256, + num_positions=200, + num_heads=1, + num_encoder_layers=1, + num_decoder_layers=1, + dim_feedforward=256 + ) + + +class MyTestCase(unittest.TestCase): + def test_get_last_sem_ids_mask(self): + model = create_model() + lengths = torch.tensor([5, 7, 3], device=DEVICE) + + mask = model._get_last_sem_ids_mask(lengths) + total_tokens = lengths.sum().item() + + assert mask.shape == (total_tokens,) + + expected_positions = [ + [1, 2, 3, 4], # Для длины 5 индексы 1-4 + [3, 4, 5, 6], # Для длины 7 индексы 3-6 + [0, 1, 2] # Для длины 3 все токены + ] + flat_expected = torch.tensor([idx for sublist in expected_positions for idx in sublist], device=DEVICE) + cum_lengths = torch.cat([ + torch.tensor([0], device=lengths.device), + lengths.cumsum(0)[:-1] + ]) + offsets = torch.repeat_interleave(cum_lengths, lengths) + indices = torch.arange(total_tokens, device=lengths.device) + flat_actual = indices - offsets + assert torch.sum(mask).item() == len(flat_expected) + assert torch.all(flat_actual[mask] == flat_expected) + + def test_embed_semantic_tokens(self): + model = create_model() + sem_ids = torch.tensor([1, 3, 5, 7, 2, 4, 6, 8], device=DEVICE) + embeddings = model._embed_semantic_tokens(sem_ids) + assert embeddings.shape == (8, model._embedding_dim) + + # разные кодбуки используются для разных позиций + for i in range(8): + full_index = (torch.arange(model._sem_id_len, device=DEVICE)[i % model._sem_id_len] * model._codebook_size + + sem_ids[i]) + expected_embed = model.codebook_embeddings(full_index) + assert torch.allclose(embeddings[i], expected_embed) + + def test_position_embeddings(self): + model = create_model() + mask = torch.BoolTensor([ + [True, True, False], + [True, False, False] + ]) + + pos_emb = model._get_position_embeddings(mask) + + assert pos_emb.shape == (2, 3, model._embedding_dim) + + assert torch.all(pos_emb[0, 2] == 0) + assert torch.all(pos_emb[1, 1:] == 0) + + def test_forward_training(self): + model = create_model() + model.train() + + inputs = { + "sequence.ids": torch.tensor([1, 2, 3, 4, 5, 1, 2, 3, 10, 12, 1, 16], device=DEVICE), + "sequence.length": torch.tensor([8, 4], device=DEVICE), + } + + outputs = model(inputs) + + assert "decoder_loss_1" in outputs + assert "decoder_scores_4" in outputs + assert "decoder_argmax_3" in outputs + + assert outputs["decoder_scores_1"].shape == (2, 256) + assert outputs["decoder_argmax_4"].shape == (2,) + + def test_autoregressive_decoder(self): + model = create_model() + model.eval() + + inputs = { + "sequence.ids": torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=DEVICE), + "sequence.length": torch.tensor([8], device=DEVICE), + } + + outputs = model(inputs) + + assert outputs["decoder_argmax_1"].shape == (1,) + assert outputs["decoder_scores_4"].shape == (1, 256) + + assert outputs["decoder_argmax_1"].item() in range(256) + assert outputs["decoder_argmax_4"].item() in range(256) + + def test_prepare_sem_id(self): + model = create_model() + sem_embs_flat = torch.randn(20, model._embedding_dim) + lengths = torch.tensor([12, 8]) + + encoder_emb, encoder_mask, decoder_emb = model._prepare_sem_id_batch(sem_embs_flat, lengths) + sem_embs, _ = create_masked_tensor(sem_embs_flat, lengths) + + # Проверка что позиционные эмбеддинги ставятся корректно для декодера + + assert torch.all(decoder_emb[:, 0] == model.bos_embedding) # роверка что первый токен это bos + assert torch.all(torch.isclose(decoder_emb[0, 0], model.bos_embedding)) + assert torch.all(torch.isclose(decoder_emb[0, 1], + sem_embs[0, lengths[0] - 4] + model.sem_id_position_embeddings(torch.tensor(0)))) + assert torch.all(torch.isclose(decoder_emb[0, 2], + sem_embs[0, lengths[0] - 3] + model.sem_id_position_embeddings(torch.tensor(1)))) + assert torch.all(torch.isclose(decoder_emb[0, 3], + sem_embs[0, lengths[0] - 2] + model.sem_id_position_embeddings(torch.tensor(2)))) + assert torch.all(torch.isclose(decoder_emb[0, 4], + sem_embs[0, lengths[0] - 1] + model.sem_id_position_embeddings(torch.tensor(3)))) + assert torch.all(torch.isclose(decoder_emb[0, 0], model.bos_embedding)) + assert torch.all(torch.isclose(decoder_emb[1, 1], + sem_embs[1, lengths[1] - 4] + model.sem_id_position_embeddings(torch.tensor(0)))) + assert torch.all(torch.isclose(decoder_emb[1, 2], + sem_embs[1, lengths[1] - 3] + model.sem_id_position_embeddings(torch.tensor(1)))) + assert torch.all(torch.isclose(decoder_emb[1, 3], + sem_embs[1, lengths[1] - 2] + model.sem_id_position_embeddings(torch.tensor(2)))) + assert torch.all(torch.isclose(decoder_emb[1, 4], + sem_embs[1, lengths[1] - 1] + model.sem_id_position_embeddings(torch.tensor(3)))) + + # Проверка что позиционные эмбеддинги ставятся корректно для энкодера + assert torch.all(encoder_emb[:, 0] == model.bos_embedding) + assert torch.all(torch.isclose(encoder_emb[0, 0], model.bos_embedding)) + assert torch.all(torch.isclose(encoder_emb[0, 1], sem_embs[0, 0] + model.position_embeddings( + torch.tensor(0)) + model.sem_id_position_embeddings(torch.tensor(0)))) + assert torch.all(torch.isclose(encoder_emb[0, 2], sem_embs[0, 1] + model.position_embeddings( + torch.tensor(1)) + model.sem_id_position_embeddings(torch.tensor(1)))) + assert torch.all(torch.isclose(encoder_emb[0, 3], sem_embs[0, 2] + model.position_embeddings( + torch.tensor(2)) + model.sem_id_position_embeddings(torch.tensor(2)))) + assert torch.all(torch.isclose(encoder_emb[0, 4], sem_embs[0, 3] + model.position_embeddings( + torch.tensor(3)) + model.sem_id_position_embeddings(torch.tensor(3)))) + assert torch.all(torch.isclose(encoder_emb[0, 5], sem_embs[0, 4] + model.position_embeddings( + torch.tensor(4)) + model.sem_id_position_embeddings(torch.tensor(0)))) + assert torch.all(torch.isclose(encoder_emb[0, 6], sem_embs[0, 5] + model.position_embeddings( + torch.tensor(5)) + model.sem_id_position_embeddings(torch.tensor(1)))) + assert torch.all(torch.isclose(encoder_emb[0, 7], sem_embs[0, 6] + model.position_embeddings( + torch.tensor(6)) + model.sem_id_position_embeddings(torch.tensor(2)))) + assert torch.all(torch.isclose(encoder_emb[0, 8], sem_embs[0, 7] + model.position_embeddings( + torch.tensor(7)) + model.sem_id_position_embeddings(torch.tensor(3)))) + assert torch.all(torch.isclose(encoder_emb[1, 0], model.bos_embedding)) + assert torch.all(torch.isclose(encoder_emb[1, 1], sem_embs[1, 0] + model.position_embeddings( + torch.tensor(0)) + model.sem_id_position_embeddings(torch.tensor(0)))) + assert torch.all(torch.isclose(encoder_emb[1, 2], sem_embs[1, 1] + model.position_embeddings( + torch.tensor(1)) + model.sem_id_position_embeddings(torch.tensor(1)))) + assert torch.all(torch.isclose(encoder_emb[1, 3], sem_embs[1, 2] + model.position_embeddings( + torch.tensor(2)) + model.sem_id_position_embeddings(torch.tensor(2)))) + assert torch.all(torch.isclose(encoder_emb[1, 4], sem_embs[1, 3] + model.position_embeddings( + torch.tensor(3)) + model.sem_id_position_embeddings(torch.tensor(3)))) + assert torch.all(torch.isclose(encoder_emb[1, 5], torch.zeros(64))) + assert torch.all(torch.isclose(encoder_emb[1, 6], torch.zeros(64))) + assert torch.all(torch.isclose(encoder_emb[1, 7], torch.zeros(64))) + assert torch.all(torch.isclose(encoder_emb[1, 8], torch.zeros(64))) + + def test_only_decoder_data(self): + sizz = 64 + model = create_model() + expanded_bos = model.bos_embedding.unsqueeze(0).unsqueeze(0).expand(sizz, -1, -1) + + only_decoder_embs_flat = torch.randn(model._sem_id_len * sizz, model._embedding_dim) + only_decoder_lengths = torch.tensor([model._sem_id_len for _ in range(sizz)]) + + only_decoder_sem_embs, _ = create_masked_tensor(only_decoder_embs_flat, only_decoder_lengths) + + sem_pos_ids = torch.arange(model._sem_id_len, device=DEVICE).expand(sizz, -1) + sem_pos_emb = model.sem_id_position_embeddings(sem_pos_ids) + only_decoder_sem_embs += sem_pos_emb + + only_decoder_sem_embs_with_bos = torch.cat([expanded_bos, only_decoder_sem_embs], dim=1) + only_decoder_encoder_emb, only_decoder_encoder_mask, only_decoder_decoder_emb = model._prepare_sem_id_batch( + only_decoder_embs_flat, only_decoder_lengths) + + assert torch.all(only_decoder_encoder_emb == expanded_bos).item() + assert torch.all(only_decoder_decoder_emb == only_decoder_sem_embs_with_bos).item() + + +if __name__ == '__main__': + unittest.main()