Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 22 additions & 106 deletions configs/train/letter.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"experiment_name": "letter_data",
"best_metric": "validation/ndcg@20",
"experiment_name": "letter_tiger",
"best_metric": "loss",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

А почему тут лосс?

"train_epochs_num": 100,
"dataset": {
"type": "letter_full",
Expand Down Expand Up @@ -39,21 +39,15 @@
},
"model": {
"type": "tiger",
"rqvae_train_config_path": "../configs/train/rqvae_train_config.json",
"rqvae_checkpoint_path": "../checkpoints/rqvae_beauty_final_state.pth",
"embs_extractor_path": "../data/Beauty/rqvae/data_full.pt",
"sequence_prefix": "item",
"predictions_prefix": "logits",
"positive_prefix": "labels",
"labels_prefix": "labels",
"embedding_dim": 64,
"codebook_size": 256,
"num_positions": 200,
"num_heads": 2,
"num_encoder_layers": 2,
"num_decoder_layers": 2,
"dim_feedforward": 256,
"dropout": 0.3,
"activation": "gelu",
"layer_norm_eps": 1e-9,
"initializer_range": 0.02
},
"optimizer": {
Expand All @@ -68,18 +62,28 @@
"type": "composite",
"losses": [
{
"type": "ce",
"predictions_prefix": "logits",
"labels_prefix": "semantic.labels",
"type": "identity_map",
"predictions_prefix": "decoder_loss_1",
"weight": 1.0,
"output_prefix": "semantic_loss"
"output_prefix": "decoder_loss_1"
},
{
"type": "ce",
"predictions_prefix": "dedup.logits",
"labels_prefix": "dedup.labels",
"type": "identity_map",
"predictions_prefix": "decoder_loss_2",
"weight": 1.0,
"output_prefix": "dedup_loss"
"output_prefix": "decoder_loss_2"
},
{
"type": "identity_map",
"predictions_prefix": "decoder_loss_3",
"weight": 1.0,
"output_prefix": "decoder_loss_3"
},
{
"type": "identity_map",
"predictions_prefix": "decoder_loss_4",
"weight": 1.0,
"output_prefix": "decoder_loss_4"
}
],
"output_prefix": "loss"
Expand All @@ -91,94 +95,6 @@
"type": "metric",
"on_step": 1,
"loss_prefix": "loss"
},
{
"type": "validation",
"on_step": 1024,
"pred_prefix": "logits",
"labels_prefix": "labels",
"metrics": {
"ndcg@5": {
"type": "ndcg",
"k": 5
},
"ndcg@10": {
"type": "ndcg",
"k": 10
},
"ndcg@20": {
"type": "ndcg",
"k": 20
},
"recall@5": {
"type": "recall",
"k": 5
},
"recall@10": {
"type": "recall",
"k": 10
},
"recall@20": {
"type": "recall",
"k": 20
},
"coverage@5": {
"type": "coverage",
"k": 5
},
"coverage@10": {
"type": "coverage",
"k": 10
},
"coverage@20": {
"type": "coverage",
"k": 20
}
}
},
{
"type": "eval",
"on_step": 2048,
"pred_prefix": "logits",
"labels_prefix": "labels",
"metrics": {
"ndcg@5": {
"type": "ndcg",
"k": 5
},
"ndcg@10": {
"type": "ndcg",
"k": 10
},
"ndcg@20": {
"type": "ndcg",
"k": 20
},
"recall@5": {
"type": "recall",
"k": 5
},
"recall@10": {
"type": "recall",
"k": 10
},
"recall@20": {
"type": "recall",
"k": 20
},
"coverage@5": {
"type": "coverage",
"k": 5
},
"coverage@10": {
"type": "coverage",
"k": 10
},
"coverage@20": {
"type": "coverage",
"k": 20
}
}
}
]
}
Expand Down
16 changes: 15 additions & 1 deletion modeling/loss/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,21 @@ def forward(self, inputs):
inputs[self._output_prefix] = loss.cpu().item()

return loss

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Отступ 2 строки


class IdentityMapLoss(TorchLoss, config_name='identity_map'):

def __init__(self, predictions_prefix, output_prefix=None):
super().__init__()
self._input_loss_key = predictions_prefix
self._output_prefix = output_prefix

def forward(self, inputs):
loss = inputs[self._input_loss_key]
assert loss.dim() == 0, "Loss must be a scalar tensor"
if self._output_prefix is not None:
inputs[self._output_prefix] = loss.cpu().item()
return loss

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Отступ 2 строки

class RqVaeLoss(TorchLoss, config_name='rqvae_loss'):

def __init__(self, beta, output_prefix=None):
Expand Down
7 changes: 7 additions & 0 deletions modeling/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def _init_weights(self, initializer_range):
a=-2 * initializer_range,
b=2 * initializer_range
)
elif "bos_embedding" in key:
nn.init.trunc_normal_(
value.data,
std=initializer_range,
a=-2 * initializer_range,
b=2 * initializer_range,
)
else:
raise ValueError(f'Unknown transformer weight: {key}')

Expand Down
39 changes: 35 additions & 4 deletions modeling/models/sasrec_semantic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json

import torch
from .tiger import TigerModel
from models import SequentialTorchModel
from torch import nn

from models import SequentialTorchModel, RqVaeModel
from utils import DEVICE, create_masked_tensor
from torch import nn


class SasRecSemanticModel(SequentialTorchModel, config_name="sasrec_semantic"):
Expand Down Expand Up @@ -67,9 +68,39 @@ def __init__(
requires_grad=True,
) # len(self._codebook_sizes), codebook_size, embedding_dim

@classmethod
def init_rqvae(self, config):
rqvae_config = json.load(open(config["rqvae_train_config_path"]))
rqvae_config["model"]["should_init_codebooks"] = False

rqvae_model = RqVaeModel.create_from_config(rqvae_config["model"]).to(DEVICE)
rqvae_model.load_state_dict(
torch.load(config["rqvae_checkpoint_path"], weights_only=True)
)
rqvae_model.eval()
for param in rqvae_model.parameters():
param.requires_grad = False

codebook_sizes = rqvae_model.codebook_sizes
assert all([book_size == codebook_sizes[0] for book_size in codebook_sizes])

embs_extractor = torch.load(config["embs_extractor_path"], weights_only=False)

embs_extractor = embs_extractor.sort_index()

item_ids = embs_extractor.index.tolist()
assert item_ids == list(range(1, len(item_ids) + 1))

text_embeddings = torch.stack(embs_extractor["embeddings"].tolist()).to(DEVICE)

semantic_ids, residuals = rqvae_model({"embeddings": text_embeddings})

return rqvae_model, semantic_ids, residuals, item_ids


@classmethod
def create_from_config(cls, config, **kwargs):
rqvae_model, semantic_ids, residuals, _ = TigerModel.init_rqvae(config)
rqvae_model, semantic_ids, residuals, _ = cls.init_rqvae(config)

return cls(
rqvae_model=rqvae_model,
Expand Down
Loading