From c46e48555384914eb6d4ed7ebb3b146a4bed2c3c Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 02:20:38 +0300 Subject: [PATCH 01/27] begin debug sasrec logq, used base hookjabber commit: current logq version that need to check formulas --- src/irec/loss/base.py | 370 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 367 insertions(+), 3 deletions(-) diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index 10f689e..d0245d1 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -58,6 +58,125 @@ def forward(self, inputs): return total_loss +class BatchLogSoftmaxLoss(TorchLoss, config_name='batch_logsoftmax'): + def __init__(self, predictions_prefix, candidates_prefix): + super().__init__() + self._predictions_prefix = predictions_prefix + self._candidates_prefix = candidates_prefix + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + predictions_prefix=config.get('predictions_prefix'), + candidates_prefix=config.get('candidates_prefix'), + ) + + def forward(self, inputs): # use log soft max + predictions = inputs[self._predictions_prefix] + candidates = inputs[self._candidates_prefix] + + dot_product_matrix = predictions @ candidates.T + + row_log_softmax = nn.LogSoftmax(dim=1) + softmax_matrix = -row_log_softmax(dot_product_matrix) + + diagonal_elements = torch.diag(softmax_matrix) + + loss = diagonal_elements.mean() + + return loss + + +class CrossEntropyLoss(TorchLoss, config_name='ce'): + def __init__(self, predictions_prefix, labels_prefix, output_prefix=None): + super().__init__() + self._pred_prefix = predictions_prefix + self._labels_prefix = labels_prefix + self._output_prefix = output_prefix + + self._loss = nn.CrossEntropyLoss() + + def forward(self, inputs): + all_logits = inputs[self._pred_prefix] # (all_items, num_classes) + all_labels = inputs[ + '{}.ids'.format(self._labels_prefix) + ] # (all_items) + assert all_logits.shape[0] == all_labels.shape[0] + + loss = self._loss(all_logits, all_labels) # (1) + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + + return loss + + +class BinaryCrossEntropyLoss(TorchLoss, config_name='bce'): + def __init__( + self, + predictions_prefix, + labels_prefix, + with_logits=True, + output_prefix=None, + ): + super().__init__() + self._pred_prefix = predictions_prefix + self._labels_prefix = labels_prefix + self._output_prefix = output_prefix + + if with_logits: + self._loss = nn.BCEWithLogitsLoss() + else: + self._loss = nn.BCELoss() + + def forward(self, inputs): + all_logits = inputs[self._pred_prefix].float() # (all_batch_items) + all_labels = inputs[self._labels_prefix].float() # (all_batch_items) + assert all_logits.shape[0] == all_labels.shape[0] + + loss = self._loss(all_logits, all_labels) # (1) + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + + return loss + + +class BPRLoss(TorchLoss, config_name='bpr'): + def __init__(self, positive_prefix, negative_prefix, output_prefix=None): + super().__init__() + self._positive_prefix = positive_prefix + self._negative_prefix = negative_prefix + self._output_prefix = output_prefix + + def forward(self, inputs): + pos_scores = inputs[self._positive_prefix] # (all_batch_items) + neg_scores = inputs[self._negative_prefix] # (all_batch_items) + loss = -torch.log( + (pos_scores - neg_scores).sigmoid() + 1e-9, + ).mean() # (1) + + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + + return loss + + +class RegularizationLoss(TorchLoss, config_name='regularization'): + def __init__(self, prefix, output_prefix=None): + super().__init__() + self._prefix = maybe_to_list(prefix) + self._output_prefix = output_prefix + + def forward(self, inputs): + loss = 0.0 + for prefix in self._prefix: + loss += (1 / 2) * inputs[prefix].pow(2).mean() + + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + + return loss + + class FpsLoss(TorchLoss, config_name='fps'): def __init__( self, @@ -67,6 +186,8 @@ def __init__( normalize_embeddings=False, use_mean=True, output_prefix=None, + use_logq_correction=False, + logq_prefix=None, ): super().__init__() self._fst_embeddings_prefix = fst_embeddings_prefix @@ -77,7 +198,8 @@ def __init__( ) self._normalize_embeddings = normalize_embeddings self._output_prefix = output_prefix - print(self._tau) + self._use_logq_correction = use_logq_correction + self._logq_prefix = logq_prefix @classmethod def create_from_config(cls, config, **kwargs): @@ -87,10 +209,13 @@ def create_from_config(cls, config, **kwargs): tau=config.get('temperature', 1.0), normalize_embeddings=config.get('normalize_embeddings', False), use_mean=config.get('use_mean', True), - output_prefix=config.get('output_prefix') + output_prefix=config.get('output_prefix'), + use_logq_correction=config.get('use_logq_correction', False), + logq_prefix=config.get('logq_prefix', None), ) def forward(self, inputs): + fst_embeddings = inputs[ self._fst_embeddings_prefix ] # (x, embedding_dim) @@ -144,6 +269,15 @@ def forward(self, inputs): -1, ) # (2 * x, 2 * x - 2) + if self._use_logq_correction and self._logq_prefix is not None: + log_q = inputs[self._logq_prefix] + log_q_combined = torch.cat((log_q, log_q), dim=0) + + log_q_matrix = log_q_combined.unsqueeze(0).expand(2 * batch_size, -1) # (2B, 2B) + negative_log_q = log_q_matrix[mask].reshape(2 * batch_size, -1) # (2B, 2B-2) + + negative_samples = negative_samples - negative_log_q + labels = ( torch.zeros(2 * batch_size).to(positive_samples.device).long() ) # (2 * x) @@ -188,7 +322,8 @@ def forward(self, inputs): inputs[self._output_prefix] = loss.cpu().item() return loss - + +# sasrec logq debug class SamplesSoftmaxLoss(TorchLoss, config_name='sampled_softmax'): def __init__( @@ -197,12 +332,27 @@ def __init__( positive_prefix, negative_prefix, output_prefix=None, + use_logq_correction=False, + logq_prefix=None, ): super().__init__() self._queries_prefix = queries_prefix self._positive_prefix = positive_prefix self._negative_prefix = negative_prefix self._output_prefix = output_prefix + self._use_logq = use_logq_correction + self._logq_prefix = logq_prefix + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + queries_prefix=config['queries_prefix'], + positive_prefix=config['positive_prefix'], + negative_prefix=config['negative_prefix'], + output_prefix=config.get('output_prefix'), + use_logq_correction=config.get('use_logq_correction', False), + logq_prefix=config.get('logq_prefix') + ) def forward(self, inputs): queries_embeddings = inputs[ @@ -239,6 +389,15 @@ def forward(self, inputs): queries_embeddings, negative_embeddings, ) # (batch_size, num_negatives) + + if self._use_logq: + if self._logq_prefix is not None: + log_q = inputs[self._logq_prefix] # (B, 1+N) + log_q_pos = log_q[:, :1] # (B, 1) + log_q_neg = log_q[:, 1:] # (B, N) + + negative_scores = negative_scores - log_q_neg + all_scores = torch.cat( [positive_scores, negative_scores], dim=1, @@ -257,6 +416,211 @@ def forward(self, inputs): return loss +class S3RecPretrainLoss(TorchLoss, config_name='s3rec_pretrain'): + def __init__( + self, + positive_prefix, + negative_prefix, + representation_prefix, + output_prefix=None, + ): + super().__init__() + self._positive_prefix = positive_prefix + self._negative_prefix = negative_prefix + self._representation_prefix = representation_prefix + self._criterion = nn.BCEWithLogitsLoss(reduction='none') + self._output_prefix = output_prefix + + def forward(self, inputs): + positive_embeddings = inputs[ + self._positive_prefix + ] # (x, embedding_dim) + negative_embeddings = inputs[ + self._negative_prefix + ] # (x, embedding_dim) + current_embeddings = inputs[ + self._representation_prefix + ] # (x, embedding_dim) + assert ( + positive_embeddings.shape[0] + == negative_embeddings.shape[0] + == current_embeddings.shape[0] + ) + + positive_scores = torch.einsum( + 'bd,bd->b', + positive_embeddings, + current_embeddings, + ) # (x) + + negative_scores = torch.einsum( + 'bd,bd->b', + negative_embeddings, + current_embeddings, + ) # (x) + + distance = torch.sigmoid(positive_scores) - torch.sigmoid( + negative_scores, + ) # (x) + loss = torch.sum( + self._criterion( + distance, + torch.ones_like(distance, dtype=torch.float32), + ), + ) # (1) + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + + return loss + + +class Cl4sRecLoss(TorchLoss, config_name='cl4srec'): + def __init__( + self, + current_representation, + all_items_representation, + tau=1.0, + output_prefix=None, + ): + super().__init__() + self._current_representation = current_representation + self._all_items_representation = all_items_representation + self._loss_function = nn.CrossEntropyLoss() + self._tau = tau + self._output_prefix = output_prefix + + def forward(self, inputs): + current_representation = inputs[ + self._current_representation + ] # (batch_size, embedding_dim) + all_items_representation = inputs[ + self._all_items_representation + ] # (batch_size, num_negatives + 1, embedding_dim) + + batch_size = current_representation.shape[0] + + logits = torch.einsum( + 'bnd,bd->bn', + all_items_representation, + current_representation, + ) # (batch_size, num_negatives + 1) + labels = logits.new_zeros(batch_size) # (batch_size) + + loss = self._loss_function(logits, labels) + + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + + return loss + + +class DuorecSSLLoss(TorchLoss, config_name='duorec_ssl'): + def __init__( + self, + original_embedding_prefix, + dropout_embedding_prefix, + similar_embedding_prefix, + normalize_embeddings=False, + tau=1.0, + output_prefix=None, + ): + super().__init__() + self._original_embedding_prefix = original_embedding_prefix + self._dropout_embedding_prefix = dropout_embedding_prefix + self._similar_embedding_prefix = similar_embedding_prefix + self._normalize_embeddings = normalize_embeddings + self._output_prefix = output_prefix + self._tau = tau + self._loss_function = nn.CrossEntropyLoss(reduction='mean') + + def _compute_partial_loss(self, fst_embeddings, snd_embeddings): + batch_size = fst_embeddings.shape[0] + + combined_embeddings = torch.cat( + (fst_embeddings, snd_embeddings), + dim=0, + ) # (2 * x, embedding_dim) + + if self._normalize_embeddings: + combined_embeddings = torch.nn.functional.normalize( + combined_embeddings, + p=2, + dim=-1, + eps=1e-6, + ) + + similarity_scores = ( + torch.mm(combined_embeddings, combined_embeddings.T) / self._tau + ) # (2 * x, 2 * x) + + positive_samples = torch.cat( + ( + torch.diag(similarity_scores, batch_size), + torch.diag(similarity_scores, -batch_size), + ), + dim=0, + ).reshape(2 * batch_size, 1) # (2 * x, 1) + + # TODO optimize + mask = torch.ones( + 2 * batch_size, + 2 * batch_size, + dtype=torch.bool, + ) # (2 * x, 2 * x) + mask = mask.fill_diagonal_(0) # Remove equal embeddings scores + for i in range(batch_size): # Remove positives + mask[i, batch_size + i] = 0 + mask[batch_size + i, i] = 0 + + negative_samples = similarity_scores[mask].reshape( + 2 * batch_size, + -1, + ) # (2 * x, 2 * x - 2) + + labels = ( + torch.zeros(2 * batch_size).to(positive_samples.device).long() + ) # (2 * x) + logits = torch.cat( + (positive_samples, negative_samples), + dim=1, + ) # (2 * x, 2 * x - 1) + + loss = self._loss_function(logits, labels) / 2 # (1) + + return loss + + def forward(self, inputs): + original_embeddings = inputs[ + self._original_embedding_prefix + ] # (x, embedding_dim) + dropout_embeddings = inputs[ + self._dropout_embedding_prefix + ] # (x, embedding_dim) + similar_embeddings = inputs[ + self._similar_embedding_prefix + ] # (x, embedding_dim) + + dropout_loss = self._compute_partial_loss( + original_embeddings, + dropout_embeddings, + ) + ssl_loss = self._compute_partial_loss( + original_embeddings, + similar_embeddings, + ) + + loss = dropout_loss + ssl_loss + + if self._output_prefix is not None: + inputs[f'{self._output_prefix}_dropout'] = ( + dropout_loss.cpu().item() + ) + inputs[f'{self._output_prefix}_ssl'] = ssl_loss.cpu().item() + inputs[self._output_prefix] = loss.cpu().item() + + return loss + + class MCLSRLoss(TorchLoss, config_name='mclsr'): def __init__( self, From d1fcc81d983de8b37abed9a80fdf65776b1e0ed8 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 02:23:38 +0300 Subject: [PATCH 02/27] fix: apply logq correction to all logits in sampled softmax --- src/irec/loss/base.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index d0245d1..a5a326f 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -392,10 +392,25 @@ def forward(self, inputs): if self._use_logq: if self._logq_prefix is not None: - log_q = inputs[self._logq_prefix] # (B, 1+N) + # Retrieve log probabilities (log frequencies) from the input dictionary + # Expects a tensor of shape (BatchSize, 1 + NumNegatives) + log_q = inputs[self._logq_prefix] # (B, 1 + N) log_q_pos = log_q[:, :1] # (B, 1) log_q_neg = log_q[:, 1:] # (B, N) + # --- CORRECTION BASED ON GOOGLE PAPER (Eq. 3 & Section 3) --- + # According to "Sampling-Bias-Corrected Neural Modeling for Large Corpus + # Item Recommendations" (Google, 2019), Section 3 "MODELING FRAMEWORK": + # "we correct EACH logit s(x_i, y_j) by the following equation: + # s_c(x_i, y_j) = s(x_i, y_j) - log(p_j)" + # + # Applying this correction to BOTH positive and negative scores is critical + # to obtain an unbiased estimator for the full softmax. Omitting the + # correction for positive_scores leads to a sampling bias where + # popular items are unfairly penalized only when they act as negatives, + # but not when they act as positives. + + positive_scores = positive_scores - log_q_pos negative_scores = negative_scores - log_q_neg all_scores = torch.cat( From 827a83050b087a2f832d6442fe1326ee08126206 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 02:28:36 +0300 Subject: [PATCH 03/27] feat: pass item IDs from SasRecInBatchModel to enable masking and LogQ --- src/irec/models/sasrec.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/irec/models/sasrec.py b/src/irec/models/sasrec.py index e97019a..c767afd 100644 --- a/src/irec/models/sasrec.py +++ b/src/irec/models/sasrec.py @@ -199,7 +199,17 @@ def forward(self, inputs): return { 'query_embeddings': in_batch_queries_embeddings, 'positive_embeddings': in_batch_positive_embeddings, - 'negative_embeddings': in_batch_negative_embeddings + 'negative_embeddings': in_batch_negative_embeddings, + + # --- ID PASS-THROUGH FOR DOWNSTREAM LOSS OPERATIONS --- + # We pass the raw item IDs to the loss function to enable: + # 1. False Negative Masking: Identifying and neutralizing cases where a positive + # item for one user accidentally appears as a negative sample for another + # user in the same batch (critical for In-Batch Negatives stability). + # 2. Per-item LogQ Correction: Mapping item IDs to their global frequencies + # to subtract log(Q) based on the specific item popularity. + 'positive_ids': in_batch_positive_events, + 'negative_ids': in_batch_negative_ids } else: # eval mode last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim) From 50af16c774e3a59e7e1ff805439c4aa1856ff1b7 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 02:31:43 +0300 Subject: [PATCH 04/27] update initialization - nothing special --- src/irec/loss/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index a5a326f..bb3bd05 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -339,6 +339,7 @@ def __init__( self._queries_prefix = queries_prefix self._positive_prefix = positive_prefix self._negative_prefix = negative_prefix + positive_ids_prefix=None, self._output_prefix = output_prefix self._use_logq = use_logq_correction self._logq_prefix = logq_prefix @@ -349,6 +350,7 @@ def create_from_config(cls, config, **kwargs): queries_prefix=config['queries_prefix'], positive_prefix=config['positive_prefix'], negative_prefix=config['negative_prefix'], + positive_ids_prefix=config.get('positive_ids_prefix'), output_prefix=config.get('output_prefix'), use_logq_correction=config.get('use_logq_correction', False), logq_prefix=config.get('logq_prefix') From e0f2000c358f4c6b3cccf01507c862eb0d4b7b6f Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 02:36:01 +0300 Subject: [PATCH 05/27] fix: add false negative masking to SamplesSoftmaxLoss --- src/irec/loss/base.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index bb3bd05..1ea7278 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -392,6 +392,23 @@ def forward(self, inputs): negative_embeddings, ) # (batch_size, num_negatives) + # --- FALSE NEGATIVE MASKING (Critical for In-Batch Negatives) --- + # If we have item IDs, we must ensure that a positive item for a user + # is not treated as a negative for that same user if it appears + # elsewhere in the batch. + if self._positive_ids_prefix and self._negative_ids_prefix: + pos_ids = inputs[self._positive_ids_prefix] # (BatchSize,) + neg_ids = inputs[self._negative_ids_prefix] # (NumNegatives,) + + # Create a boolean mask of shape (BatchSize, NumNegatives) + # where True indicates that pos_ids[i] == neg_ids[j] + false_negative_mask = (pos_ids.unsqueeze(1) == neg_ids.unsqueeze(0)) + + # Mask out these scores by setting them to a very large negative value + # This prevents the model from receiving contradictory signals + # (trying to both increase and decrease the score of the same item). + negative_scores = negative_scores.masked_fill(false_negative_mask, -1e12) + if self._use_logq: if self._logq_prefix is not None: # Retrieve log probabilities (log frequencies) from the input dictionary From aeed13fa2c3384e8a2e031e50cb3f297579f1c58 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 02:39:11 +0300 Subject: [PATCH 06/27] fix: itit correction - nothing special --- src/irec/loss/base.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index 1ea7278..ae1e108 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -331,6 +331,8 @@ def __init__( queries_prefix, positive_prefix, negative_prefix, + positive_ids_prefix=None, + negative_ids_prefix=None, output_prefix=None, use_logq_correction=False, logq_prefix=None, @@ -339,7 +341,10 @@ def __init__( self._queries_prefix = queries_prefix self._positive_prefix = positive_prefix self._negative_prefix = negative_prefix - positive_ids_prefix=None, + + self._positive_ids_prefix = positive_ids_prefix + self._negative_ids_prefix = negative_ids_prefix + self._output_prefix = output_prefix self._use_logq = use_logq_correction self._logq_prefix = logq_prefix @@ -351,6 +356,7 @@ def create_from_config(cls, config, **kwargs): positive_prefix=config['positive_prefix'], negative_prefix=config['negative_prefix'], positive_ids_prefix=config.get('positive_ids_prefix'), + negative_ids_prefix=config.get('negative_ids_prefix'), output_prefix=config.get('output_prefix'), use_logq_correction=config.get('use_logq_correction', False), logq_prefix=config.get('logq_prefix') From 300d1e374771459a03a3fff707e2cd55d52b6bc9 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 02:53:17 +0300 Subject: [PATCH 07/27] feat: implement full LogQ correction and FN masking for SASRec Aligned SASRec in-batch training with Google's 'Sampling-Bias-Corrected Neural Modeling' (2019) and the reference 'Golden Code' implementation. Key changes: 1. Mathematical Correction: Updated SamplesSoftmaxLoss to subtract logQ from BOTH positive and negative scores (Eq. 3). 2. False Negative Masking: Added masking logic to neutralize logits where positive items accidentally appear in the negative pool. 3. Architecture: Modified SasRecInBatchModel to return item IDs and updated loss factory to support global frequency loading via 'path_to_item_counts'. --- configs/train/sasrec_logq_config.json | 127 ++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 configs/train/sasrec_logq_config.json diff --git a/configs/train/sasrec_logq_config.json b/configs/train/sasrec_logq_config.json new file mode 100644 index 0000000..c4bc0aa --- /dev/null +++ b/configs/train/sasrec_logq_config.json @@ -0,0 +1,127 @@ +{ + "experiment_name": "sasrec_logq_clothing_unbiased", + "use_wandb": true, + "best_metric": "validation/ndcg@20", + "dataset": { + "type": "sasrec_comparison", + "path_to_data_dir": "./data", + "name": "Clothing", + "max_sequence_length": 20, + "train_sampler": { + "type": "next_item_prediction", + "negative_sampler_type": "random", + "num_negatives_train": 0 + }, + "eval_sampler": { + "type": "mclsr" + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": true, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "type": "sasrec_in_batch", + "sequence_prefix": "item", + "positive_prefix": "positive", + "negative_prefix": "negative", + "candidate_prefix": "candidates", + "embedding_dim": 64, + "num_heads": 2, + "num_layers": 2, + "dim_feedforward": 256, + "dropout": 0.3, + "activation": "gelu", + "layer_norm_eps": 1e-9, + "initializer_range": 0.02 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 0.001 + }, + "clip_grad_threshold": 1.0 + }, + "loss": { + "type": "composite", + "losses": [ + { + "type": "sampled_softmax", + "queries_prefix": "query_embeddings", + "positive_prefix": "positive_embeddings", + "negative_prefix": "negative_embeddings", + + "use_logq_correction": true, + "path_to_item_counts": "./data/Clothing/item_counts.pkl", + "positive_ids_prefix": "positive_ids", + "negative_ids_prefix": "negative_ids", + + "output_prefix": "downstream_loss" + } + ], + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "validation", + "on_step": 64, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, + "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, + "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, + "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, + "recall@5": { "type": "mclsr-recall", "k": 5 }, + "recall@10": { "type": "mclsr-recall", "k": 10 }, + "recall@20": { "type": "mclsr-recall", "k": 20 }, + "recall@50": { "type": "mclsr-recall", "k": 50 }, + "hit@20": { "type": "mclsr-hit", "k": 20 }, + "hit@50": { "type": "mclsr-hit", "k": 50 } + } + }, + { + "type": "eval", + "on_step": 256, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, + "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, + "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, + "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, + "recall@5": { "type": "mclsr-recall", "k": 5 }, + "recall@10": { "type": "mclsr-recall", "k": 10 }, + "recall@20": { "type": "mclsr-recall", "k": 20 }, + "recall@50": { "type": "mclsr-recall", "k": 50 }, + "hit@20": { "type": "mclsr-hit", "k": 20 }, + "hit@50": { "type": "mclsr-hit", "k": 50 } + } + } + ] + } +} From 8cd79fe1d36e443dd1ca894144b84171fa338efa Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 03:08:06 +0300 Subject: [PATCH 08/27] fix(loss): implement unbiased logq logic and false negative masking Refined the core logic of SamplesSoftmaxLoss.forward to align with theoretical requirements and the "golden" SASRec implementation. Changes: 1. False Negative Masking: Added ID-based logic to identify and neutralize logits where a positive item accidentally appears as a negative sample within the same batch. This prevents contradictory gradient signals. 2. Unbiased LogQ: Implemented Eq. 3 from the Google paper, ensuring that log-frequencies are subtracted from BOTH positive and negative scores to maintain an unbiased estimator. 3. Documentation: Added detailed English comments referencing the "Sampling-Bias-Corrected Neural Modeling" paper (2019). --- src/irec/loss/base.py | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index ae1e108..a46e6d3 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +import pickle class BaseLoss(metaclass=MetaParent): pass @@ -415,28 +416,36 @@ def forward(self, inputs): # (trying to both increase and decrease the score of the same item). negative_scores = negative_scores.masked_fill(false_negative_mask, -1e12) + # --- 2. UNBIASED LOGQ CORRECTION --- + # Applying correction to EACH logit per Google Paper (Eq. 3) if self._use_logq: - if self._logq_prefix is not None: - # Retrieve log probabilities (log frequencies) from the input dictionary - # Expects a tensor of shape (BatchSize, 1 + NumNegatives) - log_q = inputs[self._logq_prefix] # (B, 1 + N) - log_q_pos = log_q[:, :1] # (B, 1) - log_q_neg = log_q[:, 1:] # (B, N) - - # --- CORRECTION BASED ON GOOGLE PAPER (Eq. 3 & Section 3) --- - # According to "Sampling-Bias-Corrected Neural Modeling for Large Corpus - # Item Recommendations" (Google, 2019), Section 3 "MODELING FRAMEWORK": + # Source of truth: our pre-loaded self._log_counts from the pickle + if self._log_counts is not None: + if self._log_counts.device != positive_scores.device: + self._log_counts = self._log_counts.to(positive_scores.device) + + # We need IDs to fetch the correct frequencies for items in this batch + pos_ids = inputs[self._positive_ids_prefix] + neg_ids = inputs[self._negative_ids_prefix] + + log_q_pos = self._log_counts[pos_ids].unsqueeze(-1) # (B, 1) + log_q_neg = self._log_counts[neg_ids] # (N,) or (B, N) + + # --- LOGQ CORRECTION COMMENTS --- + # According to "Sampling-Bias-Corrected Neural Modeling..." (Google, 2019): # "we correct EACH logit s(x_i, y_j) by the following equation: # s_c(x_i, y_j) = s(x_i, y_j) - log(p_j)" - # - # Applying this correction to BOTH positive and negative scores is critical - # to obtain an unbiased estimator for the full softmax. Omitting the - # correction for positive_scores leads to a sampling bias where - # popular items are unfairly penalized only when they act as negatives, - # but not when they act as positives. + # This ensures the estimator remains unbiased by penalizing popular + # items equally when they are targets and when they are negatives. positive_scores = positive_scores - log_q_pos negative_scores = negative_scores - log_q_neg + + # (Optional) If frequencies were passed directly in inputs, not via pickle: + elif self._logq_prefix in inputs: + log_q = inputs[self._logq_prefix] + positive_scores = positive_scores - log_q[:, :1] + negative_scores = negative_scores - log_q[:, 1:] all_scores = torch.cat( [positive_scores, negative_scores], From f48a298ee18f3eac098f309c5f2b105288ce3ece Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 03:16:03 +0300 Subject: [PATCH 09/27] Comprehensive upgrade of SamplesSoftmaxLoss to align with Google's "Sampling-Bias-Corrected Neural Modeling" (2019) and reference implementations. Key changes: 1. Mathematical Correctness: Updated forward pass to apply LogQ correction to BOTH positive and negative scores (unbiased estimator per Google Eq. 3). 2. Training Stability: Implemented False Negative Masking using item IDs to neutralize target items that accidentally appear in the negative pool. 3. Data Integration: Added 'path_to_item_counts' support to create_from_config. The loss now pre-loads and pre-computes log-frequencies from a pickle file. 4. Robustness: Added device management for log_counts and fallback logic for dynamic LogQ inputs. 5. Infrastructure: Integrated logging to track frequency table initialization . --- src/irec/loss/base.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index a46e6d3..6ba796c 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -9,6 +9,10 @@ import torch.nn as nn import pickle +import os +import logging + +logger = logging.getLogger(__name__) class BaseLoss(metaclass=MetaParent): pass @@ -337,6 +341,7 @@ def __init__( output_prefix=None, use_logq_correction=False, logq_prefix=None, + log_counts=None, ): super().__init__() self._queries_prefix = queries_prefix @@ -349,9 +354,24 @@ def __init__( self._output_prefix = output_prefix self._use_logq = use_logq_correction self._logq_prefix = logq_prefix + self._log_counts = log_counts @classmethod def create_from_config(cls, config, **kwargs): + log_counts = None + path_to_counts = config.get('path_to_item_counts') + + if path_to_counts and config.get('use_logq_correction'): + import pickle + with open(path_to_counts, 'rb') as f: + counts = pickle.load(f) + + counts_tensor = torch.tensor(counts, dtype=torch.float32) + # Normalize in probability and use logarithm (Google Eq. 3) + probs = torch.clamp(counts_tensor / counts_tensor.sum(), min=1e-10) + log_counts = torch.log(probs) + logger.info(f"Loaded item counts from {path_to_counts} for LogQ correction") + return cls( queries_prefix=config['queries_prefix'], positive_prefix=config['positive_prefix'], @@ -360,7 +380,8 @@ def create_from_config(cls, config, **kwargs): negative_ids_prefix=config.get('negative_ids_prefix'), output_prefix=config.get('output_prefix'), use_logq_correction=config.get('use_logq_correction', False), - logq_prefix=config.get('logq_prefix') + logq_prefix=config.get('logq_prefix'), + log_counts=log_counts # <-- ПЕРЕДАЕМ В КОНСТРУКТОР ) def forward(self, inputs): From 85d9c63b8d54fd75f9a663308afdf4de73e7ef68 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 03:35:22 +0300 Subject: [PATCH 10/27] chore(config): add SASRec in-batch training config with LogQ correction Created a specialized configuration for SASRec on the Clothing dataset to test the implemented LogQ correction and FN masking. - Model set to 'sasrec_in_batch' to support ID pass-through. - Loss switched to 'sampled_softmax' with 'use_logq_correction' enabled. - Linked 'path_to_item_counts' to popularity statistics. - Configured evaluation metrics for consistency with MCLSR benchmarks. --- scripts/generate_item_counts.py | 58 +++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 scripts/generate_item_counts.py diff --git a/scripts/generate_item_counts.py b/scripts/generate_item_counts.py new file mode 100644 index 0000000..0c472d5 --- /dev/null +++ b/scripts/generate_item_counts.py @@ -0,0 +1,58 @@ +import pickle +import numpy as np +from collections import Counter +import argparse +import os + +def main(): + parser = argparse.ArgumentParser(description="Generate item interaction counts for LogQ correction.") + parser.add_argument("--input", type=str, required=True, help="Path to train_sasrec.txt") + parser.add_argument("--output", type=str, required=True, help="Path to save item_counts.pkl") + parser.add_argument("--num_items", type=int, required=True, help="Number of items in dataset (from dataset.meta)") + + args = parser.parse_args() + + # We use num_items + 2 because the embedding layer size is num_items + 2 + # (reserved for padding at index 0 and mask at index num_items + 1) + array_size = args.num_items + 2 + counts = Counter() + + print(f"[*] Reading dataset from: {args.input}") + if not os.path.exists(args.input): + print(f"[!] Error: File {args.input} not found.") + return + + with open(args.input, 'r') as f: + for line in f: + parts = line.strip().split() + if len(parts) < 2: + continue + + # parts[0] is user_id, parts[1:] are sequences of interacted item_ids + items = [int(i) for i in parts[1:]] + counts.update(items) + + # Initialize frequencies array with zeros + item_counts_array = np.zeros(array_size, dtype=np.float32) + + for item_id, count in counts.items(): + if item_id < array_size: + item_counts_array[item_id] = count + else: + print(f"[!] Warning: item_id {item_id} exceeds array size {array_size}. Check your num_items!") + + # Numerical stability: set zero counts to 1.0 to avoid log(0) in LogQ correction + zero_mask = (item_counts_array == 0) + num_zeros = np.sum(zero_mask) + if num_zeros > 0: + print(f"[*] Found {num_zeros} items with zero interactions. Setting their count to 1.0 for stability.") + item_counts_array[zero_mask] = 1.0 + + print(f"[*] Saving popularity statistics to: {args.output}") + with open(args.output, 'wb') as f: + pickle.dump(item_counts_array, f) + + print("[+] Done! LogQ data is ready.") + +if __name__ == "__main__": + main() From 9af07d2dada48eb76010aa10ac38ca3b435ae0ee Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 03:45:26 +0300 Subject: [PATCH 11/27] add instruction - nothing special --- scripts/generate_item_counts.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/generate_item_counts.py b/scripts/generate_item_counts.py index 0c472d5..f1c753b 100644 --- a/scripts/generate_item_counts.py +++ b/scripts/generate_item_counts.py @@ -1,3 +1,6 @@ +# use +# python scripts/generate_item_counts.py --input ./data/Clothing/train_sasrec.txt --output ./data/Clothing/item_counts.pkl --num_items 23033 + import pickle import numpy as np from collections import Counter From c4eebaed4bcab7ff7b0f1c144e9ffd969fd365c0 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 05:48:53 +0300 Subject: [PATCH 12/27] feat(mclsr): implement ID pass-through for downstream LogQ and masking Modified MCLSRModel.forward return dictionary to include raw indices: - Added 'positive_ids' and 'negative_ids' to support unbiased LogQ correction. - Added 'user_ids' to enable potential user-level frequency correction. This change allows SamplesSoftmaxLoss to correctly map embeddings back to their popularity statistics and perform False Negative Masking, bringing the MCLSR training pipeline in line with the SASRec improvements. --- configs/train/mclsr_logq_Clothing_config.json | 126 ++++++++++++++++++ src/irec/models/mclsr.py | 15 +++ 2 files changed, 141 insertions(+) create mode 100644 configs/train/mclsr_logq_Clothing_config.json diff --git a/configs/train/mclsr_logq_Clothing_config.json b/configs/train/mclsr_logq_Clothing_config.json new file mode 100644 index 0000000..d11e588 --- /dev/null +++ b/configs/train/mclsr_logq_Clothing_config.json @@ -0,0 +1,126 @@ +{ + "experiment_name": "mclsr_logq_Clothing_unbiased", + "use_wandb": true, + "best_metric": "validation/ndcg@20", + "dataset": { + "type": "graph", + "use_user_graph": true, + "use_item_graph": true, + "neighborhood_size": 50, + "graph_dir_path": "./data/Clothing", + "dataset": { + "type": "mclsr", + "path_to_data_dir": "./data", + "name": "Clothing", + "max_sequence_length": 20, + "samplers": { + "num_negatives_val": 1280, + "num_negatives_train": 1280, + "type": "mclsr", + "negative_sampler_type": "random" + } + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": true, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "type": "mclsr", + "sequence_prefix": "item", + "user_prefix": "user", + "labels_prefix": "labels", + "candidate_prefix": "candidates", + "embedding_dim": 64, + "num_graph_layers": 2, + "dropout": 0.3, + "layer_norm_eps": 1e-9, + "graph_dropout": 0.3, + "initializer_range": 0.02, + "alpha": 0.5 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 0.001 + } + }, + "loss": { + "type": "composite", + "losses": [ + { + "type": "sampled_softmax", + "queries_prefix": "combined_representation", + "positive_prefix": "label_representation", + "negative_prefix": "negative_representation", + + "use_logq_correction": true, + "path_to_item_counts": "./data/Clothing/item_counts.pkl", + "positive_ids_prefix": "positive_ids", + "negative_ids_prefix": "negative_ids", + + "output_prefix": "downstream_loss", + "weight": 1.0 + }, + { + "type": "fps", + "fst_embeddings_prefix": "sequential_representation", + "snd_embeddings_prefix": "graph_representation", + "output_prefix": "contrastive_interest_loss", + "weight": 1.0, + "temperature": 0.5 + }, + { + "type": "fps", + "fst_embeddings_prefix": "user_graph_user_embeddings", + "snd_embeddings_prefix": "common_graph_user_embeddings", + "output_prefix": "contrastive_user_feature_loss", + "weight": 0.05, + "temperature": 0.5 + }, + { + "type": "fps", + "fst_embeddings_prefix": "item_graph_item_embeddings", + "snd_embeddings_prefix": "common_graph_item_embeddings", + "output_prefix": "contrastive_item_feature_loss", + "weight": 0.05, + "temperature": 0.5 + } + ], + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "validation", + "on_step": 64, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, + "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, + "recall@20": { "type": "mclsr-recall", "k": 20 }, + "recall@50": { "type": "mclsr-recall", "k": 50 } + } + } + ] + } +} \ No newline at end of file diff --git a/src/irec/models/mclsr.py b/src/irec/models/mclsr.py index fc8c6ef..ef841d1 100644 --- a/src/irec/models/mclsr.py +++ b/src/irec/models/mclsr.py @@ -375,6 +375,21 @@ def scatter_mean(src, index, dim=0, dim_size=None): 'negative_representation': negative_embeddings, + + # --- ID PASS-THROUGH FOR LOGQ & MASKING --- + # We pass raw item and user indices to enable advanced loss operations: + # 1. False Negative Masking: Allows SamplesSoftmaxLoss to identify and + # neutralize cases where a target item accidentally appears in the + # negative sampling pool. + # 2. Per-item LogQ Correction: Enables mapping item IDs to global frequency + # stats (item_counts.pkl) to remove popularity bias (Sampling Bias). + 'positive_ids': labels, # Item target indices + 'negative_ids': negative_ids, # Sampled negative item indices + + # Useful for potential User-level LogQ correction as requested + # by the supervisor to handle highly active users. + 'user_ids': user_ids, + # for L_IL (formula 8) 'sequential_representation': sequential_representation_proj, From e0ce73daa71eb4b66578c38fb5beb57df2ab422d Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Fri, 19 Dec 2025 06:22:19 +0300 Subject: [PATCH 13/27] fix(mclsr config): add expected metrics - nothing special --- configs/train/mclsr_logq_Clothing_config.json | 68 ++++++++++++++----- 1 file changed, 51 insertions(+), 17 deletions(-) diff --git a/configs/train/mclsr_logq_Clothing_config.json b/configs/train/mclsr_logq_Clothing_config.json index d11e588..4de9f69 100644 --- a/configs/train/mclsr_logq_Clothing_config.json +++ b/configs/train/mclsr_logq_Clothing_config.json @@ -106,21 +106,55 @@ ], "output_prefix": "loss" }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "validation", - "on_step": 64, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, - "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, - "recall@20": { "type": "mclsr-recall", "k": 20 }, - "recall@50": { "type": "mclsr-recall", "k": 50 } - } +"callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "downstream_loss" + }, + { + "type": "validation", + "on_step": 64, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, + "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, + "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, + "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, + "recall@5": { "type": "mclsr-recall", "k": 5 }, + "recall@10": { "type": "mclsr-recall", "k": 10 }, + "recall@20": { "type": "mclsr-recall", "k": 20 }, + "recall@50": { "type": "mclsr-recall", "k": 50 }, + "hit@20": { "type": "mclsr-hit", "k": 20 }, + "hit@50": { "type": "mclsr-hit", "k": 50 } } - ] - } -} \ No newline at end of file + }, + { + "type": "eval", + "on_step": 256, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, + "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, + "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, + "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, + "recall@5": { "type": "mclsr-recall", "k": 5 }, + "recall@10": { "type": "mclsr-recall", "k": 10 }, + "recall@20": { "type": "mclsr-recall", "k": 20 }, + "recall@50": { "type": "mclsr-recall", "k": 50 }, + "hit@20": { "type": "mclsr-hit", "k": 20 }, + "hit@50": { "type": "mclsr-hit", "k": 50 } + } + } +] +} +} From e22e500bdb3b687c0d430d063175cfc18642cd82 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Sat, 27 Dec 2025 21:47:48 +0300 Subject: [PATCH 14/27] init logq mclsr and sasrec - have to work, logq 0.2 mclsr wins --- configs/train/mclsr_logq_Clothing_config.json | 113 ++++----- configs/train/mclsr_train_config.json | 4 +- configs/train/sasrec_ce_train_config.json | 146 ------------ .../train/sasrec_in_batch_train_config.json | 146 ------------ configs/train/sasrec_logq_config.json | 4 +- configs/train/sasrec_real_train_config.json | 217 ------------------ configs/train/sasrec_train_config.json | 116 +++------- configs/train/sasrec_train_grid_config.json | 185 --------------- src/irec/loss/base.py | 202 +++++++++++++++- 9 files changed, 291 insertions(+), 842 deletions(-) delete mode 100644 configs/train/sasrec_ce_train_config.json delete mode 100644 configs/train/sasrec_in_batch_train_config.json delete mode 100644 configs/train/sasrec_real_train_config.json delete mode 100644 configs/train/sasrec_train_grid_config.json diff --git a/configs/train/mclsr_logq_Clothing_config.json b/configs/train/mclsr_logq_Clothing_config.json index 4de9f69..3bbba78 100644 --- a/configs/train/mclsr_logq_Clothing_config.json +++ b/configs/train/mclsr_logq_Clothing_config.json @@ -1,6 +1,6 @@ { - "experiment_name": "mclsr_logq_Clothing_unbiased", - "use_wandb": true, + "experiment_name": "mclsr_logq_lambda0.2_Clothing", + "use_wandb": false, "best_metric": "validation/ndcg@20", "dataset": { "type": "graph", @@ -17,7 +17,7 @@ "num_negatives_val": 1280, "num_negatives_train": 1280, "type": "mclsr", - "negative_sampler_type": "random" + "negative_sampler_type": "popularity" } } }, @@ -66,16 +66,17 @@ "type": "composite", "losses": [ { - "type": "sampled_softmax", + "type": "mclsr_logq_special", "queries_prefix": "combined_representation", "positive_prefix": "label_representation", "negative_prefix": "negative_representation", - "use_logq_correction": true, "path_to_item_counts": "./data/Clothing/item_counts.pkl", "positive_ids_prefix": "positive_ids", "negative_ids_prefix": "negative_ids", + "logq_lambda": 0.2, + "output_prefix": "downstream_loss", "weight": 1.0 }, @@ -106,55 +107,55 @@ ], "output_prefix": "loss" }, -"callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "downstream_loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, - "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, - "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, - "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, - "recall@5": { "type": "mclsr-recall", "k": 5 }, - "recall@10": { "type": "mclsr-recall", "k": 10 }, - "recall@20": { "type": "mclsr-recall", "k": 20 }, - "recall@50": { "type": "mclsr-recall", "k": 50 }, - "hit@20": { "type": "mclsr-hit", "k": 20 }, - "hit@50": { "type": "mclsr-hit", "k": 50 } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, - "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, - "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, - "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, - "recall@5": { "type": "mclsr-recall", "k": 5 }, - "recall@10": { "type": "mclsr-recall", "k": 10 }, - "recall@20": { "type": "mclsr-recall", "k": 20 }, - "recall@50": { "type": "mclsr-recall", "k": 50 }, - "hit@20": { "type": "mclsr-hit", "k": 20 }, - "hit@50": { "type": "mclsr-hit", "k": 50 } + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "downstream_loss" + }, + { + "type": "validation", + "on_step": 64, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, + "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, + "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, + "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, + "recall@5": { "type": "mclsr-recall", "k": 5 }, + "recall@10": { "type": "mclsr-recall", "k": 10 }, + "recall@20": { "type": "mclsr-recall", "k": 20 }, + "recall@50": { "type": "mclsr-recall", "k": 50 }, + "hit@20": { "type": "mclsr-hit", "k": 20 }, + "hit@50": { "type": "mclsr-hit", "k": 50 } + } + }, + { + "type": "eval", + "on_step": 256, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, + "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, + "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, + "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, + "recall@5": { "type": "mclsr-recall", "k": 5 }, + "recall@10": { "type": "mclsr-recall", "k": 10 }, + "recall@20": { "type": "mclsr-recall", "k": 20 }, + "recall@50": { "type": "mclsr-recall", "k": 50 }, + "hit@20": { "type": "mclsr-hit", "k": 20 }, + "hit@50": { "type": "mclsr-hit", "k": 50 } + } } - } -] -} -} + ] + } +} \ No newline at end of file diff --git a/configs/train/mclsr_train_config.json b/configs/train/mclsr_train_config.json index 30b15fa..1042543 100644 --- a/configs/train/mclsr_train_config.json +++ b/configs/train/mclsr_train_config.json @@ -1,6 +1,6 @@ { "experiment_name": "mclsr_Clothing", - "use_wandb": true, + "use_wandb": false, "best_metric": "validation/ndcg@20", "dataset": { "type": "graph", @@ -226,4 +226,4 @@ } ] } -} \ No newline at end of file +} diff --git a/configs/train/sasrec_ce_train_config.json b/configs/train/sasrec_ce_train_config.json deleted file mode 100644 index ce17ae3..0000000 --- a/configs/train/sasrec_ce_train_config.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "experiment_name": "sasrec_test", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "use_ce": true, - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec", - "positive_prefix": "positive_embeddings", - "negative_prefix": "negative_embeddings", - "representation_prefix": "current_embeddings", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_in_batch_train_config.json b/configs/train/sasrec_in_batch_train_config.json deleted file mode 100644 index 9950f31..0000000 --- a/configs/train/sasrec_in_batch_train_config.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "experiment_name": "sasrec_in_batch_test", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec_in_batch", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "use_ce": true, - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sampled_softmax", - "queries_prefix": "query_embeddings", - "positive_prefix": "positive_embeddings", - "negative_prefix": "negative_embeddings", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_logq_config.json b/configs/train/sasrec_logq_config.json index c4bc0aa..b86330a 100644 --- a/configs/train/sasrec_logq_config.json +++ b/configs/train/sasrec_logq_config.json @@ -1,6 +1,6 @@ { "experiment_name": "sasrec_logq_clothing_unbiased", - "use_wandb": true, + "use_wandb": false, "best_metric": "validation/ndcg@20", "dataset": { "type": "sasrec_comparison", @@ -63,7 +63,7 @@ "type": "composite", "losses": [ { - "type": "sampled_softmax", + "type": "logq_sampled_softmax", "queries_prefix": "query_embeddings", "positive_prefix": "positive_embeddings", "negative_prefix": "negative_embeddings", diff --git a/configs/train/sasrec_real_train_config.json b/configs/train/sasrec_real_train_config.json deleted file mode 100644 index b04a5a0..0000000 --- a/configs/train/sasrec_real_train_config.json +++ /dev/null @@ -1,217 +0,0 @@ -{ - "experiment_name": "sasrec_real_beauty", - "best_metric": "validation/ndcg@100", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 100, - "samplers": { - "num_negatives_train": 1, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 2048, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec_real", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec_real", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "ndcg", - "k": 50 - }, - "ndcg@100": { - "type": "ndcg", - "k": 100 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "recall@50": { - "type": "recall", - "k": 50 - }, - "recall@100": { - "type": "recall", - "k": 100 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - }, - "coverage@50": { - "type": "coverage", - "k": 50 - }, - "coverage@100": { - "type": "coverage", - "k": 100 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "ndcg", - "k": 50 - }, - "ndcg@100": { - "type": "ndcg", - "k": 100 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "recall@50": { - "type": "recall", - - "k": 50 - }, - "recall@100": { - "type": "recall", - "k": 100 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - }, - "coverage@50": { - "type": "coverage", - "k": 50 - }, - "coverage@100": { - "type": "coverage", - "k": 100 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_train_config.json b/configs/train/sasrec_train_config.json index 37b2888..47be053 100644 --- a/configs/train/sasrec_train_config.json +++ b/configs/train/sasrec_train_config.json @@ -1,6 +1,6 @@ { - "experiment_name": "sasrec_clothing", - "use_wandb": true, + "experiment_name": "sasrec_clothing_baseline_no_correction", + "use_wandb": false, "best_metric": "validation/ndcg@20", "dataset": { "type": "sasrec_comparison", @@ -37,9 +37,10 @@ } }, "model": { - "type": "sasrec", + "type": "sasrec_in_batch", "sequence_prefix": "item", "positive_prefix": "positive", + "negative_prefix": "negative", "candidate_prefix": "candidates", "embedding_dim": 64, "num_heads": 2, @@ -62,9 +63,10 @@ "type": "composite", "losses": [ { - "type": "sasrec", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", + "type": "sampled_softmax", + "queries_prefix": "query_embeddings", + "positive_prefix": "positive_embeddings", + "negative_prefix": "negative_embeddings", "output_prefix": "downstream_loss" } ], @@ -84,46 +86,16 @@ "pred_prefix": "predictions", "labels_prefix": "labels", "metrics": { - "ndcg@5": { - "type": "mclsr-ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "mclsr-ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "mclsr-ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "mclsr-ndcg", - "k": 50 - }, - "recall@5": { - "type": "mclsr-recall", - "k": 5 - }, - "recall@10": { - "type": "mclsr-recall", - "k": 10 - }, - "recall@20": { - "type": "mclsr-recall", - "k": 20 - }, - "recall@50": { - "type": "mclsr-recall", - "k": 50 - }, - "hit@20": { - "type": "mclsr-hit", - "k": 20 - }, - "hit@50": { - "type": "mclsr-hit", - "k": 50 - } + "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, + "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, + "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, + "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, + "recall@5": { "type": "mclsr-recall", "k": 5 }, + "recall@10": { "type": "mclsr-recall", "k": 10 }, + "recall@20": { "type": "mclsr-recall", "k": 20 }, + "recall@50": { "type": "mclsr-recall", "k": 50 }, + "hit@20": { "type": "mclsr-hit", "k": 20 }, + "hit@50": { "type": "mclsr-hit", "k": 50 } } }, { @@ -132,48 +104,18 @@ "pred_prefix": "predictions", "labels_prefix": "labels", "metrics": { - "ndcg@5": { - "type": "mclsr-ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "mclsr-ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "mclsr-ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "mclsr-ndcg", - "k": 50 - }, - "recall@5": { - "type": "mclsr-recall", - "k": 5 - }, - "recall@10": { - "type": "mclsr-recall", - "k": 10 - }, - "recall@20": { - "type": "mclsr-recall", - "k": 20 - }, - "recall@50": { - "type": "mclsr-recall", - "k": 50 - }, - "hit@20": { - "type": "mclsr-hit", - "k": 20 - }, - "hit@50": { - "type": "mclsr-hit", - "k": 50 - } + "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, + "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, + "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, + "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, + "recall@5": { "type": "mclsr-recall", "k": 5 }, + "recall@10": { "type": "mclsr-recall", "k": 10 }, + "recall@20": { "type": "mclsr-recall", "k": 20 }, + "recall@50": { "type": "mclsr-recall", "k": 50 }, + "hit@20": { "type": "mclsr-hit", "k": 20 }, + "hit@50": { "type": "mclsr-hit", "k": 50 } } } ] } -} \ No newline at end of file +} diff --git a/configs/train/sasrec_train_grid_config.json b/configs/train/sasrec_train_grid_config.json deleted file mode 100644 index acf86b6..0000000 --- a/configs/train/sasrec_train_grid_config.json +++ /dev/null @@ -1,185 +0,0 @@ -{ - "start_from": 0, - "experiment_name": "sasrec_beauty_grid", - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataset_params": { - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "model_params": { - "dropout": [ - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9] - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "optimizer_params": { - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "loss_params": { - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - } - ] - } -} diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index 6ba796c..1d0e194 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn + import pickle import os import logging @@ -328,9 +329,77 @@ def forward(self, inputs): return loss -# sasrec logq debug class SamplesSoftmaxLoss(TorchLoss, config_name='sampled_softmax'): + def __init__( + self, + queries_prefix, + positive_prefix, + negative_prefix, + output_prefix=None, + ): + super().__init__() + self._queries_prefix = queries_prefix + self._positive_prefix = positive_prefix + self._negative_prefix = negative_prefix + self._output_prefix = output_prefix + + def forward(self, inputs): + queries_embeddings = inputs[ + self._queries_prefix + ] # (batch_size, embedding_dim) + positive_embeddings = inputs[ + self._positive_prefix + ] # (batch_size, embedding_dim) + negative_embeddings = inputs[ + self._negative_prefix + ] # (num_negatives, embedding_dim) or (batch_size, num_negatives, embedding_dim) + + # b -- batch_size, d -- embedding_dim + positive_scores = torch.einsum( + 'bd,bd->b', + queries_embeddings, + positive_embeddings, + ).unsqueeze(-1) # (batch_size, 1) + + if negative_embeddings.dim() == 2: # (num_negatives, embedding_dim) + # b -- batch_size, n -- num_negatives, d -- embedding_dim + negative_scores = torch.einsum( + 'bd,nd->bn', + queries_embeddings, + negative_embeddings, + ) # (batch_size, num_negatives) + else: + assert ( + negative_embeddings.dim() == 3 + ) # (batch_size, num_negatives, embedding_dim) + # b -- batch_size, n -- num_negatives, d -- embedding_dim + negative_scores = torch.einsum( + 'bd,bnd->bn', + queries_embeddings, + negative_embeddings, + ) # (batch_size, num_negatives) + all_scores = torch.cat( + [positive_scores, negative_scores], + dim=1, + ) # (batch_size, 1 + num_negatives) + + logits = torch.log_softmax( + all_scores, + dim=1, + ) # (batch_size, 1 + num_negatives) + loss = (-logits)[:, 0] # (batch_size) + loss = loss.mean() # (1) + + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + + return loss + + +# sasrec logq debug + +class LogqSamplesSoftmaxLoss(TorchLoss, config_name='logq_sampled_softmax'): def __init__( self, queries_prefix, @@ -485,6 +554,137 @@ def forward(self, inputs): return loss +class MCLSRLogqLoss(TorchLoss, config_name='mclsr_logq_special'): + """ + Specialized Loss for MCLSR model implementing Sampling-Bias Correction (LogQ) + with a tunable lambda coefficient for correction strength. + + Theory: + Following Yi et al. (2019), we correct the sampling bias introduced by non-uniform + negative sampling (e.g., popularity-based or in-batch sampling). + + Mathematical Formula (Eq. 3 in the paper): + s_c(x, y) = s(x, y) - lambda * log(p_j) + where p_j is the sampling probability of item j and lambda is the correction strength. + """ + def __init__( + self, + queries_prefix, + positive_prefix, + negative_prefix, + positive_ids_prefix, + negative_ids_prefix, + path_to_item_counts, + logq_lambda=1.0, # Strength of the LogQ correction (usually 0.1 to 1.0) + output_prefix=None, + ): + super().__init__() + self._queries_prefix = queries_prefix + self._positive_prefix = positive_prefix + self._negative_prefix = negative_prefix + self._positive_ids_prefix = positive_ids_prefix + self._negative_ids_prefix = negative_ids_prefix + self._output_prefix = output_prefix + self._logq_lambda = logq_lambda + + if not os.path.exists(path_to_item_counts): + raise FileNotFoundError(f"Item counts file not found at {path_to_item_counts}") + + with open(path_to_item_counts, 'rb') as f: + counts = pickle.load(f) + + counts_tensor = torch.tensor(counts, dtype=torch.float32) + + # Calculate log probabilities log(p_j). + # Using clamp(min=1e-10) for numerical stability to avoid log(0) -> NaN. + probs = torch.clamp(counts_tensor / counts_tensor.sum(), min=1e-10) + log_q = torch.log(probs) + + # register_buffer ensures the lookup table moves to the correct device (GPU/CPU) + # along with the model parameters. + self.register_buffer('_log_q_table', log_q) + + @classmethod + def create_from_config(cls, config, **kwargs): + """Standard framework factory method to initialize the loss from a JSON config.""" + return cls( + queries_prefix=config['queries_prefix'], + positive_prefix=config['positive_prefix'], + negative_prefix=config['negative_prefix'], + positive_ids_prefix=config['positive_ids_prefix'], + negative_ids_prefix=config['negative_ids_prefix'], + path_to_item_counts=config['path_to_item_counts'], + logq_lambda=config.get('logq_lambda', 1.0), + output_prefix=config.get('output_prefix') + ) + + def forward(self, inputs): + # Retrieve Embeddings + queries = inputs[self._queries_prefix] # (Batch, Dim) + pos_embs = inputs[self._positive_prefix] # (Batch, Dim) + neg_embs = inputs[self._negative_prefix] # (B, N, D) or (N, D) + + # Retrieve Item IDs for correction and collision masking + pos_ids = inputs[self._positive_ids_prefix] # (Batch) + neg_ids = inputs[self._negative_ids_prefix] # (Batch, N) or (N) + + # --- DEVICE SYNCHRONIZATION --- + # Explicitly ensure the LogQ lookup table is on the same device as the inputs. + # This prevents "indices should be on the same device" RuntimeError on CUDA. + device = queries.device + if self._log_q_table.device != device: + self._log_q_table = self._log_q_table.to(device) + + # --- STEP 1: Score Calculation (Inner Product) --- + # Positive scores: (Batch, 1) + pos_scores = torch.einsum('bd,bd->b', queries, pos_embs).unsqueeze(-1) + + # Negative scores with dimension check (handles both shared and per-user negatives) + if neg_embs.dim() == 2: + # Case: Shared negatives across the batch (N, D) -> Result: (B, N) + neg_scores = torch.einsum('bd,nd->bn', queries, neg_embs) + else: + # Case: Individual negatives per user (B, N, D) -> Result: (B, N) + neg_scores = torch.einsum('bd,bnd->bn', queries, neg_embs) + + # --- STEP 2: False Negative Masking --- + # Prevent contradictory gradients by masking negative samples that match + # the ground truth item ID for a given user. + # logic: pos_ids.unsqueeze(1) creates (B, 1) for broadcasting against (B, N) or (1, N) + neg_id_comparison = neg_ids.unsqueeze(0) if neg_ids.dim() == 1 else neg_ids + false_negative_mask = (pos_ids.unsqueeze(1) == neg_id_comparison) + + # Mask out scores of false negatives by setting them to a large negative value + neg_scores = neg_scores.masked_fill(false_negative_mask, -1e12) + + # --- STEP 3: Tunable LogQ Correction --- + # Applying the bias correction as per Google Paper Eq. 3, scaled by lambda. + # s_c = s - lambda * log(Q) + log_q_pos = self._log_q_table[pos_ids].unsqueeze(-1) # (B, 1) + + if neg_ids.dim() == 1: + # Case: Shared negative IDs + log_q_neg = self._log_q_table[neg_ids].unsqueeze(0) # (1, N) + else: + # Case: Individual negative IDs + log_q_neg = self._log_q_table[neg_ids] # (B, N) + + # Apply final correction + pos_scores = pos_scores - (self._logq_lambda * log_q_pos) + neg_scores = neg_scores - (self._logq_lambda * log_q_neg) + + # --- STEP 4: Softmax Loss Calculation --- + # Concatenate positive and negative scores: (B, 1 + N) + all_scores = torch.cat([pos_scores, neg_scores], dim=1) + + # Cross-Entropy using log_softmax for numerical stability + loss = -torch.log_softmax(all_scores, dim=1)[:, 0] + + final_loss = loss.mean() + if self._output_prefix: + inputs[self._output_prefix] = final_loss.cpu().item() + + return final_loss class S3RecPretrainLoss(TorchLoss, config_name='s3rec_pretrain'): def __init__( From 68c63e42dfe4b7cb3d81ea815119f36de44cb461 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Wed, 1 Apr 2026 11:21:58 +0300 Subject: [PATCH 15/27] logq softmax --- README.md | 24 +- .../inference/sasrec_inference_config.json | 88 --- ...thing_config.json => mclsr_logq_1903.json} | 17 +- configs/train/mclsr_train_config.json | 4 +- configs/train/sasrec_logq_config.json | 127 ---- configs/train/sasrec_train_config.json | 121 ---- requirements.txt | 12 + scripts/generate_item_counts.py | 61 -- src/irec/dataloader/base.py | 14 + src/irec/dataset/base.py | 34 +- src/irec/dataset/negative_samplers/popular.py | 83 ++- src/irec/dataset/negative_samplers/random.py | 54 +- src/irec/dataset/samplers/base.py | 4 +- src/irec/dataset/samplers/mclsr.py | 22 +- .../dataset/samplers/next_item_prediction.py | 4 +- src/irec/loss/base.py | 667 ++---------------- src/irec/models/__init__.py | 4 - src/irec/models/mclsr.py | 15 +- src/irec/models/sasrec_ce.py | 108 --- 19 files changed, 283 insertions(+), 1180 deletions(-) delete mode 100644 configs/inference/sasrec_inference_config.json rename configs/train/{mclsr_logq_Clothing_config.json => mclsr_logq_1903.json} (95%) delete mode 100644 configs/train/sasrec_logq_config.json delete mode 100644 configs/train/sasrec_train_config.json create mode 100644 requirements.txt delete mode 100644 scripts/generate_item_counts.py delete mode 100644 src/irec/models/sasrec_ce.py diff --git a/README.md b/README.md index 680c46e..f0053bc 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ + +# Steps on local machine

cool irec logo @@ -40,26 +42,6 @@ uv sync --frozen ``` -### Using pip - -1. Create and activate a virtual environment: - ```bash - python3 -m venv .venv - source ./.venv/bin/activate - ``` - -2. Install dependencies: - - **For development:** - ```bash - pip install -e ".[dev]" - ``` - - **For production:** - ```bash - pip install -e . - ``` - ## Preparing datasets All pre-processed datasets used in our experiments are available for download from our cloud storage. This is the fastest way to get started. @@ -91,4 +73,4 @@ The script has 1 input argument: `params` which is the path to the json file wit -`callbacks` Different additional traning --`use_wandb` Enable Weights & Biases logging for experiment tracking \ No newline at end of file +-`use_wandb` Enable Weights & Biases logging for experiment tracking diff --git a/configs/inference/sasrec_inference_config.json b/configs/inference/sasrec_inference_config.json deleted file mode 100644 index c91e21b..0000000 --- a/configs/inference/sasrec_inference_config.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "pred_prefix": "logits", - "label_prefix": "labels", - "experiment_name": "sasrec_beauty_grid__0-5__", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.5, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } -} diff --git a/configs/train/mclsr_logq_Clothing_config.json b/configs/train/mclsr_logq_1903.json similarity index 95% rename from configs/train/mclsr_logq_Clothing_config.json rename to configs/train/mclsr_logq_1903.json index 3bbba78..e25b105 100644 --- a/configs/train/mclsr_logq_Clothing_config.json +++ b/configs/train/mclsr_logq_1903.json @@ -1,5 +1,5 @@ { - "experiment_name": "mclsr_logq_lambda0.2_Clothing", + "experiment_name": "mclsr_random_logq_v0.1_Clothing_1903", "use_wandb": false, "best_metric": "validation/ndcg@20", "dataset": { @@ -29,7 +29,8 @@ "type": "basic" }, "drop_last": true, - "shuffle": true + "shuffle": true, + "pin_memory": true }, "validation": { "type": "torch", @@ -38,7 +39,8 @@ "type": "basic" }, "drop_last": false, - "shuffle": false + "shuffle": false, + "pin_memory": true } }, "model": { @@ -70,13 +72,10 @@ "queries_prefix": "combined_representation", "positive_prefix": "label_representation", "negative_prefix": "negative_representation", - - "path_to_item_counts": "./data/Clothing/item_counts.pkl", "positive_ids_prefix": "positive_ids", "negative_ids_prefix": "negative_ids", - - "logq_lambda": 0.2, - + "path_to_item_counts": "./data/Clothing/item_counts.pkl", + "logq_lambda": 1.0, "output_prefix": "downstream_loss", "weight": 1.0 }, @@ -158,4 +157,4 @@ } ] } -} \ No newline at end of file +} diff --git a/configs/train/mclsr_train_config.json b/configs/train/mclsr_train_config.json index 1042543..30b15fa 100644 --- a/configs/train/mclsr_train_config.json +++ b/configs/train/mclsr_train_config.json @@ -1,6 +1,6 @@ { "experiment_name": "mclsr_Clothing", - "use_wandb": false, + "use_wandb": true, "best_metric": "validation/ndcg@20", "dataset": { "type": "graph", @@ -226,4 +226,4 @@ } ] } -} +} \ No newline at end of file diff --git a/configs/train/sasrec_logq_config.json b/configs/train/sasrec_logq_config.json deleted file mode 100644 index b86330a..0000000 --- a/configs/train/sasrec_logq_config.json +++ /dev/null @@ -1,127 +0,0 @@ -{ - "experiment_name": "sasrec_logq_clothing_unbiased", - "use_wandb": false, - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "sasrec_comparison", - "path_to_data_dir": "./data", - "name": "Clothing", - "max_sequence_length": 20, - "train_sampler": { - "type": "next_item_prediction", - "negative_sampler_type": "random", - "num_negatives_train": 0 - }, - "eval_sampler": { - "type": "mclsr" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec_in_batch", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 1.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "logq_sampled_softmax", - "queries_prefix": "query_embeddings", - "positive_prefix": "positive_embeddings", - "negative_prefix": "negative_embeddings", - - "use_logq_correction": true, - "path_to_item_counts": "./data/Clothing/item_counts.pkl", - "positive_ids_prefix": "positive_ids", - "negative_ids_prefix": "negative_ids", - - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, - "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, - "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, - "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, - "recall@5": { "type": "mclsr-recall", "k": 5 }, - "recall@10": { "type": "mclsr-recall", "k": 10 }, - "recall@20": { "type": "mclsr-recall", "k": 20 }, - "recall@50": { "type": "mclsr-recall", "k": 50 }, - "hit@20": { "type": "mclsr-hit", "k": 20 }, - "hit@50": { "type": "mclsr-hit", "k": 50 } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, - "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, - "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, - "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, - "recall@5": { "type": "mclsr-recall", "k": 5 }, - "recall@10": { "type": "mclsr-recall", "k": 10 }, - "recall@20": { "type": "mclsr-recall", "k": 20 }, - "recall@50": { "type": "mclsr-recall", "k": 50 }, - "hit@20": { "type": "mclsr-hit", "k": 20 }, - "hit@50": { "type": "mclsr-hit", "k": 50 } - } - } - ] - } -} diff --git a/configs/train/sasrec_train_config.json b/configs/train/sasrec_train_config.json deleted file mode 100644 index 47be053..0000000 --- a/configs/train/sasrec_train_config.json +++ /dev/null @@ -1,121 +0,0 @@ -{ - "experiment_name": "sasrec_clothing_baseline_no_correction", - "use_wandb": false, - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "sasrec_comparison", - "path_to_data_dir": "./data", - "name": "Clothing", - "max_sequence_length": 20, - "train_sampler": { - "type": "next_item_prediction", - "negative_sampler_type": "random", - "num_negatives_train": 0 - }, - "eval_sampler": { - "type": "mclsr" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec_in_batch", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 1.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sampled_softmax", - "queries_prefix": "query_embeddings", - "positive_prefix": "positive_embeddings", - "negative_prefix": "negative_embeddings", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, - "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, - "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, - "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, - "recall@5": { "type": "mclsr-recall", "k": 5 }, - "recall@10": { "type": "mclsr-recall", "k": 10 }, - "recall@20": { "type": "mclsr-recall", "k": 20 }, - "recall@50": { "type": "mclsr-recall", "k": 50 }, - "hit@20": { "type": "mclsr-hit", "k": 20 }, - "hit@50": { "type": "mclsr-hit", "k": 50 } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, - "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, - "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, - "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, - "recall@5": { "type": "mclsr-recall", "k": 5 }, - "recall@10": { "type": "mclsr-recall", "k": 10 }, - "recall@20": { "type": "mclsr-recall", "k": 20 }, - "recall@50": { "type": "mclsr-recall", "k": 50 }, - "hit@20": { "type": "mclsr-hit", "k": 20 }, - "hit@50": { "type": "mclsr-hit", "k": 50 } - } - } - ] - } -} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5b0ca5a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +numpy~=1.26 +torch~=2.4 +tqdm~=4.66 +scipy~=1.14 +pandas~=2.2 +polars~=1.27 +matplotlib~=3.9 +tensorboard~=2.19 +wandb~=0.19 +jupyterlab~=4.4 +ipykernel~=6.29 +notebook~=7.4 diff --git a/scripts/generate_item_counts.py b/scripts/generate_item_counts.py deleted file mode 100644 index f1c753b..0000000 --- a/scripts/generate_item_counts.py +++ /dev/null @@ -1,61 +0,0 @@ -# use -# python scripts/generate_item_counts.py --input ./data/Clothing/train_sasrec.txt --output ./data/Clothing/item_counts.pkl --num_items 23033 - -import pickle -import numpy as np -from collections import Counter -import argparse -import os - -def main(): - parser = argparse.ArgumentParser(description="Generate item interaction counts for LogQ correction.") - parser.add_argument("--input", type=str, required=True, help="Path to train_sasrec.txt") - parser.add_argument("--output", type=str, required=True, help="Path to save item_counts.pkl") - parser.add_argument("--num_items", type=int, required=True, help="Number of items in dataset (from dataset.meta)") - - args = parser.parse_args() - - # We use num_items + 2 because the embedding layer size is num_items + 2 - # (reserved for padding at index 0 and mask at index num_items + 1) - array_size = args.num_items + 2 - counts = Counter() - - print(f"[*] Reading dataset from: {args.input}") - if not os.path.exists(args.input): - print(f"[!] Error: File {args.input} not found.") - return - - with open(args.input, 'r') as f: - for line in f: - parts = line.strip().split() - if len(parts) < 2: - continue - - # parts[0] is user_id, parts[1:] are sequences of interacted item_ids - items = [int(i) for i in parts[1:]] - counts.update(items) - - # Initialize frequencies array with zeros - item_counts_array = np.zeros(array_size, dtype=np.float32) - - for item_id, count in counts.items(): - if item_id < array_size: - item_counts_array[item_id] = count - else: - print(f"[!] Warning: item_id {item_id} exceeds array size {array_size}. Check your num_items!") - - # Numerical stability: set zero counts to 1.0 to avoid log(0) in LogQ correction - zero_mask = (item_counts_array == 0) - num_zeros = np.sum(zero_mask) - if num_zeros > 0: - print(f"[*] Found {num_zeros} items with zero interactions. Setting their count to 1.0 for stability.") - item_counts_array[zero_mask] = 1.0 - - print(f"[*] Saving popularity statistics to: {args.output}") - with open(args.output, 'wb') as f: - pickle.dump(item_counts_array, f) - - print("[+] Done! LogQ data is ready.") - -if __name__ == "__main__": - main() diff --git a/src/irec/dataloader/base.py b/src/irec/dataloader/base.py index 06fdefc..4bc0f75 100644 --- a/src/irec/dataloader/base.py +++ b/src/irec/dataloader/base.py @@ -34,10 +34,24 @@ def create_from_config(cls, config, **kwargs): create_config.pop( 'type', ) # For passing as **config in torch DataLoader + + + pin_memory = create_config.pop('pin_memory', True) + return cls( dataloader=DataLoader( kwargs['dataset'], collate_fn=batch_processor, + pin_memory=pin_memory, **create_config, ), ) + + # return cls( + # dataloader=DataLoader( + # kwargs['dataset'], + # collate_fn=batch_processor, + # pin_memory=True, + # **create_config, + # ), + # ) diff --git a/src/irec/dataset/base.py b/src/irec/dataset/base.py index b07d00c..e8f4218 100644 --- a/src/irec/dataset/base.py +++ b/src/irec/dataset/base.py @@ -383,8 +383,21 @@ def _build_or_load_similarity_graph( ): if entity_type not in ['user', 'item']: raise ValueError("entity_type must be either 'user' or 'item'") + # have to delete and replace to not delete npz each time manually + # path_to_graph = os.path.join(self._graph_dir_path, '{}_graph.npz'.format(entity_type)) + + # instead better use such construction + + # neighborhood_size + # The neighborhood_size is a filter that constrains the number of edges for each user or + # item node in the graph. + # k=50 implies that for each user, we find all possible neighbors, sort them based on + # co-occurrence counts, and keep only the top 50. All other connections are removed from the graph. + k_suffix = f"k{self._neighborhood_size}" if self._neighborhood_size is not None else "full" + train_suffix = "trainOnly" if self._use_train_data_only else "withValTest" + filename = f"{entity_type}_graph_{k_suffix}_{train_suffix}.npz" + path_to_graph = os.path.join(self._graph_dir_path, filename) - path_to_graph = os.path.join(self._graph_dir_path, '{}_graph.npz'.format(entity_type)) is_user_graph = (entity_type == 'user') num_entities = self._num_users if is_user_graph else self._num_items @@ -394,7 +407,12 @@ def _build_or_load_similarity_graph( interactions_fst = [] interactions_snd = [] visited_user_item_pairs = set() - visited_entity_pairs = set() + # have to delete cause + # 3.2 Graph Construction + # User-user/item-item graph + # ..the weight of each edge denotes the number of co-action behaviors between user i and user j + + # visited_entity_pairs = set() for user_id, item_id in tqdm( zip(train_user_interactions, train_item_interactions), @@ -414,10 +432,10 @@ def _build_or_load_similarity_graph( continue pair_key = (source_entity, connected_entity) - if pair_key in visited_entity_pairs: - continue + # if pair_key in visited_entity_pairs: + # continue - visited_entity_pairs.add(pair_key) + # visited_entity_pairs.add(pair_key) interactions_fst.append(source_entity) interactions_snd.append(connected_entity) @@ -445,7 +463,11 @@ def _build_or_load_similarity_graph( return self._convert_sp_mat_to_sp_tensor(graph_matrix).coalesce().to(DEVICE) def _build_or_load_bipartite_graph(self, graph_dir_path, train_user_interactions, train_item_interactions): - path_to_graph = os.path.join(graph_dir_path, 'general_graph.npz') + # path_to_graph = os.path.join(graph_dir_path, 'general_graph.npz') + train_suffix = "trainOnly" if self._use_train_data_only else "withValTest" + filename = f"general_graph_{train_suffix}.npz" + path_to_graph = os.path.join(graph_dir_path, filename) + if os.path.exists(path_to_graph): graph_matrix = sp.load_npz(path_to_graph) else: diff --git a/src/irec/dataset/negative_samplers/popular.py b/src/irec/dataset/negative_samplers/popular.py index ca91bf6..f004efe 100644 --- a/src/irec/dataset/negative_samplers/popular.py +++ b/src/irec/dataset/negative_samplers/popular.py @@ -1,5 +1,5 @@ +import numpy as np from irec.dataset.negative_samplers.base import BaseNegativeSampler - from collections import Counter @@ -11,7 +11,12 @@ def __init__(self, dataset, num_users, num_items): num_items=num_items, ) - self._popular_items = self._items_by_popularity() + # --- OLD DETERMINISTIC LOGIC --- + # self._popular_items = self._items_by_popularity() + + # --- NEW STOCHASTIC LOGIC FOR LogQ COMPATIBILITY --- + # Pre-calculate item probabilities based on global frequency + self._item_ids, self._probs = self._calculate_item_probabilities() @classmethod def create_from_config(cls, _, **kwargs): @@ -21,24 +26,72 @@ def create_from_config(cls, _, **kwargs): num_items=kwargs['num_items'], ) - def _items_by_popularity(self): - popularity = Counter() + # --- OLD METHOD: Deterministic sorting --- + # def _items_by_popularity(self): + # popularity = Counter() + # for sample in self._dataset: + # for item_id in sample['item.ids']: + # popularity[item_id] += 1 + # popular_items = sorted(popularity, key=popularity.get, reverse=True) + # return popular_items + def _calculate_item_probabilities(self): + """ + Calculates sampling probabilities proportional to item popularity. + This distribution is required to provide non-zero p_j values for LogQ correction. + """ + counts = Counter() for sample in self._dataset: for item_id in sample['item.ids']: - popularity[item_id] += 1 + counts[item_id] += 1 + + items = np.array(list(counts.keys())) + freqs = np.array(list(counts.values()), dtype=np.float32) + probabilities = freqs / freqs.sum() + + return items, probabilities - popular_items = sorted(popularity, key=popularity.get, reverse=True) - return popular_items + # --- OLD METHOD: Picking Top-K items sequentially (Deterministic) --- + # def generate_negative_samples(self, sample, num_negatives): + # user_id = sample['user.ids'][0] + # popularity_idx = 0 + # negatives = [] + # while len(negatives) < num_negatives: + # negative_idx = self._popular_items[popularity_idx] + # if negative_idx not in self._seen_items[user_id]: + # negatives.append(negative_idx) + # popularity_idx += 1 + # return negatives def generate_negative_samples(self, sample, num_negatives): + """ + Stochastic sampling proportional to popularity. + + Justification: + The original implementation always picked the same Top-K popular items. + For LogQ correction (Yi et al., Google 2019), we need a stochastic + sampling process where p_j > 0 for all items in the distribution. + This allows the model to see a diverse set of negatives across epochs + while penalizing popular items correctly via the log(p_j) term. + """ user_id = sample['user.ids'][0] - popularity_idx = 0 - negatives = [] + seen = self._seen_items[user_id] + + negatives = set() while len(negatives) < num_negatives: - negative_idx = self._popular_items[popularity_idx] - if negative_idx not in self._seen_items[user_id]: - negatives.append(negative_idx) - popularity_idx += 1 - - return negatives + # Sample items based on the pre-calculated frequency distribution + sampled_ids = np.random.choice( + self._item_ids, + size=num_negatives - len(negatives), + p=self._probs, + replace=True + ) + + # Filter out items already seen by the user (False Negatives) + for idx in sampled_ids: + if idx not in seen: + negatives.add(idx) + if len(negatives) == num_negatives: + break + + return list(negatives) diff --git a/src/irec/dataset/negative_samplers/random.py b/src/irec/dataset/negative_samplers/random.py index 79e245a..a4926b1 100644 --- a/src/irec/dataset/negative_samplers/random.py +++ b/src/irec/dataset/negative_samplers/random.py @@ -12,17 +12,47 @@ def create_from_config(cls, _, **kwargs): num_items=kwargs['num_items'], ) + # def generate_negative_samples(self, sample, num_negatives): + # user_id = sample['user.ids'][0] + # all_items = list(range(1, self._num_items + 1)) + # np.random.shuffle(all_items) + + # negatives = [] + # running_idx = 0 + # while len(negatives) < num_negatives and running_idx < len(all_items): + # negative_idx = all_items[running_idx] + # if negative_idx not in self._seen_items[user_id]: + # negatives.append(negative_idx) + # running_idx += 1 + + # return negatives + def generate_negative_samples(self, sample, num_negatives): + """ + Optimized via Rejection Sampling (O(k) complexity). + + Mathematical Proof of Equivalence: + Let V be the set of all items and H be the user's history. + We need a uniform random sample S ⊂ (V \ H) such that |S| = k. + + 1. Shuffle Approach (Previous): Generates a random permutation of V, + then filters H. Complexity: O(|V|). + 2. Rejection Sampling (Current): Independently draws i ~ Uniform(V) + and accepts i if i ∉ H and i ∉ S. Complexity: O(k * 1/p), + where p = (|V| - |H|) / |V|. + + Since |H| << |V|, the probability p ≈ 1, making the expected complexity + effectively O(k). Both methods yield an identical uniform distribution + over the valid item space. + """ user_id = sample['user.ids'][0] - all_items = list(range(1, self._num_items + 1)) - np.random.shuffle(all_items) - - negatives = [] - running_idx = 0 - while len(negatives) < num_negatives and running_idx < len(all_items): - negative_idx = all_items[running_idx] - if negative_idx not in self._seen_items[user_id]: - negatives.append(negative_idx) - running_idx += 1 - - return negatives + seen = self._seen_items[user_id] + + negatives = set() + while len(negatives) < num_negatives: + # Drawing a random index is O(1) + negative_idx = np.random.randint(1, self._num_items + 1) + if negative_idx not in seen: + negatives.add(negative_idx) + + return list(negatives) diff --git a/src/irec/dataset/samplers/base.py b/src/irec/dataset/samplers/base.py index 158bede..7ceca78 100644 --- a/src/irec/dataset/samplers/base.py +++ b/src/irec/dataset/samplers/base.py @@ -33,7 +33,9 @@ def __len__(self): return len(self._dataset) def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) + # sample = copy.deepcopy(self._dataset[index]) + # yes, it's safe + sample = self._dataset[index] item_sequence = sample['item.ids'][:-1] next_item = sample['item.ids'][-1] diff --git a/src/irec/dataset/samplers/mclsr.py b/src/irec/dataset/samplers/mclsr.py index b957cde..bf0fef6 100644 --- a/src/irec/dataset/samplers/mclsr.py +++ b/src/irec/dataset/samplers/mclsr.py @@ -36,9 +36,27 @@ def __getitem__(self, index): user_seen = self._user_to_all_seen_items[user_id] - unseen_items = list(self._all_items_set - user_seen) + # unseen_items = list(self._all_items_set - user_seen) + # negatives = random.sample(unseen_items, self._num_negatives) + + # --- OPTIMIZATION: Rejection Sampling --- + # Mathematically equivalent to: random.sample(list(all_items - user_seen), k) + # Logic: Instead of allocating a huge list of unseen items (O(N) memory), + # we draw random indices until we have the required number of unique negatives. + # Since |user_seen| << |num_items|, the collision probability is near zero. + negatives = set() + while len(negatives) < self._num_negatives: + # Draw a random item index from the range [1, num_items] + candidate = random.randint(1, self._num_items) + + # Rejection step: Only accept if the user has never interacted with it. + # This ensures we only sample from the "unseen" pool. + if candidate not in user_seen: + negatives.add(candidate) - negatives = random.sample(unseen_items, self._num_negatives) + # Convert back to list to match the expected format for BatchProcessor + negatives = list(negatives) + # ---------------------------------------- return { diff --git a/src/irec/dataset/samplers/next_item_prediction.py b/src/irec/dataset/samplers/next_item_prediction.py index e263062..0db7eb7 100644 --- a/src/irec/dataset/samplers/next_item_prediction.py +++ b/src/irec/dataset/samplers/next_item_prediction.py @@ -39,7 +39,9 @@ def create_from_config(cls, config, **kwargs): ) def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) + # sample = copy.deepcopy(self._dataset[index]) + # yes, it's safe + sample = self._dataset[index] item_sequence = sample['item.ids'][:-1] next_item_sequence = sample['item.ids'][1:] diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index 1d0e194..6152e5d 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -1,19 +1,15 @@ -import copy - from irec.utils import ( MetaParent, maybe_to_list, ) +import copy import torch import torch.nn as nn - - import pickle import os import logging -logger = logging.getLogger(__name__) class BaseLoss(metaclass=MetaParent): pass @@ -64,125 +60,6 @@ def forward(self, inputs): return total_loss -class BatchLogSoftmaxLoss(TorchLoss, config_name='batch_logsoftmax'): - def __init__(self, predictions_prefix, candidates_prefix): - super().__init__() - self._predictions_prefix = predictions_prefix - self._candidates_prefix = candidates_prefix - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - predictions_prefix=config.get('predictions_prefix'), - candidates_prefix=config.get('candidates_prefix'), - ) - - def forward(self, inputs): # use log soft max - predictions = inputs[self._predictions_prefix] - candidates = inputs[self._candidates_prefix] - - dot_product_matrix = predictions @ candidates.T - - row_log_softmax = nn.LogSoftmax(dim=1) - softmax_matrix = -row_log_softmax(dot_product_matrix) - - diagonal_elements = torch.diag(softmax_matrix) - - loss = diagonal_elements.mean() - - return loss - - -class CrossEntropyLoss(TorchLoss, config_name='ce'): - def __init__(self, predictions_prefix, labels_prefix, output_prefix=None): - super().__init__() - self._pred_prefix = predictions_prefix - self._labels_prefix = labels_prefix - self._output_prefix = output_prefix - - self._loss = nn.CrossEntropyLoss() - - def forward(self, inputs): - all_logits = inputs[self._pred_prefix] # (all_items, num_classes) - all_labels = inputs[ - '{}.ids'.format(self._labels_prefix) - ] # (all_items) - assert all_logits.shape[0] == all_labels.shape[0] - - loss = self._loss(all_logits, all_labels) # (1) - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class BinaryCrossEntropyLoss(TorchLoss, config_name='bce'): - def __init__( - self, - predictions_prefix, - labels_prefix, - with_logits=True, - output_prefix=None, - ): - super().__init__() - self._pred_prefix = predictions_prefix - self._labels_prefix = labels_prefix - self._output_prefix = output_prefix - - if with_logits: - self._loss = nn.BCEWithLogitsLoss() - else: - self._loss = nn.BCELoss() - - def forward(self, inputs): - all_logits = inputs[self._pred_prefix].float() # (all_batch_items) - all_labels = inputs[self._labels_prefix].float() # (all_batch_items) - assert all_logits.shape[0] == all_labels.shape[0] - - loss = self._loss(all_logits, all_labels) # (1) - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class BPRLoss(TorchLoss, config_name='bpr'): - def __init__(self, positive_prefix, negative_prefix, output_prefix=None): - super().__init__() - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._output_prefix = output_prefix - - def forward(self, inputs): - pos_scores = inputs[self._positive_prefix] # (all_batch_items) - neg_scores = inputs[self._negative_prefix] # (all_batch_items) - loss = -torch.log( - (pos_scores - neg_scores).sigmoid() + 1e-9, - ).mean() # (1) - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class RegularizationLoss(TorchLoss, config_name='regularization'): - def __init__(self, prefix, output_prefix=None): - super().__init__() - self._prefix = maybe_to_list(prefix) - self._output_prefix = output_prefix - - def forward(self, inputs): - loss = 0.0 - for prefix in self._prefix: - loss += (1 / 2) * inputs[prefix].pow(2).mean() - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - class FpsLoss(TorchLoss, config_name='fps'): def __init__( self, @@ -192,8 +69,6 @@ def __init__( normalize_embeddings=False, use_mean=True, output_prefix=None, - use_logq_correction=False, - logq_prefix=None, ): super().__init__() self._fst_embeddings_prefix = fst_embeddings_prefix @@ -204,8 +79,7 @@ def __init__( ) self._normalize_embeddings = normalize_embeddings self._output_prefix = output_prefix - self._use_logq_correction = use_logq_correction - self._logq_prefix = logq_prefix + print(self._tau) @classmethod def create_from_config(cls, config, **kwargs): @@ -215,13 +89,10 @@ def create_from_config(cls, config, **kwargs): tau=config.get('temperature', 1.0), normalize_embeddings=config.get('normalize_embeddings', False), use_mean=config.get('use_mean', True), - output_prefix=config.get('output_prefix'), - use_logq_correction=config.get('use_logq_correction', False), - logq_prefix=config.get('logq_prefix', None), + output_prefix=config.get('output_prefix') ) def forward(self, inputs): - fst_embeddings = inputs[ self._fst_embeddings_prefix ] # (x, embedding_dim) @@ -275,15 +146,6 @@ def forward(self, inputs): -1, ) # (2 * x, 2 * x - 2) - if self._use_logq_correction and self._logq_prefix is not None: - log_q = inputs[self._logq_prefix] - log_q_combined = torch.cat((log_q, log_q), dim=0) - - log_q_matrix = log_q_combined.unsqueeze(0).expand(2 * batch_size, -1) # (2B, 2B) - negative_log_q = log_q_matrix[mask].reshape(2 * batch_size, -1) # (2B, 2B-2) - - negative_samples = negative_samples - negative_log_q - labels = ( torch.zeros(2 * batch_size).to(positive_samples.device).long() ) # (2 * x) @@ -328,7 +190,7 @@ def forward(self, inputs): inputs[self._output_prefix] = loss.cpu().item() return loss - + class SamplesSoftmaxLoss(TorchLoss, config_name='sampled_softmax'): def __init__( @@ -396,176 +258,63 @@ def forward(self, inputs): return loss - -# sasrec logq debug -class LogqSamplesSoftmaxLoss(TorchLoss, config_name='logq_sampled_softmax'): +class MCLSRLoss(TorchLoss, config_name='mclsr'): def __init__( self, - queries_prefix, - positive_prefix, - negative_prefix, - positive_ids_prefix=None, - negative_ids_prefix=None, + all_scores_prefix, + mask_prefix, + normalize_embeddings=False, + tau=1.0, output_prefix=None, - use_logq_correction=False, - logq_prefix=None, - log_counts=None, ): super().__init__() - self._queries_prefix = queries_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - - self._positive_ids_prefix = positive_ids_prefix - self._negative_ids_prefix = negative_ids_prefix - + self._all_scores_prefix = all_scores_prefix + self._mask_prefix = mask_prefix + self._normalize_embeddings = normalize_embeddings self._output_prefix = output_prefix - self._use_logq = use_logq_correction - self._logq_prefix = logq_prefix - self._log_counts = log_counts - - @classmethod - def create_from_config(cls, config, **kwargs): - log_counts = None - path_to_counts = config.get('path_to_item_counts') - - if path_to_counts and config.get('use_logq_correction'): - import pickle - with open(path_to_counts, 'rb') as f: - counts = pickle.load(f) - - counts_tensor = torch.tensor(counts, dtype=torch.float32) - # Normalize in probability and use logarithm (Google Eq. 3) - probs = torch.clamp(counts_tensor / counts_tensor.sum(), min=1e-10) - log_counts = torch.log(probs) - logger.info(f"Loaded item counts from {path_to_counts} for LogQ correction") - - return cls( - queries_prefix=config['queries_prefix'], - positive_prefix=config['positive_prefix'], - negative_prefix=config['negative_prefix'], - positive_ids_prefix=config.get('positive_ids_prefix'), - negative_ids_prefix=config.get('negative_ids_prefix'), - output_prefix=config.get('output_prefix'), - use_logq_correction=config.get('use_logq_correction', False), - logq_prefix=config.get('logq_prefix'), - log_counts=log_counts # <-- ПЕРЕДАЕМ В КОНСТРУКТОР - ) + self._tau = tau def forward(self, inputs): - queries_embeddings = inputs[ - self._queries_prefix - ] # (batch_size, embedding_dim) - positive_embeddings = inputs[ - self._positive_prefix - ] # (batch_size, embedding_dim) - negative_embeddings = inputs[ - self._negative_prefix - ] # (num_negatives, embedding_dim) or (batch_size, num_negatives, embedding_dim) - - # b -- batch_size, d -- embedding_dim - positive_scores = torch.einsum( - 'bd,bd->b', - queries_embeddings, - positive_embeddings, - ).unsqueeze(-1) # (batch_size, 1) + all_scores = inputs[ + self._all_scores_prefix + ] # (batch_size, batch_size, seq_len) + mask = inputs[self._mask_prefix] # (batch_size) - if negative_embeddings.dim() == 2: # (num_negatives, embedding_dim) - # b -- batch_size, n -- num_negatives, d -- embedding_dim - negative_scores = torch.einsum( - 'bd,nd->bn', - queries_embeddings, - negative_embeddings, - ) # (batch_size, num_negatives) - else: - assert ( - negative_embeddings.dim() == 3 - ) # (batch_size, num_negatives, embedding_dim) - # b -- batch_size, n -- num_negatives, d -- embedding_dim - negative_scores = torch.einsum( - 'bd,bnd->bn', - queries_embeddings, - negative_embeddings, - ) # (batch_size, num_negatives) + batch_size = mask.shape[0] + seq_len = mask.shape[1] - # --- FALSE NEGATIVE MASKING (Critical for In-Batch Negatives) --- - # If we have item IDs, we must ensure that a positive item for a user - # is not treated as a negative for that same user if it appears - # elsewhere in the batch. - if self._positive_ids_prefix and self._negative_ids_prefix: - pos_ids = inputs[self._positive_ids_prefix] # (BatchSize,) - neg_ids = inputs[self._negative_ids_prefix] # (NumNegatives,) - - # Create a boolean mask of shape (BatchSize, NumNegatives) - # where True indicates that pos_ids[i] == neg_ids[j] - false_negative_mask = (pos_ids.unsqueeze(1) == neg_ids.unsqueeze(0)) - - # Mask out these scores by setting them to a very large negative value - # This prevents the model from receiving contradictory signals - # (trying to both increase and decrease the score of the same item). - negative_scores = negative_scores.masked_fill(false_negative_mask, -1e12) - - # --- 2. UNBIASED LOGQ CORRECTION --- - # Applying correction to EACH logit per Google Paper (Eq. 3) - if self._use_logq: - # Source of truth: our pre-loaded self._log_counts from the pickle - if self._log_counts is not None: - if self._log_counts.device != positive_scores.device: - self._log_counts = self._log_counts.to(positive_scores.device) - - # We need IDs to fetch the correct frequencies for items in this batch - pos_ids = inputs[self._positive_ids_prefix] - neg_ids = inputs[self._negative_ids_prefix] - - log_q_pos = self._log_counts[pos_ids].unsqueeze(-1) # (B, 1) - log_q_neg = self._log_counts[neg_ids] # (N,) or (B, N) - - # --- LOGQ CORRECTION COMMENTS --- - # According to "Sampling-Bias-Corrected Neural Modeling..." (Google, 2019): - # "we correct EACH logit s(x_i, y_j) by the following equation: - # s_c(x_i, y_j) = s(x_i, y_j) - log(p_j)" - # This ensures the estimator remains unbiased by penalizing popular - # items equally when they are targets and when they are negatives. - - positive_scores = positive_scores - log_q_pos - negative_scores = negative_scores - log_q_neg - - # (Optional) If frequencies were passed directly in inputs, not via pickle: - elif self._logq_prefix in inputs: - log_q = inputs[self._logq_prefix] - positive_scores = positive_scores - log_q[:, :1] - negative_scores = negative_scores - log_q[:, 1:] + positive_mask = torch.eye(batch_size, device=mask.device).bool() - all_scores = torch.cat( - [positive_scores, negative_scores], - dim=1, - ) # (batch_size, 1 + num_negatives) + positive_scores = all_scores[positive_mask] # (batch_size, seq_len) + negative_scores = torch.reshape( + all_scores[~positive_mask], + shape=(batch_size, batch_size - 1, seq_len), + ) # (batch_size, batch_size - 1, seq_len) + assert torch.allclose(all_scores[0, 1], negative_scores[0, 0]) + assert torch.allclose(all_scores[-1, -2], negative_scores[-1, -1]) + assert torch.allclose(all_scores[0, 0], positive_scores[0]) + assert torch.allclose(all_scores[-1, -1], positive_scores[-1]) - logits = torch.log_softmax( - all_scores, - dim=1, - ) # (batch_size, 1 + num_negatives) - loss = (-logits)[:, 0] # (batch_size) - loss = loss.mean() # (1) + # Maybe try mean over sequence TODO + loss = torch.sum( + torch.log( + torch.sigmoid(positive_scores.unsqueeze(1) - negative_scores), + ), + ) # (1) if self._output_prefix is not None: inputs[self._output_prefix] = loss.cpu().item() return loss - + class MCLSRLogqLoss(TorchLoss, config_name='mclsr_logq_special'): """ - Specialized Loss for MCLSR model implementing Sampling-Bias Correction (LogQ) - with a tunable lambda coefficient for correction strength. - - Theory: - Following Yi et al. (2019), we correct the sampling bias introduced by non-uniform - negative sampling (e.g., popularity-based or in-batch sampling). + LogQ-corrected Sampled Softmax Loss for MCLSR model. + Implements sampling-bias correction: s_c(x, y) = s(x, y) - lambda * log(p_j) - Mathematical Formula (Eq. 3 in the paper): - s_c(x, y) = s(x, y) - lambda * log(p_j) - where p_j is the sampling probability of item j and lambda is the correction strength. + This adjustment compensates for non-uniform negative sampling (e.g., popularity-based), + preventing the model from over-penalizing popular items. """ def __init__( self, @@ -575,7 +324,7 @@ def __init__( positive_ids_prefix, negative_ids_prefix, path_to_item_counts, - logq_lambda=1.0, # Strength of the LogQ correction (usually 0.1 to 1.0) + logq_lambda=1.0, output_prefix=None, ): super().__init__() @@ -587,6 +336,7 @@ def __init__( self._output_prefix = output_prefix self._logq_lambda = logq_lambda + # Load global item frequencies to calculate sampling probabilities (p_j) if not os.path.exists(path_to_item_counts): raise FileNotFoundError(f"Item counts file not found at {path_to_item_counts}") @@ -595,18 +345,18 @@ def __init__( counts_tensor = torch.tensor(counts, dtype=torch.float32) - # Calculate log probabilities log(p_j). - # Using clamp(min=1e-10) for numerical stability to avoid log(0) -> NaN. + # Calculate log-probabilities. + # Clamp used for numerical stability to avoid log(0) resulting in NaN. probs = torch.clamp(counts_tensor / counts_tensor.sum(), min=1e-10) log_q = torch.log(probs) - # register_buffer ensures the lookup table moves to the correct device (GPU/CPU) - # along with the model parameters. + # register_buffer ensures the lookup table is moved to the correct + # device (GPU/CPU) automatically during training. self.register_buffer('_log_q_table', log_q) @classmethod def create_from_config(cls, config, **kwargs): - """Standard framework factory method to initialize the loss from a JSON config.""" + """Factory method to initialize loss from JSON configuration.""" return cls( queries_prefix=config['queries_prefix'], positive_prefix=config['positive_prefix'], @@ -619,65 +369,39 @@ def create_from_config(cls, config, **kwargs): ) def forward(self, inputs): - # Retrieve Embeddings + # 1. Extract embeddings and item IDs queries = inputs[self._queries_prefix] # (Batch, Dim) pos_embs = inputs[self._positive_prefix] # (Batch, Dim) - neg_embs = inputs[self._negative_prefix] # (B, N, D) or (N, D) + neg_embs = inputs[self._negative_prefix] # (Batch, NumNegs, Dim) - # Retrieve Item IDs for correction and collision masking pos_ids = inputs[self._positive_ids_prefix] # (Batch) - neg_ids = inputs[self._negative_ids_prefix] # (Batch, N) or (N) - - # --- DEVICE SYNCHRONIZATION --- - # Explicitly ensure the LogQ lookup table is on the same device as the inputs. - # This prevents "indices should be on the same device" RuntimeError on CUDA. - device = queries.device - if self._log_q_table.device != device: - self._log_q_table = self._log_q_table.to(device) - - # --- STEP 1: Score Calculation (Inner Product) --- - # Positive scores: (Batch, 1) - pos_scores = torch.einsum('bd,bd->b', queries, pos_embs).unsqueeze(-1) - - # Negative scores with dimension check (handles both shared and per-user negatives) - if neg_embs.dim() == 2: - # Case: Shared negatives across the batch (N, D) -> Result: (B, N) - neg_scores = torch.einsum('bd,nd->bn', queries, neg_embs) - else: - # Case: Individual negatives per user (B, N, D) -> Result: (B, N) - neg_scores = torch.einsum('bd,bnd->bn', queries, neg_embs) - - # --- STEP 2: False Negative Masking --- - # Prevent contradictory gradients by masking negative samples that match - # the ground truth item ID for a given user. - # logic: pos_ids.unsqueeze(1) creates (B, 1) for broadcasting against (B, N) or (1, N) - neg_id_comparison = neg_ids.unsqueeze(0) if neg_ids.dim() == 1 else neg_ids - false_negative_mask = (pos_ids.unsqueeze(1) == neg_id_comparison) - - # Mask out scores of false negatives by setting them to a large negative value + neg_ids = inputs[self._negative_ids_prefix] # (Batch, NumNegs) + + # Device synchronization check + if self._log_q_table.device != queries.device: + self._log_q_table = self._log_q_table.to(queries.device) + + # 2. Compute raw scores (Dot Product) + # Using einsum for efficient multiplication of 2D queries and 3D negatives + pos_scores = torch.einsum('bd,bd->b', queries, pos_embs).unsqueeze(-1) # (B, 1) + neg_scores = torch.einsum('bd,bnd->bn', queries, neg_embs) # (B, N) + + # 3. False Negative Masking + # Neutralize cases where the sampled negative item is actually the target item + false_negative_mask = (pos_ids.unsqueeze(1) == neg_ids) neg_scores = neg_scores.masked_fill(false_negative_mask, -1e12) - # --- STEP 3: Tunable LogQ Correction --- - # Applying the bias correction as per Google Paper Eq. 3, scaled by lambda. - # s_c = s - lambda * log(Q) + # 4. Apply LogQ Correction + # Correction term: score = score - lambda * log(p_j) log_q_pos = self._log_q_table[pos_ids].unsqueeze(-1) # (B, 1) + log_q_neg = self._log_q_table[neg_ids] # (B, N) - if neg_ids.dim() == 1: - # Case: Shared negative IDs - log_q_neg = self._log_q_table[neg_ids].unsqueeze(0) # (1, N) - else: - # Case: Individual negative IDs - log_q_neg = self._log_q_table[neg_ids] # (B, N) - - # Apply final correction pos_scores = pos_scores - (self._logq_lambda * log_q_pos) neg_scores = neg_scores - (self._logq_lambda * log_q_neg) - # --- STEP 4: Softmax Loss Calculation --- - # Concatenate positive and negative scores: (B, 1 + N) - all_scores = torch.cat([pos_scores, neg_scores], dim=1) - - # Cross-Entropy using log_softmax for numerical stability + # 5. Final Softmax Reranking + # Concatenate scores and compute cross-entropy over the sampled items + all_scores = torch.cat([pos_scores, neg_scores], dim=1) # (B, 1+N) loss = -torch.log_softmax(all_scores, dim=1)[:, 0] final_loss = loss.mean() @@ -686,256 +410,3 @@ def forward(self, inputs): return final_loss -class S3RecPretrainLoss(TorchLoss, config_name='s3rec_pretrain'): - def __init__( - self, - positive_prefix, - negative_prefix, - representation_prefix, - output_prefix=None, - ): - super().__init__() - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._representation_prefix = representation_prefix - self._criterion = nn.BCEWithLogitsLoss(reduction='none') - self._output_prefix = output_prefix - - def forward(self, inputs): - positive_embeddings = inputs[ - self._positive_prefix - ] # (x, embedding_dim) - negative_embeddings = inputs[ - self._negative_prefix - ] # (x, embedding_dim) - current_embeddings = inputs[ - self._representation_prefix - ] # (x, embedding_dim) - assert ( - positive_embeddings.shape[0] - == negative_embeddings.shape[0] - == current_embeddings.shape[0] - ) - - positive_scores = torch.einsum( - 'bd,bd->b', - positive_embeddings, - current_embeddings, - ) # (x) - - negative_scores = torch.einsum( - 'bd,bd->b', - negative_embeddings, - current_embeddings, - ) # (x) - - distance = torch.sigmoid(positive_scores) - torch.sigmoid( - negative_scores, - ) # (x) - loss = torch.sum( - self._criterion( - distance, - torch.ones_like(distance, dtype=torch.float32), - ), - ) # (1) - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class Cl4sRecLoss(TorchLoss, config_name='cl4srec'): - def __init__( - self, - current_representation, - all_items_representation, - tau=1.0, - output_prefix=None, - ): - super().__init__() - self._current_representation = current_representation - self._all_items_representation = all_items_representation - self._loss_function = nn.CrossEntropyLoss() - self._tau = tau - self._output_prefix = output_prefix - - def forward(self, inputs): - current_representation = inputs[ - self._current_representation - ] # (batch_size, embedding_dim) - all_items_representation = inputs[ - self._all_items_representation - ] # (batch_size, num_negatives + 1, embedding_dim) - - batch_size = current_representation.shape[0] - - logits = torch.einsum( - 'bnd,bd->bn', - all_items_representation, - current_representation, - ) # (batch_size, num_negatives + 1) - labels = logits.new_zeros(batch_size) # (batch_size) - - loss = self._loss_function(logits, labels) - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class DuorecSSLLoss(TorchLoss, config_name='duorec_ssl'): - def __init__( - self, - original_embedding_prefix, - dropout_embedding_prefix, - similar_embedding_prefix, - normalize_embeddings=False, - tau=1.0, - output_prefix=None, - ): - super().__init__() - self._original_embedding_prefix = original_embedding_prefix - self._dropout_embedding_prefix = dropout_embedding_prefix - self._similar_embedding_prefix = similar_embedding_prefix - self._normalize_embeddings = normalize_embeddings - self._output_prefix = output_prefix - self._tau = tau - self._loss_function = nn.CrossEntropyLoss(reduction='mean') - - def _compute_partial_loss(self, fst_embeddings, snd_embeddings): - batch_size = fst_embeddings.shape[0] - - combined_embeddings = torch.cat( - (fst_embeddings, snd_embeddings), - dim=0, - ) # (2 * x, embedding_dim) - - if self._normalize_embeddings: - combined_embeddings = torch.nn.functional.normalize( - combined_embeddings, - p=2, - dim=-1, - eps=1e-6, - ) - - similarity_scores = ( - torch.mm(combined_embeddings, combined_embeddings.T) / self._tau - ) # (2 * x, 2 * x) - - positive_samples = torch.cat( - ( - torch.diag(similarity_scores, batch_size), - torch.diag(similarity_scores, -batch_size), - ), - dim=0, - ).reshape(2 * batch_size, 1) # (2 * x, 1) - - # TODO optimize - mask = torch.ones( - 2 * batch_size, - 2 * batch_size, - dtype=torch.bool, - ) # (2 * x, 2 * x) - mask = mask.fill_diagonal_(0) # Remove equal embeddings scores - for i in range(batch_size): # Remove positives - mask[i, batch_size + i] = 0 - mask[batch_size + i, i] = 0 - - negative_samples = similarity_scores[mask].reshape( - 2 * batch_size, - -1, - ) # (2 * x, 2 * x - 2) - - labels = ( - torch.zeros(2 * batch_size).to(positive_samples.device).long() - ) # (2 * x) - logits = torch.cat( - (positive_samples, negative_samples), - dim=1, - ) # (2 * x, 2 * x - 1) - - loss = self._loss_function(logits, labels) / 2 # (1) - - return loss - - def forward(self, inputs): - original_embeddings = inputs[ - self._original_embedding_prefix - ] # (x, embedding_dim) - dropout_embeddings = inputs[ - self._dropout_embedding_prefix - ] # (x, embedding_dim) - similar_embeddings = inputs[ - self._similar_embedding_prefix - ] # (x, embedding_dim) - - dropout_loss = self._compute_partial_loss( - original_embeddings, - dropout_embeddings, - ) - ssl_loss = self._compute_partial_loss( - original_embeddings, - similar_embeddings, - ) - - loss = dropout_loss + ssl_loss - - if self._output_prefix is not None: - inputs[f'{self._output_prefix}_dropout'] = ( - dropout_loss.cpu().item() - ) - inputs[f'{self._output_prefix}_ssl'] = ssl_loss.cpu().item() - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class MCLSRLoss(TorchLoss, config_name='mclsr'): - def __init__( - self, - all_scores_prefix, - mask_prefix, - normalize_embeddings=False, - tau=1.0, - output_prefix=None, - ): - super().__init__() - self._all_scores_prefix = all_scores_prefix - self._mask_prefix = mask_prefix - self._normalize_embeddings = normalize_embeddings - self._output_prefix = output_prefix - self._tau = tau - - def forward(self, inputs): - all_scores = inputs[ - self._all_scores_prefix - ] # (batch_size, batch_size, seq_len) - mask = inputs[self._mask_prefix] # (batch_size) - - batch_size = mask.shape[0] - seq_len = mask.shape[1] - - positive_mask = torch.eye(batch_size, device=mask.device).bool() - - positive_scores = all_scores[positive_mask] # (batch_size, seq_len) - negative_scores = torch.reshape( - all_scores[~positive_mask], - shape=(batch_size, batch_size - 1, seq_len), - ) # (batch_size, batch_size - 1, seq_len) - assert torch.allclose(all_scores[0, 1], negative_scores[0, 0]) - assert torch.allclose(all_scores[-1, -2], negative_scores[-1, -1]) - assert torch.allclose(all_scores[0, 0], positive_scores[0]) - assert torch.allclose(all_scores[-1, -1], positive_scores[-1]) - - # Maybe try mean over sequence TODO - loss = torch.sum( - torch.log( - torch.sigmoid(positive_scores.unsqueeze(1) - negative_scores), - ), - ) # (1) - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss diff --git a/src/irec/models/__init__.py b/src/irec/models/__init__.py index 351a048..761e479 100644 --- a/src/irec/models/__init__.py +++ b/src/irec/models/__init__.py @@ -1,13 +1,9 @@ from .base import BaseModel, SequentialTorchModel from .mclsr import MCLSRModel from .sasrec import SasRecModel, SasRecInBatchModel -from .sasrec_ce import SasRecCeModel __all__ = [ 'BaseModel', 'MCLSRModel', 'SasRecModel', - 'SasRecInBatchModel', - 'SasRecCeModel', - 'SasRecRealModel', ] diff --git a/src/irec/models/mclsr.py b/src/irec/models/mclsr.py index ef841d1..1c7c7a2 100644 --- a/src/irec/models/mclsr.py +++ b/src/irec/models/mclsr.py @@ -363,8 +363,14 @@ def scatter_mean(src, index, dim=0, dim_size=None): unique_common_graph_items_proj = self._item_projection(unique_common_graph_items) unique_item_graph_items_proj = self._item_projection(unique_item_graph_items) - negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] # (batch_size, num_negatives) - negative_embeddings = self._item_embeddings(negative_ids) # (batch_size, num_negatives, embedding_dim) + + # negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] # (batch_size, num_negatives) + # negative_embeddings = self._item_embeddings(negative_ids) # (batch_size, num_negatives, embedding_dim) + + raw_negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] + num_negatives = raw_negative_ids.shape[0] // batch_size + negative_ids = raw_negative_ids.view(batch_size, num_negatives) # (Batch, NumNegs) + negative_embeddings = self._item_embeddings(negative_ids) # (Batch, NumNegs, Dim) # import code; code.interact(local=locals()) @@ -375,7 +381,7 @@ def scatter_mean(src, index, dim=0, dim_size=None): 'negative_representation': negative_embeddings, - + # --- ID PASS-THROUGH FOR LOGQ & MASKING --- # We pass raw item and user indices to enable advanced loss operations: # 1. False Negative Masking: Allows SamplesSoftmaxLoss to identify and @@ -390,8 +396,9 @@ def scatter_mean(src, index, dim=0, dim_size=None): # by the supervisor to handle highly active users. 'user_ids': user_ids, - # for L_IL (formula 8) + + # for L_IL (formula 8) 'sequential_representation': sequential_representation_proj, 'graph_representation': graph_representation_proj, diff --git a/src/irec/models/sasrec_ce.py b/src/irec/models/sasrec_ce.py deleted file mode 100644 index c07b9ff..0000000 --- a/src/irec/models/sasrec_ce.py +++ /dev/null @@ -1,108 +0,0 @@ -from .base import SequentialTorchModel - -import torch -import torch.nn as nn - - -class SasRecCeModel(SequentialTorchModel, config_name='sasrec_ce'): - def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=True, - ) - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - - self._output_projection = nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get( - 'num_heads', - int(config['embedding_dim'] // 64), - ), - num_layers=config['num_layers'], - dim_feedforward=config.get( - 'dim_feedforward', - 4 * config['embedding_dim'], - ), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, - all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - embeddings = self._output_projection( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.nn.functional.gelu( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.einsum( - 'bsd,nd->bsn', - embeddings, - self._item_embeddings.weight, - ) # (batch_size, seq_len, num_items + 2) - - if self.training: # training mode - return {'logits': embeddings[mask]} - else: # eval mode - candidate_scores = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices From f28f9415018fc9d3bdfbe776772618dd0d8a4f38 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Wed, 1 Apr 2026 16:05:36 +0300 Subject: [PATCH 16/27] delete logq --- configs/train/mclsr_logq_1903.json | 160 ---------------------- configs/train/sasrec_train_config.json | 179 +++++++++++++++++++++++++ src/irec/loss/base.py | 101 -------------- 3 files changed, 179 insertions(+), 261 deletions(-) delete mode 100644 configs/train/mclsr_logq_1903.json create mode 100644 configs/train/sasrec_train_config.json diff --git a/configs/train/mclsr_logq_1903.json b/configs/train/mclsr_logq_1903.json deleted file mode 100644 index e25b105..0000000 --- a/configs/train/mclsr_logq_1903.json +++ /dev/null @@ -1,160 +0,0 @@ -{ - "experiment_name": "mclsr_random_logq_v0.1_Clothing_1903", - "use_wandb": false, - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "graph", - "use_user_graph": true, - "use_item_graph": true, - "neighborhood_size": 50, - "graph_dir_path": "./data/Clothing", - "dataset": { - "type": "mclsr", - "path_to_data_dir": "./data", - "name": "Clothing", - "max_sequence_length": 20, - "samplers": { - "num_negatives_val": 1280, - "num_negatives_train": 1280, - "type": "mclsr", - "negative_sampler_type": "popularity" - } - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true, - "pin_memory": true - }, - "validation": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false, - "pin_memory": true - } - }, - "model": { - "type": "mclsr", - "sequence_prefix": "item", - "user_prefix": "user", - "labels_prefix": "labels", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_graph_layers": 2, - "dropout": 0.3, - "layer_norm_eps": 1e-9, - "graph_dropout": 0.3, - "initializer_range": 0.02, - "alpha": 0.5 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "mclsr_logq_special", - "queries_prefix": "combined_representation", - "positive_prefix": "label_representation", - "negative_prefix": "negative_representation", - "positive_ids_prefix": "positive_ids", - "negative_ids_prefix": "negative_ids", - "path_to_item_counts": "./data/Clothing/item_counts.pkl", - "logq_lambda": 1.0, - "output_prefix": "downstream_loss", - "weight": 1.0 - }, - { - "type": "fps", - "fst_embeddings_prefix": "sequential_representation", - "snd_embeddings_prefix": "graph_representation", - "output_prefix": "contrastive_interest_loss", - "weight": 1.0, - "temperature": 0.5 - }, - { - "type": "fps", - "fst_embeddings_prefix": "user_graph_user_embeddings", - "snd_embeddings_prefix": "common_graph_user_embeddings", - "output_prefix": "contrastive_user_feature_loss", - "weight": 0.05, - "temperature": 0.5 - }, - { - "type": "fps", - "fst_embeddings_prefix": "item_graph_item_embeddings", - "snd_embeddings_prefix": "common_graph_item_embeddings", - "output_prefix": "contrastive_item_feature_loss", - "weight": 0.05, - "temperature": 0.5 - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "downstream_loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, - "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, - "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, - "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, - "recall@5": { "type": "mclsr-recall", "k": 5 }, - "recall@10": { "type": "mclsr-recall", "k": 10 }, - "recall@20": { "type": "mclsr-recall", "k": 20 }, - "recall@50": { "type": "mclsr-recall", "k": 50 }, - "hit@20": { "type": "mclsr-hit", "k": 20 }, - "hit@50": { "type": "mclsr-hit", "k": 50 } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { "type": "mclsr-ndcg", "k": 5 }, - "ndcg@10": { "type": "mclsr-ndcg", "k": 10 }, - "ndcg@20": { "type": "mclsr-ndcg", "k": 20 }, - "ndcg@50": { "type": "mclsr-ndcg", "k": 50 }, - "recall@5": { "type": "mclsr-recall", "k": 5 }, - "recall@10": { "type": "mclsr-recall", "k": 10 }, - "recall@20": { "type": "mclsr-recall", "k": 20 }, - "recall@50": { "type": "mclsr-recall", "k": 50 }, - "hit@20": { "type": "mclsr-hit", "k": 20 }, - "hit@50": { "type": "mclsr-hit", "k": 50 } - } - } - ] - } -} diff --git a/configs/train/sasrec_train_config.json b/configs/train/sasrec_train_config.json new file mode 100644 index 0000000..37b2888 --- /dev/null +++ b/configs/train/sasrec_train_config.json @@ -0,0 +1,179 @@ +{ + "experiment_name": "sasrec_clothing", + "use_wandb": true, + "best_metric": "validation/ndcg@20", + "dataset": { + "type": "sasrec_comparison", + "path_to_data_dir": "./data", + "name": "Clothing", + "max_sequence_length": 20, + "train_sampler": { + "type": "next_item_prediction", + "negative_sampler_type": "random", + "num_negatives_train": 0 + }, + "eval_sampler": { + "type": "mclsr" + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": true, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "type": "sasrec", + "sequence_prefix": "item", + "positive_prefix": "positive", + "candidate_prefix": "candidates", + "embedding_dim": 64, + "num_heads": 2, + "num_layers": 2, + "dim_feedforward": 256, + "dropout": 0.3, + "activation": "gelu", + "layer_norm_eps": 1e-9, + "initializer_range": 0.02 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 0.001 + }, + "clip_grad_threshold": 1.0 + }, + "loss": { + "type": "composite", + "losses": [ + { + "type": "sasrec", + "positive_prefix": "positive_scores", + "negative_prefix": "negative_scores", + "output_prefix": "downstream_loss" + } + ], + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "validation", + "on_step": 64, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "mclsr-ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "mclsr-ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "mclsr-ndcg", + "k": 20 + }, + "ndcg@50": { + "type": "mclsr-ndcg", + "k": 50 + }, + "recall@5": { + "type": "mclsr-recall", + "k": 5 + }, + "recall@10": { + "type": "mclsr-recall", + "k": 10 + }, + "recall@20": { + "type": "mclsr-recall", + "k": 20 + }, + "recall@50": { + "type": "mclsr-recall", + "k": 50 + }, + "hit@20": { + "type": "mclsr-hit", + "k": 20 + }, + "hit@50": { + "type": "mclsr-hit", + "k": 50 + } + } + }, + { + "type": "eval", + "on_step": 256, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "mclsr-ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "mclsr-ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "mclsr-ndcg", + "k": 20 + }, + "ndcg@50": { + "type": "mclsr-ndcg", + "k": 50 + }, + "recall@5": { + "type": "mclsr-recall", + "k": 5 + }, + "recall@10": { + "type": "mclsr-recall", + "k": 10 + }, + "recall@20": { + "type": "mclsr-recall", + "k": 20 + }, + "recall@50": { + "type": "mclsr-recall", + "k": 50 + }, + "hit@20": { + "type": "mclsr-hit", + "k": 20 + }, + "hit@50": { + "type": "mclsr-hit", + "k": 50 + } + } + } + ] + } +} \ No newline at end of file diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index 6152e5d..6f761e3 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -308,105 +308,4 @@ def forward(self, inputs): return loss -class MCLSRLogqLoss(TorchLoss, config_name='mclsr_logq_special'): - """ - LogQ-corrected Sampled Softmax Loss for MCLSR model. - Implements sampling-bias correction: s_c(x, y) = s(x, y) - lambda * log(p_j) - - This adjustment compensates for non-uniform negative sampling (e.g., popularity-based), - preventing the model from over-penalizing popular items. - """ - def __init__( - self, - queries_prefix, - positive_prefix, - negative_prefix, - positive_ids_prefix, - negative_ids_prefix, - path_to_item_counts, - logq_lambda=1.0, - output_prefix=None, - ): - super().__init__() - self._queries_prefix = queries_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._positive_ids_prefix = positive_ids_prefix - self._negative_ids_prefix = negative_ids_prefix - self._output_prefix = output_prefix - self._logq_lambda = logq_lambda - - # Load global item frequencies to calculate sampling probabilities (p_j) - if not os.path.exists(path_to_item_counts): - raise FileNotFoundError(f"Item counts file not found at {path_to_item_counts}") - - with open(path_to_item_counts, 'rb') as f: - counts = pickle.load(f) - - counts_tensor = torch.tensor(counts, dtype=torch.float32) - - # Calculate log-probabilities. - # Clamp used for numerical stability to avoid log(0) resulting in NaN. - probs = torch.clamp(counts_tensor / counts_tensor.sum(), min=1e-10) - log_q = torch.log(probs) - - # register_buffer ensures the lookup table is moved to the correct - # device (GPU/CPU) automatically during training. - self.register_buffer('_log_q_table', log_q) - - @classmethod - def create_from_config(cls, config, **kwargs): - """Factory method to initialize loss from JSON configuration.""" - return cls( - queries_prefix=config['queries_prefix'], - positive_prefix=config['positive_prefix'], - negative_prefix=config['negative_prefix'], - positive_ids_prefix=config['positive_ids_prefix'], - negative_ids_prefix=config['negative_ids_prefix'], - path_to_item_counts=config['path_to_item_counts'], - logq_lambda=config.get('logq_lambda', 1.0), - output_prefix=config.get('output_prefix') - ) - - def forward(self, inputs): - # 1. Extract embeddings and item IDs - queries = inputs[self._queries_prefix] # (Batch, Dim) - pos_embs = inputs[self._positive_prefix] # (Batch, Dim) - neg_embs = inputs[self._negative_prefix] # (Batch, NumNegs, Dim) - - pos_ids = inputs[self._positive_ids_prefix] # (Batch) - neg_ids = inputs[self._negative_ids_prefix] # (Batch, NumNegs) - - # Device synchronization check - if self._log_q_table.device != queries.device: - self._log_q_table = self._log_q_table.to(queries.device) - - # 2. Compute raw scores (Dot Product) - # Using einsum for efficient multiplication of 2D queries and 3D negatives - pos_scores = torch.einsum('bd,bd->b', queries, pos_embs).unsqueeze(-1) # (B, 1) - neg_scores = torch.einsum('bd,bnd->bn', queries, neg_embs) # (B, N) - - # 3. False Negative Masking - # Neutralize cases where the sampled negative item is actually the target item - false_negative_mask = (pos_ids.unsqueeze(1) == neg_ids) - neg_scores = neg_scores.masked_fill(false_negative_mask, -1e12) - - # 4. Apply LogQ Correction - # Correction term: score = score - lambda * log(p_j) - log_q_pos = self._log_q_table[pos_ids].unsqueeze(-1) # (B, 1) - log_q_neg = self._log_q_table[neg_ids] # (B, N) - - pos_scores = pos_scores - (self._logq_lambda * log_q_pos) - neg_scores = neg_scores - (self._logq_lambda * log_q_neg) - - # 5. Final Softmax Reranking - # Concatenate scores and compute cross-entropy over the sampled items - all_scores = torch.cat([pos_scores, neg_scores], dim=1) # (B, 1+N) - loss = -torch.log_softmax(all_scores, dim=1)[:, 0] - - final_loss = loss.mean() - if self._output_prefix: - inputs[self._output_prefix] = final_loss.cpu().item() - - return final_loss From eaf2b88014a1d8b4408a2d242a80e1ce85248553 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Wed, 1 Apr 2026 16:11:30 +0300 Subject: [PATCH 17/27] clean --- src/irec/dataset/base.py | 10 ------- src/irec/dataset/negative_samplers/popular.py | 27 ------------------- src/irec/dataset/negative_samplers/random.py | 15 ----------- src/irec/dataset/samplers/base.py | 2 -- .../dataset/samplers/next_item_prediction.py | 2 -- src/irec/loss/base.py | 3 +-- 6 files changed, 1 insertion(+), 58 deletions(-) diff --git a/src/irec/dataset/base.py b/src/irec/dataset/base.py index e8f4218..25fdf11 100644 --- a/src/irec/dataset/base.py +++ b/src/irec/dataset/base.py @@ -383,10 +383,6 @@ def _build_or_load_similarity_graph( ): if entity_type not in ['user', 'item']: raise ValueError("entity_type must be either 'user' or 'item'") - # have to delete and replace to not delete npz each time manually - # path_to_graph = os.path.join(self._graph_dir_path, '{}_graph.npz'.format(entity_type)) - - # instead better use such construction # neighborhood_size # The neighborhood_size is a filter that constrains the number of edges for each user or @@ -422,7 +418,6 @@ def _build_or_load_similarity_graph( continue visited_user_item_pairs.add((user_id, item_id)) - # TODO look here at review source_entity = user_id if is_user_graph else item_id connection_map = train_item_2_users if is_user_graph else train_user_2_items connection_point = item_id if is_user_graph else user_id @@ -431,11 +426,6 @@ def _build_or_load_similarity_graph( if source_entity == connected_entity: continue - pair_key = (source_entity, connected_entity) - # if pair_key in visited_entity_pairs: - # continue - - # visited_entity_pairs.add(pair_key) interactions_fst.append(source_entity) interactions_snd.append(connected_entity) diff --git a/src/irec/dataset/negative_samplers/popular.py b/src/irec/dataset/negative_samplers/popular.py index f004efe..9859b1a 100644 --- a/src/irec/dataset/negative_samplers/popular.py +++ b/src/irec/dataset/negative_samplers/popular.py @@ -10,12 +10,6 @@ def __init__(self, dataset, num_users, num_items): num_users=num_users, num_items=num_items, ) - - # --- OLD DETERMINISTIC LOGIC --- - # self._popular_items = self._items_by_popularity() - - # --- NEW STOCHASTIC LOGIC FOR LogQ COMPATIBILITY --- - # Pre-calculate item probabilities based on global frequency self._item_ids, self._probs = self._calculate_item_probabilities() @classmethod @@ -26,15 +20,6 @@ def create_from_config(cls, _, **kwargs): num_items=kwargs['num_items'], ) - # --- OLD METHOD: Deterministic sorting --- - # def _items_by_popularity(self): - # popularity = Counter() - # for sample in self._dataset: - # for item_id in sample['item.ids']: - # popularity[item_id] += 1 - # popular_items = sorted(popularity, key=popularity.get, reverse=True) - # return popular_items - def _calculate_item_probabilities(self): """ Calculates sampling probabilities proportional to item popularity. @@ -51,18 +36,6 @@ def _calculate_item_probabilities(self): return items, probabilities - # --- OLD METHOD: Picking Top-K items sequentially (Deterministic) --- - # def generate_negative_samples(self, sample, num_negatives): - # user_id = sample['user.ids'][0] - # popularity_idx = 0 - # negatives = [] - # while len(negatives) < num_negatives: - # negative_idx = self._popular_items[popularity_idx] - # if negative_idx not in self._seen_items[user_id]: - # negatives.append(negative_idx) - # popularity_idx += 1 - # return negatives - def generate_negative_samples(self, sample, num_negatives): """ Stochastic sampling proportional to popularity. diff --git a/src/irec/dataset/negative_samplers/random.py b/src/irec/dataset/negative_samplers/random.py index a4926b1..8c2a301 100644 --- a/src/irec/dataset/negative_samplers/random.py +++ b/src/irec/dataset/negative_samplers/random.py @@ -12,21 +12,6 @@ def create_from_config(cls, _, **kwargs): num_items=kwargs['num_items'], ) - # def generate_negative_samples(self, sample, num_negatives): - # user_id = sample['user.ids'][0] - # all_items = list(range(1, self._num_items + 1)) - # np.random.shuffle(all_items) - - # negatives = [] - # running_idx = 0 - # while len(negatives) < num_negatives and running_idx < len(all_items): - # negative_idx = all_items[running_idx] - # if negative_idx not in self._seen_items[user_id]: - # negatives.append(negative_idx) - # running_idx += 1 - - # return negatives - def generate_negative_samples(self, sample, num_negatives): """ Optimized via Rejection Sampling (O(k) complexity). diff --git a/src/irec/dataset/samplers/base.py b/src/irec/dataset/samplers/base.py index 7ceca78..ad6f3bf 100644 --- a/src/irec/dataset/samplers/base.py +++ b/src/irec/dataset/samplers/base.py @@ -33,8 +33,6 @@ def __len__(self): return len(self._dataset) def __getitem__(self, index): - # sample = copy.deepcopy(self._dataset[index]) - # yes, it's safe sample = self._dataset[index] item_sequence = sample['item.ids'][:-1] diff --git a/src/irec/dataset/samplers/next_item_prediction.py b/src/irec/dataset/samplers/next_item_prediction.py index 0db7eb7..47e69a8 100644 --- a/src/irec/dataset/samplers/next_item_prediction.py +++ b/src/irec/dataset/samplers/next_item_prediction.py @@ -39,8 +39,6 @@ def create_from_config(cls, config, **kwargs): ) def __getitem__(self, index): - # sample = copy.deepcopy(self._dataset[index]) - # yes, it's safe sample = self._dataset[index] item_sequence = sample['item.ids'][:-1] diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index 6f761e3..e082b4d 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -296,12 +296,11 @@ def forward(self, inputs): assert torch.allclose(all_scores[0, 0], positive_scores[0]) assert torch.allclose(all_scores[-1, -1], positive_scores[-1]) - # Maybe try mean over sequence TODO loss = torch.sum( torch.log( torch.sigmoid(positive_scores.unsqueeze(1) - negative_scores), ), - ) # (1) + ) if self._output_prefix is not None: inputs[self._output_prefix] = loss.cpu().item() From 1749edf9b8f11d7aa3819329810fb6978ef7973c Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Wed, 1 Apr 2026 16:13:18 +0300 Subject: [PATCH 18/27] clean old code --- src/irec/dataloader/base.py | 9 --------- src/irec/dataset/base.py | 3 +-- src/irec/models/mclsr.py | 5 ----- src/irec/models/sasrec.py | 15 --------------- 4 files changed, 1 insertion(+), 31 deletions(-) diff --git a/src/irec/dataloader/base.py b/src/irec/dataloader/base.py index 4bc0f75..20ec865 100644 --- a/src/irec/dataloader/base.py +++ b/src/irec/dataloader/base.py @@ -46,12 +46,3 @@ def create_from_config(cls, config, **kwargs): **create_config, ), ) - - # return cls( - # dataloader=DataLoader( - # kwargs['dataset'], - # collate_fn=batch_processor, - # pin_memory=True, - # **create_config, - # ), - # ) diff --git a/src/irec/dataset/base.py b/src/irec/dataset/base.py index 25fdf11..40d6a88 100644 --- a/src/irec/dataset/base.py +++ b/src/irec/dataset/base.py @@ -453,7 +453,6 @@ def _build_or_load_similarity_graph( return self._convert_sp_mat_to_sp_tensor(graph_matrix).coalesce().to(DEVICE) def _build_or_load_bipartite_graph(self, graph_dir_path, train_user_interactions, train_item_interactions): - # path_to_graph = os.path.join(graph_dir_path, 'general_graph.npz') train_suffix = "trainOnly" if self._use_train_data_only else "withValTest" filename = f"general_graph_{train_suffix}.npz" path_to_graph = os.path.join(graph_dir_path, filename) @@ -752,4 +751,4 @@ def create_from_config(cls, config, **kwargs): test_sampler = EvalSampler.create_from_config(config['samplers'], dataset=test_dataset, num_users=num_users, num_items=num_items, **kwargs) return cls(train_sampler, validation_sampler, test_sampler, num_users, num_items, max_seq_len) - + diff --git a/src/irec/models/mclsr.py b/src/irec/models/mclsr.py index 1c7c7a2..1708747 100644 --- a/src/irec/models/mclsr.py +++ b/src/irec/models/mclsr.py @@ -364,9 +364,6 @@ def scatter_mean(src, index, dim=0, dim_size=None): unique_item_graph_items_proj = self._item_projection(unique_item_graph_items) - # negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] # (batch_size, num_negatives) - # negative_embeddings = self._item_embeddings(negative_ids) # (batch_size, num_negatives, embedding_dim) - raw_negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] num_negatives = raw_negative_ids.shape[0] // batch_size negative_ids = raw_negative_ids.view(batch_size, num_negatives) # (Batch, NumNegs) @@ -396,8 +393,6 @@ def scatter_mean(src, index, dim=0, dim_size=None): # by the supervisor to handle highly active users. 'user_ids': user_ids, - - # for L_IL (formula 8) 'sequential_representation': sequential_representation_proj, 'graph_representation': graph_representation_proj, diff --git a/src/irec/models/sasrec.py b/src/irec/models/sasrec.py index c767afd..0cefdd5 100644 --- a/src/irec/models/sasrec.py +++ b/src/irec/models/sasrec.py @@ -87,20 +87,6 @@ def forward(self, inputs): index=torch.randint(low=0, high=all_scores.shape[1], size=all_positive_sample_events.shape, device=all_positive_sample_events.device)[..., None] )[:, 0] # (all_batch_items) - # sample_ids, _ = create_masked_tensor( - # data=all_sample_events, - # lengths=all_sample_lengths - # ) # (batch_size, seq_len) - - # sample_ids = torch.repeat_interleave(sample_ids, all_sample_lengths, dim=0) # (all_batch_events, seq_len) - - # negative_scores = torch.scatter( - # input=all_scores, - # dim=1, - # index=sample_ids, - # src=torch.ones_like(sample_ids) * (-torch.inf) - # ) # (all_batch_events, num_items) - return { 'positive_scores': positive_scores, 'negative_scores': negative_scores @@ -123,7 +109,6 @@ def forward(self, inputs): class SasRecInBatchModel(SasRecModel, config_name='sasrec_in_batch'): - def __init__( self, sequence_prefix, From a70272119044539eab3ebc89660b4f7e87d2727b Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Wed, 1 Apr 2026 16:23:59 +0300 Subject: [PATCH 19/27] code format --- src/irec/callbacks/base.py | 142 +++-- src/irec/dataloader/base.py | 17 +- src/irec/dataloader/batch_processors.py | 22 +- src/irec/dataset/base.py | 499 ++++++++++-------- .../dataset/negative_samplers/__init__.py | 6 +- src/irec/dataset/negative_samplers/base.py | 4 +- src/irec/dataset/negative_samplers/popular.py | 36 +- src/irec/dataset/negative_samplers/random.py | 30 +- src/irec/dataset/samplers/base.py | 16 +- src/irec/dataset/samplers/mclsr.py | 78 +-- .../dataset/samplers/next_item_prediction.py | 60 +-- src/irec/dataset/sasrec.py | 36 +- src/irec/loss/base.py | 67 +-- src/irec/metric/base.py | 85 +-- src/irec/models/base.py | 29 +- src/irec/models/mclsr.py | 243 +++++---- src/irec/models/sasrec.py | 202 +++---- src/irec/optimizer/base.py | 32 +- src/irec/train.py | 71 +-- 19 files changed, 900 insertions(+), 775 deletions(-) diff --git a/src/irec/callbacks/base.py b/src/irec/callbacks/base.py index 81271f5..77895d6 100644 --- a/src/irec/callbacks/base.py +++ b/src/irec/callbacks/base.py @@ -30,7 +30,7 @@ def __call__(self, inputs, step_num): raise NotImplementedError -class MetricCallback(BaseCallback, config_name='metric'): +class MetricCallback(BaseCallback, config_name="metric"): def __init__( self, model, @@ -56,43 +56,39 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], - metrics=config.get('metrics', None), - loss_prefix=config['loss_prefix'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], + metrics=config.get("metrics", None), + loss_prefix=config["loss_prefix"], ) def __call__(self, inputs, step_num): if step_num % self._on_step == 0: for metric_name, metric_function in self._metrics.items(): metric_value = metric_function( - ground_truth=inputs[ - self._model.schema['ground_truth_prefix'] - ], - predictions=inputs[ - self._model.schema['predictions_prefix'] - ], + ground_truth=inputs[self._model.schema["ground_truth_prefix"]], + predictions=inputs[self._model.schema["predictions_prefix"]], ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar( - 'train/{}'.format(metric_name), + "train/{}".format(metric_name), metric_value, step_num, ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar( - 'train/{}'.format(self._loss_prefix), + "train/{}".format(self._loss_prefix), inputs[self._loss_prefix], step_num, ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.flush() -class CheckpointCallback(BaseCallback, config_name='checkpoint'): +class CheckpointCallback(BaseCallback, config_name="checkpoint"): def __init__( self, model, @@ -115,7 +111,7 @@ def __init__( self._save_path = Path(os.path.join(save_path, model_name)) if self._save_path.exists(): logger.warning( - 'Checkpoint path `{}` is already exists!'.format( + "Checkpoint path `{}` is already exists!".format( self._save_path, ), ) @@ -125,31 +121,31 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], - save_path=config['save_path'], - model_name=config['model_name'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], + save_path=config["save_path"], + model_name=config["model_name"], ) def __call__(self, inputs, step_num): if step_num % self._on_step == 0: - logger.debug('Saving model state on step {}...'.format(step_num)) + logger.debug("Saving model state on step {}...".format(step_num)) torch.save( { - 'step_num': step_num, - 'model_state_dict': self._model.state_dict(), - 'optimizer_state_dict': self._optimizer.state_dict(), + "step_num": step_num, + "model_state_dict": self._model.state_dict(), + "optimizer_state_dict": self._optimizer.state_dict(), }, os.path.join( self._save_path, - 'checkpoint_{}.pth'.format(step_num), + "checkpoint_{}.pth".format(step_num), ), ) - logger.debug('Saving done!') + logger.debug("Saving done!") class InferenceCallback(BaseCallback): @@ -183,24 +179,24 @@ def __init__( def create_from_config(cls, config, **kwargs): metrics = { metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs) - for metric_name, metric_cfg in config['metrics'].items() + for metric_name, metric_cfg in config["metrics"].items() } return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['labels_prefix'], + pred_prefix=config["pred_prefix"], + labels_prefix=config["labels_prefix"], ) def __call__(self, inputs, step_num): if step_num % self._on_step == 0: # TODO Add time monitoring - logger.debug(f'Running {self._get_name()} on step {step_num}...') + logger.debug(f"Running {self._get_name()} on step {step_num}...") running_params = {} for metric_name, metric_function in self._metrics.items(): running_params[metric_name] = [] @@ -239,16 +235,16 @@ def __call__(self, inputs, step_num): ) for label, value in running_params.items(): - inputs[f'{self._get_name()}/{label}'] = np.mean(value) + inputs[f"{self._get_name()}/{label}"] = np.mean(value) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar( - f'{self._get_name()}/{label}', + f"{self._get_name()}/{label}", np.mean(value), step_num, ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.flush() logger.debug( - f'Running {self._get_name()} on step {step_num} is done!', + f"Running {self._get_name()} on step {step_num} is done!", ) def _get_name(self): @@ -258,55 +254,55 @@ def _get_dataloader(self): raise NotImplementedError -class ValidationCallback(InferenceCallback, config_name='validation'): +class ValidationCallback(InferenceCallback, config_name="validation"): @classmethod def create_from_config(cls, config, **kwargs): metrics = { metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs) - for metric_name, metric_cfg in config['metrics'].items() + for metric_name, metric_cfg in config["metrics"].items() } return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['labels_prefix'], + pred_prefix=config["pred_prefix"], + labels_prefix=config["labels_prefix"], ) def _get_dataloader(self): return self._validation_dataloader -class EvalCallback(InferenceCallback, config_name='eval'): +class EvalCallback(InferenceCallback, config_name="eval"): @classmethod def create_from_config(cls, config, **kwargs): metrics = { metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs) - for metric_name, metric_cfg in config['metrics'].items() + for metric_name, metric_cfg in config["metrics"].items() } return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['labels_prefix'], + pred_prefix=config["pred_prefix"], + labels_prefix=config["labels_prefix"], ) def _get_dataloader(self): return self._eval_dataloader -class CompositeCallback(BaseCallback, config_name='composite'): +class CompositeCallback(BaseCallback, config_name="composite"): def __init__( self, model, @@ -328,14 +324,14 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], callbacks=[ BaseCallback.create_from_config(cfg, **kwargs) - for cfg in config['callbacks'] + for cfg in config["callbacks"] ], ) diff --git a/src/irec/dataloader/base.py b/src/irec/dataloader/base.py index 20ec865..17c21bf 100644 --- a/src/irec/dataloader/base.py +++ b/src/irec/dataloader/base.py @@ -13,7 +13,7 @@ class BaseDataloader(metaclass=MetaParent): pass -class TorchDataloader(BaseDataloader, config_name='torch'): +class TorchDataloader(BaseDataloader, config_name="torch"): def __init__(self, dataloader): self._dataloader = dataloader @@ -27,20 +27,21 @@ def __len__(self): def create_from_config(cls, config, **kwargs): create_config = copy.deepcopy(config) batch_processor = BaseBatchProcessor.create_from_config( - create_config.pop('batch_processor') - if 'batch_processor' in create_config - else {'type': 'identity'}, + ( + create_config.pop("batch_processor") + if "batch_processor" in create_config + else {"type": "identity"} + ), ) create_config.pop( - 'type', + "type", ) # For passing as **config in torch DataLoader - - pin_memory = create_config.pop('pin_memory', True) + pin_memory = create_config.pop("pin_memory", True) return cls( dataloader=DataLoader( - kwargs['dataset'], + kwargs["dataset"], collate_fn=batch_processor, pin_memory=pin_memory, **create_config, diff --git a/src/irec/dataloader/batch_processors.py b/src/irec/dataloader/batch_processors.py index a8dbdde..c4b9d97 100644 --- a/src/irec/dataloader/batch_processors.py +++ b/src/irec/dataloader/batch_processors.py @@ -7,29 +7,29 @@ def __call__(self, batch): raise NotImplementedError -class IdentityBatchProcessor(BaseBatchProcessor, config_name='identity'): +class IdentityBatchProcessor(BaseBatchProcessor, config_name="identity"): def __call__(self, batch): return torch.tensor(batch) -class BasicBatchProcessor(BaseBatchProcessor, config_name='basic'): +class BasicBatchProcessor(BaseBatchProcessor, config_name="basic"): def __call__(self, batch): processed_batch = {} for key in batch[0].keys(): - if key.endswith('.ids'): - prefix = key.split('.')[0] - assert '{}.length'.format(prefix) in batch[0] + if key.endswith(".ids"): + prefix = key.split(".")[0] + assert "{}.length".format(prefix) in batch[0] - processed_batch[f'{prefix}.ids'] = [] - processed_batch[f'{prefix}.length'] = [] + processed_batch[f"{prefix}.ids"] = [] + processed_batch[f"{prefix}.length"] = [] for sample in batch: - processed_batch[f'{prefix}.ids'].extend( - sample[f'{prefix}.ids'], + processed_batch[f"{prefix}.ids"].extend( + sample[f"{prefix}.ids"], ) - processed_batch[f'{prefix}.length'].append( - sample[f'{prefix}.length'], + processed_batch[f"{prefix}.length"].append( + sample[f"{prefix}.length"], ) for part, values in processed_batch.items(): diff --git a/src/irec/dataset/base.py b/src/irec/dataset/base.py index 40d6a88..c380709 100644 --- a/src/irec/dataset/base.py +++ b/src/irec/dataset/base.py @@ -33,7 +33,8 @@ def num_items(self): @property def max_sequence_length(self): return self._max_sequence_length - + + class BaseSequenceDataset(BaseDataset): def __init__( self, @@ -52,7 +53,7 @@ def __init__( self._max_sequence_length = max_sequence_length @staticmethod - def _create_sequences(data, max_sample_len): # TODO + def _create_sequences(data, max_sample_len): # TODO user_sequences = [] item_sequences = [] @@ -61,10 +62,8 @@ def _create_sequences(data, max_sample_len): # TODO max_sequence_length = 0 for sample in data: - sample = sample.strip('\n').split(' ') - item_ids = [int(item_id) for item_id in sample[1:]][ - -max_sample_len: - ] + sample = sample.strip("\n").split(" ") + item_ids = [int(item_id) for item_id in sample[1:]][-max_sample_len:] user_id = int(sample[0]) max_user_id = max(max_user_id, user_id) @@ -92,12 +91,13 @@ def get_samplers(self): @property def meta(self): return { - 'num_users': self.num_users, - 'num_items': self.num_items, - 'max_sequence_length': self.max_sequence_length, + "num_users": self.num_users, + "num_items": self.num_items, + "max_sequence_length": self.max_sequence_length, } - -class SequenceDataset(BaseDataset, config_name='sequence'): + + +class SequenceDataset(BaseDataset, config_name="sequence"): def __init__( self, train_sampler, @@ -117,24 +117,24 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): data_dir_path = os.path.join( - config['path_to_data_dir'], - config['name'], + config["path_to_data_dir"], + config["name"], ) common_params_for_creation = { - 'dir_path': data_dir_path, - 'max_sequence_length': config['max_sequence_length'], - 'use_cached': config.get('use_cached', False), + "dir_path": data_dir_path, + "max_sequence_length": config["max_sequence_length"], + "use_cached": config.get("use_cached", False), } train_dataset, train_max_user_id, train_max_item_id, train_seq_len = ( - cls._create_dataset(part='train', **common_params_for_creation) + cls._create_dataset(part="train", **common_params_for_creation) ) validation_dataset, valid_max_user_id, valid_max_item_id, valid_seq_len = ( - cls._create_dataset(part='valid', **common_params_for_creation) + cls._create_dataset(part="valid", **common_params_for_creation) ) test_dataset, test_max_user_id, test_max_item_id, test_seq_len = ( - cls._create_dataset(part='test', **common_params_for_creation) + cls._create_dataset(part="test", **common_params_for_creation) ) max_user_id = max( @@ -145,11 +145,11 @@ def create_from_config(cls, config, **kwargs): ) max_seq_len = max([train_seq_len, valid_seq_len, test_seq_len]) - logger.info('Train dataset size: {}'.format(len(train_dataset))) - logger.info('Test dataset size: {}'.format(len(test_dataset))) - logger.info('Max user id: {}'.format(max_user_id)) - logger.info('Max item id: {}'.format(max_item_id)) - logger.info('Max sequence length: {}'.format(max_seq_len)) + logger.info("Train dataset size: {}".format(len(train_dataset))) + logger.info("Test dataset size: {}".format(len(test_dataset))) + logger.info("Max user id: {}".format(max_user_id)) + logger.info("Max item id: {}".format(max_item_id)) + logger.info("Max sequence length: {}".format(max_seq_len)) train_interactions = sum( list(map(lambda x: len(x), train_dataset)), @@ -161,15 +161,15 @@ def create_from_config(cls, config, **kwargs): test_dataset, ) # each new interaction as a sample logger.info( - '{} dataset sparsity: {}'.format( - config['name'], + "{} dataset sparsity: {}".format( + config["name"], (train_interactions + valid_interactions + test_interactions) / max_user_id / max_item_id, ), ) - samplers_config = config['samplers'] + samplers_config = config["samplers"] train_sampler = TrainSampler.create_from_config( samplers_config, dataset=train_dataset, @@ -206,31 +206,30 @@ def _create_dataset( max_sequence_length=None, use_cached=False, ): - cache_path = os.path.join(dir_path, '{}.pkl'.format(part)) + cache_path = os.path.join(dir_path, "{}.pkl".format(part)) if use_cached and os.path.exists(cache_path): - logger.info( - 'Loading cached dataset from {}'.format(cache_path) - ) - with open(cache_path, 'rb') as f: + logger.info("Loading cached dataset from {}".format(cache_path)) + with open(cache_path, "rb") as f: return pickle.load(f) - - return cls._build_and_cache_dataset(dir_path, part, max_sequence_length, cache_path, use_cached) + return cls._build_and_cache_dataset( + dir_path, part, max_sequence_length, cache_path, use_cached + ) @classmethod - def _build_and_cache_dataset(cls, dir_path, part, max_sequence_length, cache_path, use_cached): + def _build_and_cache_dataset( + cls, dir_path, part, max_sequence_length, cache_path, use_cached + ): logger.info( - 'Cache is forcefully ignored.' + "Cache is forcefully ignored." if not use_cached - else 'No cached dataset has been found.' - ) - dataset_path = os.path.join(dir_path, '{}.txt'.format(part)) - logger.info( - 'Creating a dataset from {}...'.format(dataset_path) + else "No cached dataset has been found." ) + dataset_path = os.path.join(dir_path, "{}.txt".format(part)) + logger.info("Creating a dataset from {}...".format(dataset_path)) - with open(dataset_path, 'r') as f: + with open(dataset_path, "r") as f: data = f.readlines() sequence_info = cls._create_sequences(data, max_sequence_length) @@ -246,22 +245,22 @@ def _build_and_cache_dataset(cls, dir_path, part, max_sequence_length, cache_pat for user_id, item_ids in zip(user_sequences, item_sequences): dataset.append( { - 'user.ids': [user_id], - 'user.length': 1, - 'item.ids': item_ids, - 'item.length': len(item_ids), + "user.ids": [user_id], + "user.length": 1, + "item.ids": item_ids, + "item.length": len(item_ids), }, ) - logger.info('{} dataset size: {}'.format(part, len(dataset))) + logger.info("{} dataset size: {}".format(part, len(dataset))) logger.info( - '{} dataset max sequence length: {}'.format( + "{} dataset max sequence length: {}".format( part, max_sequence_len, ), ) - with open(cache_path, 'wb') as dataset_file: + with open(cache_path, "wb") as dataset_file: pickle.dump( (dataset, max_user_id, max_item_id, max_sequence_len), dataset_file, @@ -270,7 +269,7 @@ def _build_and_cache_dataset(cls, dir_path, part, max_sequence_length, cache_pat return dataset, max_user_id, max_item_id, max_sequence_len @staticmethod - def _create_sequences(data, max_sample_len): # TODO + def _create_sequences(data, max_sample_len): # TODO user_sequences = [] item_sequences = [] @@ -279,10 +278,8 @@ def _create_sequences(data, max_sample_len): # TODO max_sequence_length = 0 for sample in data: - sample = sample.strip('\n').split(' ') - item_ids = [int(item_id) for item_id in sample[1:]][ - -max_sample_len: - ] + sample = sample.strip("\n").split(" ") + item_ids = [int(item_id) for item_id in sample[1:]][-max_sample_len:] user_id = int(sample[0]) max_user_id = max(max_user_id, user_id) @@ -307,7 +304,8 @@ def get_samplers(self): self._test_sampler, ) -class GraphDataset(BaseDataset, config_name='graph'): + +class GraphDataset(BaseDataset, config_name="graph"): def __init__( self, dataset, @@ -315,7 +313,7 @@ def __init__( use_train_data_only=True, use_user_graph=False, use_item_graph=False, - neighborhood_size=None + neighborhood_size=None, ): self._dataset = dataset self._graph_dir_path = graph_dir_path @@ -327,11 +325,11 @@ def __init__( self._num_users = dataset.num_users self._num_items = dataset.num_items - train_sampler, validation_sampler, test_sampler = ( - dataset.get_samplers() - ) + train_sampler, validation_sampler, test_sampler = dataset.get_samplers() - interactions_data = self._collect_interactions(train_sampler, validation_sampler, test_sampler) + interactions_data = self._collect_interactions( + train_sampler, validation_sampler, test_sampler + ) train_interactions = interactions_data["train_interactions"] train_user_interactions = interactions_data["train_user_interactions"] train_item_interactions = interactions_data["train_item_interactions"] @@ -343,58 +341,59 @@ def __init__( self._train_item_interactions = np.array(train_item_interactions) self._graph = self._build_or_load_bipartite_graph( - graph_dir_path, - train_user_interactions, - train_item_interactions + graph_dir_path, train_user_interactions, train_item_interactions ) - self._user_graph = ( self._build_or_load_similarity_graph( - 'user', - self._train_user_interactions, - self._train_item_interactions, - train_item_2_users, - train_user_2_items - ) - if self._use_user_graph + "user", + self._train_user_interactions, + self._train_item_interactions, + train_item_2_users, + train_user_2_items, + ) + if self._use_user_graph else None ) self._item_graph = ( self._build_or_load_similarity_graph( - 'item', - self._train_user_interactions, - self._train_item_interactions, - train_item_2_users, - train_user_2_items - ) - if self._use_item_graph + "item", + self._train_user_interactions, + self._train_item_interactions, + train_item_2_users, + train_user_2_items, + ) + if self._use_item_graph else None ) def _build_or_load_similarity_graph( - self, - entity_type, - train_user_interactions, - train_item_interactions, - train_item_2_users, - train_user_2_items + self, + entity_type, + train_user_interactions, + train_item_interactions, + train_item_2_users, + train_user_2_items, ): - if entity_type not in ['user', 'item']: + if entity_type not in ["user", "item"]: raise ValueError("entity_type must be either 'user' or 'item'") # neighborhood_size - # The neighborhood_size is a filter that constrains the number of edges for each user or + # The neighborhood_size is a filter that constrains the number of edges for each user or # item node in the graph. - # k=50 implies that for each user, we find all possible neighbors, sort them based on + # k=50 implies that for each user, we find all possible neighbors, sort them based on # co-occurrence counts, and keep only the top 50. All other connections are removed from the graph. - k_suffix = f"k{self._neighborhood_size}" if self._neighborhood_size is not None else "full" + k_suffix = ( + f"k{self._neighborhood_size}" + if self._neighborhood_size is not None + else "full" + ) train_suffix = "trainOnly" if self._use_train_data_only else "withValTest" filename = f"{entity_type}_graph_{k_suffix}_{train_suffix}.npz" path_to_graph = os.path.join(self._graph_dir_path, filename) - is_user_graph = (entity_type == 'user') + is_user_graph = entity_type == "user" num_entities = self._num_users if is_user_graph else self._num_items if os.path.exists(path_to_graph): @@ -405,21 +404,25 @@ def _build_or_load_similarity_graph( visited_user_item_pairs = set() # have to delete cause # 3.2 Graph Construction - # User-user/item-item graph + # User-user/item-item graph # ..the weight of each edge denotes the number of co-action behaviors between user i and user j # visited_entity_pairs = set() for user_id, item_id in tqdm( zip(train_user_interactions, train_item_interactions), - desc='Building {}-{} graph'.format(entity_type, entity_type) # TODO need? + desc="Building {}-{} graph".format( + entity_type, entity_type + ), # TODO need? ): if (user_id, item_id) in visited_user_item_pairs: continue - visited_user_item_pairs.add((user_id, item_id)) + visited_user_item_pairs.add((user_id, item_id)) source_entity = user_id if is_user_graph else item_id - connection_map = train_item_2_users if is_user_graph else train_user_2_items + connection_map = ( + train_item_2_users if is_user_graph else train_user_2_items + ) connection_point = item_id if is_user_graph else user_id for connected_entity in connection_map[connection_point]: @@ -430,29 +433,25 @@ def _build_or_load_similarity_graph( interactions_snd.append(connected_entity) connections = csr_matrix( - (np.ones(len(interactions_fst)), - ( - interactions_fst, - interactions_snd - ) - ), - shape=(num_entities + 2, num_entities + 2) + (np.ones(len(interactions_fst)), (interactions_fst, interactions_snd)), + shape=(num_entities + 2, num_entities + 2), ) if self._neighborhood_size is not None: - connections = self._filter_matrix_by_top_k(connections, self._neighborhood_size) + connections = self._filter_matrix_by_top_k( + connections, self._neighborhood_size + ) graph_matrix = self.get_sparse_graph_layer( - connections, - num_entities + 2, - num_entities + 2, - biparite=False + connections, num_entities + 2, num_entities + 2, biparite=False ) sp.save_npz(path_to_graph, graph_matrix) return self._convert_sp_mat_to_sp_tensor(graph_matrix).coalesce().to(DEVICE) - def _build_or_load_bipartite_graph(self, graph_dir_path, train_user_interactions, train_item_interactions): + def _build_or_load_bipartite_graph( + self, graph_dir_path, train_user_interactions, train_item_interactions + ): train_suffix = "trainOnly" if self._use_train_data_only else "withValTest" filename = f"general_graph_{train_suffix}.npz" path_to_graph = os.path.join(graph_dir_path, filename) @@ -492,8 +491,8 @@ def _collect_interactions(self, train_sampler, validation_sampler, test_sampler) for sampler in samplers_to_process: for sample in sampler.dataset: - user_id = sample['user.ids'][0] - for item_id in sample['item.ids']: + user_id = sample["user.ids"][0] + for item_id in sample["item.ids"]: if (user_id, item_id) not in visited_user_item_pairs: train_interactions.append((user_id, item_id)) train_user_interactions.append(user_id) @@ -503,7 +502,7 @@ def _collect_interactions(self, train_sampler, validation_sampler, test_sampler) train_item_2_users[item_id].add(user_id) visited_user_item_pairs.add((user_id, item_id)) - + return { "train_interactions": train_interactions, "train_user_interactions": train_user_interactions, @@ -514,13 +513,13 @@ def _collect_interactions(self, train_sampler, validation_sampler, test_sampler) @classmethod def create_from_config(cls, config): - dataset = BaseDataset.create_from_config(config['dataset']) + dataset = BaseDataset.create_from_config(config["dataset"]) return cls( dataset=dataset, - graph_dir_path=config['graph_dir_path'], - use_user_graph=config.get('use_user_graph', False), - use_item_graph=config.get('use_item_graph', False), - neighborhood_size=config.get('neighborhood_size', None), + graph_dir_path=config["graph_dir_path"], + use_user_graph=config.get("use_user_graph", False), + use_item_graph=config.get("use_item_graph", False), + neighborhood_size=config.get("neighborhood_size", None), ) @staticmethod @@ -534,26 +533,24 @@ def get_sparse_graph_layer( adj_mat = sparse_matrix.tocsr() else: R = sparse_matrix.tocsr() - + upper_right = R lower_left = R.T - + upper_left = sp.csr_matrix((fst_dim, fst_dim)) lower_right = sp.csr_matrix((snd_dim, snd_dim)) - - adj_mat = sp.bmat([ - [upper_left, upper_right], - [lower_left, lower_right] - ]) - assert adj_mat.shape == (fst_dim + snd_dim, fst_dim + snd_dim), ( - f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}" - ) - + + adj_mat = sp.bmat([[upper_left, upper_right], [lower_left, lower_right]]) + assert adj_mat.shape == ( + fst_dim + snd_dim, + fst_dim + snd_dim, + ), f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}" + rowsum = np.array(adj_mat.sum(1)) d_inv = np.power(rowsum, -0.5).flatten() - d_inv[np.isinf(d_inv)] = 0. + d_inv[np.isinf(d_inv)] = 0.0 d_mat_inv = sp.diags(d_inv) - + norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv) return norm_adj.tocsr() @@ -574,7 +571,7 @@ def _filter_matrix_by_top_k(matrix, k): if len(mat.rows[i]) <= k: continue data = np.array(mat.data[i]) - + top_k_indices = np.argpartition(data, -k)[-k:] mat.data[i] = [mat.data[i][j] for j in top_k_indices] mat.rows[i] = [mat.rows[i][j] for j in top_k_indices] @@ -587,50 +584,52 @@ def get_samplers(self): @property def meta(self): meta = { - 'user_graph': self._user_graph, - 'item_graph': self._item_graph, - 'graph': self._graph, + "user_graph": self._user_graph, + "item_graph": self._item_graph, + "graph": self._graph, **self._dataset.meta, } return meta -class ScientificDataset(BaseSequenceDataset, config_name='scientific'): +class ScientificDataset(BaseSequenceDataset, config_name="scientific"): @classmethod def create_from_config(cls, config, **kwargs): data_dir_path = os.path.join( - config['path_to_data_dir'], - config['name'], + config["path_to_data_dir"], + config["name"], ) - max_sequence_length = config['max_sequence_length'] + max_sequence_length = config["max_sequence_length"] - dataset_path = os.path.join(data_dir_path, '{}.txt'.format('all_data')) - with open(dataset_path, 'r') as f: + dataset_path = os.path.join(data_dir_path, "{}.txt".format("all_data")) + with open(dataset_path, "r") as f: lines = f.readlines() - datasets, max_user_id, max_item_id = cls._parse_and_split_data(lines, max_sequence_length) + datasets, max_user_id, max_item_id = cls._parse_and_split_data( + lines, max_sequence_length + ) - train_dataset = datasets['train'] - validation_dataset = datasets['validation'] - test_dataset = datasets['test'] + train_dataset = datasets["train"] + validation_dataset = datasets["validation"] + test_dataset = datasets["test"] cls._log_stats( - train_dataset, - test_dataset, - max_user_id, - max_item_id, - max_sequence_length, - config['name'] - ) + train_dataset, + test_dataset, + max_user_id, + max_item_id, + max_sequence_length, + config["name"], + ) train_sampler, validation_sampler, test_sampler = cls._create_samplers( - config['samplers'], - train_dataset, - validation_dataset, - test_dataset, - max_user_id, - max_item_id + config["samplers"], + train_dataset, + validation_dataset, + test_dataset, + max_user_id, + max_item_id, ) return cls( @@ -641,71 +640,99 @@ def create_from_config(cls, config, **kwargs): num_items=max_item_id, max_sequence_length=max_sequence_length, ) - + @staticmethod - def _create_samplers(sampler_config, train_dataset, validation_dataset, test_dataset, num_users, num_items): + def _create_samplers( + sampler_config, + train_dataset, + validation_dataset, + test_dataset, + num_users, + num_items, + ): train_sampler = TrainSampler.create_from_config( - sampler_config, dataset=train_dataset, num_users=num_users, num_items=num_items + sampler_config, + dataset=train_dataset, + num_users=num_users, + num_items=num_items, ) validation_sampler = EvalSampler.create_from_config( - sampler_config, dataset=validation_dataset, num_users=num_users, num_items=num_items + sampler_config, + dataset=validation_dataset, + num_users=num_users, + num_items=num_items, ) test_sampler = EvalSampler.create_from_config( - sampler_config, dataset=test_dataset, num_users=num_users, num_items=num_items + sampler_config, + dataset=test_dataset, + num_users=num_users, + num_items=num_items, ) return train_sampler, validation_sampler, test_sampler - + @staticmethod - def _log_stats(train_dataset, test_dataset, max_user_id, max_item_id, max_len, name): - logger.info('Train dataset size: {}'.format(len(train_dataset))) - logger.info('Test dataset size: {}'.format(len(test_dataset))) - logger.info('Max user id: {}'.format(max_user_id)) - logger.info('Max item id: {}'.format(max_item_id)) - logger.info('Max sequence length: {}'.format(max_len)) - + def _log_stats( + train_dataset, test_dataset, max_user_id, max_item_id, max_len, name + ): + logger.info("Train dataset size: {}".format(len(train_dataset))) + logger.info("Test dataset size: {}".format(len(test_dataset))) + logger.info("Max user id: {}".format(max_user_id)) + logger.info("Max item id: {}".format(max_item_id)) + logger.info("Max sequence length: {}".format(max_len)) + if max_user_id > 0 and max_item_id > 0: - sparsity = (len(train_dataset) + len(test_dataset)) / max_user_id / max_item_id - logger.info('{} dataset sparsity: {}'.format(name, sparsity)) + sparsity = ( + (len(train_dataset) + len(test_dataset)) / max_user_id / max_item_id + ) + logger.info("{} dataset sparsity: {}".format(name, sparsity)) @staticmethod def _parse_and_split_data(lines, max_sequence_length): - datasets = {'train': [], 'validation': [], 'test': []} + datasets = {"train": [], "validation": [], "test": []} - user_ids, item_sequences, max_user_id, max_item_id, _ = \ + user_ids, item_sequences, max_user_id, max_item_id, _ = ( BaseSequenceDataset._create_sequences(lines) + ) for user_id, item_ids in zip(user_ids, item_sequences): - + assert len(item_ids) >= 5 split_slices = { - 'train': slice(None, -2), - 'validation': slice(None, -1), - 'test': slice(None, None) + "train": slice(None, -2), + "validation": slice(None, -1), + "test": slice(None, None), } - + for part_name, part_slice in split_slices.items(): sliced_items = item_ids[part_slice] final_items = sliced_items[-max_sequence_length:] - - assert len(item_ids[-max_sequence_length:]) == len(set(item_ids[-max_sequence_length:]),) - datasets[part_name].append({ - 'user.ids': [user_id], 'user.length': 1, - 'item.ids': final_items, 'item.length': len(final_items), - }) + assert len(item_ids[-max_sequence_length:]) == len( + set(item_ids[-max_sequence_length:]), + ) + + datasets[part_name].append( + { + "user.ids": [user_id], + "user.length": 1, + "item.ids": final_items, + "item.length": len(final_items), + } + ) return datasets, max_user_id, max_item_id -class MCLSRDataset(BaseSequenceDataset, config_name='mclsr'): + +class MCLSRDataset(BaseSequenceDataset, config_name="mclsr"): @staticmethod def _create_sequences_from_file(filepath, max_len=None): sequences = {} max_user, max_item = 0, 0 - - with open(filepath, 'r') as f: + + with open(filepath, "r") as f: for line in f: - parts = line.strip().split(' ') + parts = line.strip().split(" ") user_id = int(parts[0]) item_ids = [int(i) for i in parts[1:]] if max_len: @@ -715,40 +742,98 @@ def _create_sequences_from_file(filepath, max_len=None): if item_ids: max_item = max(max_item, max(item_ids)) return sequences, max_user, max_item - + @classmethod def _create_evaluation_sets(cls, data_dir, max_seq_len): - valid_hist, u2, i2 = cls._create_sequences_from_file(os.path.join(data_dir, 'valid_history.txt'), max_seq_len) - valid_trg, u3, i3 = cls._create_sequences_from_file(os.path.join(data_dir, 'valid_target.txt')) + valid_hist, u2, i2 = cls._create_sequences_from_file( + os.path.join(data_dir, "valid_history.txt"), max_seq_len + ) + valid_trg, u3, i3 = cls._create_sequences_from_file( + os.path.join(data_dir, "valid_target.txt") + ) - validation_dataset = [{'user.ids': [uid], 'history': valid_hist[uid], 'target': valid_trg[uid]} for uid in valid_hist if uid in valid_trg] - - test_hist, u4, i4 = cls._create_sequences_from_file(os.path.join(data_dir, 'test_history.txt'), max_seq_len) - test_trg, u5, i5 = cls._create_sequences_from_file(os.path.join(data_dir, 'test_target.txt')) + validation_dataset = [ + {"user.ids": [uid], "history": valid_hist[uid], "target": valid_trg[uid]} + for uid in valid_hist + if uid in valid_trg + ] - test_dataset = [{'user.ids': [uid], 'history': test_hist[uid], 'target': test_trg[uid]} for uid in test_hist if uid in test_trg] + test_hist, u4, i4 = cls._create_sequences_from_file( + os.path.join(data_dir, "test_history.txt"), max_seq_len + ) + test_trg, u5, i5 = cls._create_sequences_from_file( + os.path.join(data_dir, "test_target.txt") + ) - return validation_dataset, test_dataset, max(u2, u3, u4, u5), max(i2, i3, i4, i5) + test_dataset = [ + {"user.ids": [uid], "history": test_hist[uid], "target": test_trg[uid]} + for uid in test_hist + if uid in test_trg + ] + + return ( + validation_dataset, + test_dataset, + max(u2, u3, u4, u5), + max(i2, i3, i4, i5), + ) @classmethod def create_from_config(cls, config, **kwargs): - data_dir = os.path.join(config['path_to_data_dir'], config['name']) - max_seq_len = config.get('max_sequence_length') + data_dir = os.path.join(config["path_to_data_dir"], config["name"]) + max_seq_len = config.get("max_sequence_length") - train_sequences, u1, i1 = cls._create_sequences_from_file(os.path.join(data_dir, 'train_mclsr.txt'), max_seq_len) - train_dataset = [{'user.ids': [uid], 'user.length': 1, 'item.ids': seq, 'item.length': len(seq)} for uid, seq in train_sequences.items()] + train_sequences, u1, i1 = cls._create_sequences_from_file( + os.path.join(data_dir, "train_mclsr.txt"), max_seq_len + ) + train_dataset = [ + { + "user.ids": [uid], + "user.length": 1, + "item.ids": seq, + "item.length": len(seq), + } + for uid, seq in train_sequences.items() + ] user_to_all_seen_items = defaultdict(set) - for sample in train_dataset: user_to_all_seen_items[sample['user.ids'][0]].update(sample['item.ids']) - kwargs['user_to_all_seen_items'] = user_to_all_seen_items + for sample in train_dataset: + user_to_all_seen_items[sample["user.ids"][0]].update(sample["item.ids"]) + kwargs["user_to_all_seen_items"] = user_to_all_seen_items - validation_dataset, test_dataset, u_eval, i_eval = cls._create_evaluation_sets(data_dir, max_seq_len) + validation_dataset, test_dataset, u_eval, i_eval = cls._create_evaluation_sets( + data_dir, max_seq_len + ) num_users = max(u1, u_eval) num_items = max(i1, i_eval) - - train_sampler = TrainSampler.create_from_config(config['samplers'], dataset=train_dataset, num_users=num_users, num_items=num_items, **kwargs) - validation_sampler = EvalSampler.create_from_config(config['samplers'], dataset=validation_dataset, num_users=num_users, num_items=num_items, **kwargs) - test_sampler = EvalSampler.create_from_config(config['samplers'], dataset=test_dataset, num_users=num_users, num_items=num_items, **kwargs) - return cls(train_sampler, validation_sampler, test_sampler, num_users, num_items, max_seq_len) + train_sampler = TrainSampler.create_from_config( + config["samplers"], + dataset=train_dataset, + num_users=num_users, + num_items=num_items, + **kwargs, + ) + validation_sampler = EvalSampler.create_from_config( + config["samplers"], + dataset=validation_dataset, + num_users=num_users, + num_items=num_items, + **kwargs, + ) + test_sampler = EvalSampler.create_from_config( + config["samplers"], + dataset=test_dataset, + num_users=num_users, + num_items=num_items, + **kwargs, + ) + return cls( + train_sampler, + validation_sampler, + test_sampler, + num_users, + num_items, + max_seq_len, + ) diff --git a/src/irec/dataset/negative_samplers/__init__.py b/src/irec/dataset/negative_samplers/__init__.py index 04b82f9..654f7e1 100644 --- a/src/irec/dataset/negative_samplers/__init__.py +++ b/src/irec/dataset/negative_samplers/__init__.py @@ -3,7 +3,7 @@ from .random import RandomNegativeSampler __all__ = [ - 'BaseNegativeSampler', - 'PopularNegativeSampler', - 'RandomNegativeSampler', + "BaseNegativeSampler", + "PopularNegativeSampler", + "RandomNegativeSampler", ] diff --git a/src/irec/dataset/negative_samplers/base.py b/src/irec/dataset/negative_samplers/base.py index b4d1224..3360179 100644 --- a/src/irec/dataset/negative_samplers/base.py +++ b/src/irec/dataset/negative_samplers/base.py @@ -11,8 +11,8 @@ def __init__(self, dataset, num_users, num_items): self._seen_items = defaultdict(set) for sample in self._dataset: - user_id = sample['user.ids'][0] - items = list(sample['item.ids']) + user_id = sample["user.ids"][0] + items = list(sample["item.ids"]) self._seen_items[user_id].update(items) def generate_negative_samples(self, sample, num_negatives): diff --git a/src/irec/dataset/negative_samplers/popular.py b/src/irec/dataset/negative_samplers/popular.py index 9859b1a..f04822e 100644 --- a/src/irec/dataset/negative_samplers/popular.py +++ b/src/irec/dataset/negative_samplers/popular.py @@ -3,7 +3,7 @@ from collections import Counter -class PopularNegativeSampler(BaseNegativeSampler, config_name='popular'): +class PopularNegativeSampler(BaseNegativeSampler, config_name="popular"): def __init__(self, dataset, num_users, num_items): super().__init__( dataset=dataset, @@ -15,9 +15,9 @@ def __init__(self, dataset, num_users, num_items): @classmethod def create_from_config(cls, _, **kwargs): return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], ) def _calculate_item_probabilities(self): @@ -27,44 +27,44 @@ def _calculate_item_probabilities(self): """ counts = Counter() for sample in self._dataset: - for item_id in sample['item.ids']: + for item_id in sample["item.ids"]: counts[item_id] += 1 - + items = np.array(list(counts.keys())) freqs = np.array(list(counts.values()), dtype=np.float32) probabilities = freqs / freqs.sum() - + return items, probabilities def generate_negative_samples(self, sample, num_negatives): """ Stochastic sampling proportional to popularity. - + Justification: The original implementation always picked the same Top-K popular items. - For LogQ correction (Yi et al., Google 2019), we need a stochastic - sampling process where p_j > 0 for all items in the distribution. - This allows the model to see a diverse set of negatives across epochs + For LogQ correction (Yi et al., Google 2019), we need a stochastic + sampling process where p_j > 0 for all items in the distribution. + This allows the model to see a diverse set of negatives across epochs while penalizing popular items correctly via the log(p_j) term. """ - user_id = sample['user.ids'][0] + user_id = sample["user.ids"][0] seen = self._seen_items[user_id] - + negatives = set() while len(negatives) < num_negatives: # Sample items based on the pre-calculated frequency distribution sampled_ids = np.random.choice( - self._item_ids, - size=num_negatives - len(negatives), + self._item_ids, + size=num_negatives - len(negatives), p=self._probs, - replace=True + replace=True, ) - + # Filter out items already seen by the user (False Negatives) for idx in sampled_ids: if idx not in seen: negatives.add(idx) if len(negatives) == num_negatives: break - + return list(negatives) diff --git a/src/irec/dataset/negative_samplers/random.py b/src/irec/dataset/negative_samplers/random.py index 8c2a301..d7365a8 100644 --- a/src/irec/dataset/negative_samplers/random.py +++ b/src/irec/dataset/negative_samplers/random.py @@ -3,36 +3,36 @@ import numpy as np -class RandomNegativeSampler(BaseNegativeSampler, config_name='random'): +class RandomNegativeSampler(BaseNegativeSampler, config_name="random"): @classmethod def create_from_config(cls, _, **kwargs): return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], ) def generate_negative_samples(self, sample, num_negatives): """ Optimized via Rejection Sampling (O(k) complexity). - + Mathematical Proof of Equivalence: - Let V be the set of all items and H be the user's history. + Let V be the set of all items and H be the user's history. We need a uniform random sample S ⊂ (V \ H) such that |S| = k. - - 1. Shuffle Approach (Previous): Generates a random permutation of V, + + 1. Shuffle Approach (Previous): Generates a random permutation of V, then filters H. Complexity: O(|V|). - 2. Rejection Sampling (Current): Independently draws i ~ Uniform(V) - and accepts i if i ∉ H and i ∉ S. Complexity: O(k * 1/p), + 2. Rejection Sampling (Current): Independently draws i ~ Uniform(V) + and accepts i if i ∉ H and i ∉ S. Complexity: O(k * 1/p), where p = (|V| - |H|) / |V|. - - Since |H| << |V|, the probability p ≈ 1, making the expected complexity - effectively O(k). Both methods yield an identical uniform distribution + + Since |H| << |V|, the probability p ≈ 1, making the expected complexity + effectively O(k). Both methods yield an identical uniform distribution over the valid item space. """ - user_id = sample['user.ids'][0] + user_id = sample["user.ids"][0] seen = self._seen_items[user_id] - + negatives = set() while len(negatives) < num_negatives: # Drawing a random index is O(1) diff --git a/src/irec/dataset/samplers/base.py b/src/irec/dataset/samplers/base.py index ad6f3bf..7017b71 100644 --- a/src/irec/dataset/samplers/base.py +++ b/src/irec/dataset/samplers/base.py @@ -35,14 +35,14 @@ def __len__(self): def __getitem__(self, index): sample = self._dataset[index] - item_sequence = sample['item.ids'][:-1] - next_item = sample['item.ids'][-1] + item_sequence = sample["item.ids"][:-1] + next_item = sample["item.ids"][-1] return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'labels.ids': [next_item], - 'labels.length': 1, + "user.ids": sample["user.ids"], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "labels.ids": [next_item], + "labels.length": 1, } diff --git a/src/irec/dataset/samplers/mclsr.py b/src/irec/dataset/samplers/mclsr.py index bf0fef6..3a9276a 100644 --- a/src/irec/dataset/samplers/mclsr.py +++ b/src/irec/dataset/samplers/mclsr.py @@ -4,8 +4,16 @@ import random -class MCLSRTrainSampler(TrainSampler, config_name='mclsr'): - def __init__(self, dataset, num_users, num_items, user_to_all_seen_items, num_negatives, **kwargs): +class MCLSRTrainSampler(TrainSampler, config_name="mclsr"): + def __init__( + self, + dataset, + num_users, + num_items, + user_to_all_seen_items, + num_negatives, + **kwargs, + ): super().__init__() self._dataset = dataset self._num_users = num_users @@ -16,27 +24,26 @@ def __init__(self, dataset, num_users, num_items, user_to_all_seen_items, num_ne @classmethod def create_from_config(cls, config, **kwargs): - num_negatives = config['num_negatives_train'] + num_negatives = config["num_negatives_train"] print(num_negatives) return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], num_negatives=num_negatives, - user_to_all_seen_items=kwargs['user_to_all_seen_items'], + user_to_all_seen_items=kwargs["user_to_all_seen_items"], ) - def __getitem__(self, index): sample = self._dataset[index] - user_id = sample['user.ids'][0] - item_sequence = sample['item.ids'][:-1] - positive_item = sample['item.ids'][-1] + user_id = sample["user.ids"][0] + item_sequence = sample["item.ids"][:-1] + positive_item = sample["item.ids"][-1] user_seen = self._user_to_all_seen_items[user_id] - # unseen_items = list(self._all_items_set - user_seen) + # unseen_items = list(self._all_items_set - user_seen) # negatives = random.sample(unseen_items, self._num_negatives) # --- OPTIMIZATION: Rejection Sampling --- @@ -48,51 +55,50 @@ def __getitem__(self, index): while len(negatives) < self._num_negatives: # Draw a random item index from the range [1, num_items] candidate = random.randint(1, self._num_items) - + # Rejection step: Only accept if the user has never interacted with it. # This ensures we only sample from the "unseen" pool. if candidate not in user_seen: negatives.add(candidate) - + # Convert back to list to match the expected format for BatchProcessor negatives = list(negatives) # ---------------------------------------- - return { - 'user.ids': [user_id], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'labels.ids': [positive_item], - 'labels.length': 1, - 'negatives.ids': negatives, - 'negatives.length': len(negatives), + "user.ids": [user_id], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "labels.ids": [positive_item], + "labels.length": 1, + "negatives.ids": negatives, + "negatives.length": len(negatives), } -class MCLSRPredictionEvalSampler(EvalSampler, config_name='mclsr'): +class MCLSRPredictionEvalSampler(EvalSampler, config_name="mclsr"): def __init__(self, dataset, num_users, num_items): super().__init__(dataset, num_users, num_items) @classmethod def create_from_config(cls, config, **kwargs): return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], ) - + def __getitem__(self, index): sample = self._dataset[index] - history_sequence = sample['history'] - target_items = sample['target'] + history_sequence = sample["history"] + target_items = sample["target"] return { - 'user.ids': sample['user.ids'], - 'user.length': 1, - 'item.ids': history_sequence, - 'item.length': len(history_sequence), - 'labels.ids': target_items, - 'labels.length': len(target_items), + "user.ids": sample["user.ids"], + "user.length": 1, + "item.ids": history_sequence, + "item.length": len(history_sequence), + "labels.ids": target_items, + "labels.length": len(target_items), } diff --git a/src/irec/dataset/samplers/next_item_prediction.py b/src/irec/dataset/samplers/next_item_prediction.py index 47e69a8..f757b4b 100644 --- a/src/irec/dataset/samplers/next_item_prediction.py +++ b/src/irec/dataset/samplers/next_item_prediction.py @@ -6,7 +6,7 @@ class NextItemPredictionTrainSampler( TrainSampler, - config_name='next_item_prediction', + config_name="next_item_prediction", ): def __init__( self, @@ -26,61 +26,59 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): negative_sampler = BaseNegativeSampler.create_from_config( - {'type': config['negative_sampler_type']}, + {"type": config["negative_sampler_type"]}, **kwargs, ) return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], negative_sampler=negative_sampler, - num_negatives=config.get('num_negatives_train', 0), + num_negatives=config.get("num_negatives_train", 0), ) def __getitem__(self, index): sample = self._dataset[index] - item_sequence = sample['item.ids'][:-1] - next_item_sequence = sample['item.ids'][1:] + item_sequence = sample["item.ids"][:-1] + next_item_sequence = sample["item.ids"][1:] if self._num_negatives == 0: return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'positive.ids': next_item_sequence, - 'positive.length': len(next_item_sequence), + "user.ids": sample["user.ids"], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "positive.ids": next_item_sequence, + "positive.length": len(next_item_sequence), } else: - negative_sequence = ( - self._negative_sampler.generate_negative_samples( - sample, - self._num_negatives, - ) + negative_sequence = self._negative_sampler.generate_negative_samples( + sample, + self._num_negatives, ) return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'positive.ids': next_item_sequence, - 'positive.length': len(next_item_sequence), - 'negative.ids': negative_sequence, - 'negative.length': len(negative_sequence), + "user.ids": sample["user.ids"], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "positive.ids": next_item_sequence, + "positive.length": len(next_item_sequence), + "negative.ids": negative_sequence, + "negative.length": len(negative_sequence), } class NextItemPredictionEvalSampler( EvalSampler, - config_name='next_item_prediction', + config_name="next_item_prediction", ): @classmethod def create_from_config(cls, config, **kwargs): return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], ) diff --git a/src/irec/dataset/sasrec.py b/src/irec/dataset/sasrec.py index 0cba986..0c1dd97 100644 --- a/src/irec/dataset/sasrec.py +++ b/src/irec/dataset/sasrec.py @@ -3,35 +3,41 @@ from .base import BaseSequenceDataset, SequenceDataset, MCLSRDataset from .samplers import TrainSampler, EvalSampler -class SASRecDataset(BaseSequenceDataset, config_name='sasrec_comparison'): + +class SASRecDataset(BaseSequenceDataset, config_name="sasrec_comparison"): @classmethod def create_from_config(cls, config, **kwargs): - data_dir = os.path.join(config['path_to_data_dir'], config['name']) - max_seq_len = config.get('max_sequence_length') + data_dir = os.path.join(config["path_to_data_dir"], config["name"]) + max_seq_len = config.get("max_sequence_length") train_dataset, u1, i1, _ = SequenceDataset._create_dataset( - dir_path=data_dir, - part='train_sasrec', - max_sequence_length=max_seq_len + dir_path=data_dir, part="train_sasrec", max_sequence_length=max_seq_len ) - validation_dataset, test_dataset, u_eval, i_eval = \ + validation_dataset, test_dataset, u_eval, i_eval = ( MCLSRDataset._create_evaluation_sets(data_dir, max_seq_len) + ) num_users = max(u1, u_eval) num_items = max(i1, i_eval) - + train_sampler = TrainSampler.create_from_config( - config['train_sampler'], - dataset=train_dataset, num_users=num_users, num_items=num_items + config["train_sampler"], + dataset=train_dataset, + num_users=num_users, + num_items=num_items, ) validation_sampler = EvalSampler.create_from_config( - config['eval_sampler'], - dataset=validation_dataset, num_users=num_users, num_items=num_items + config["eval_sampler"], + dataset=validation_dataset, + num_users=num_users, + num_items=num_items, ) test_sampler = EvalSampler.create_from_config( - config['eval_sampler'], - dataset=test_dataset, num_users=num_users, num_items=num_items + config["eval_sampler"], + dataset=test_dataset, + num_users=num_users, + num_items=num_items, ) return cls( @@ -40,5 +46,5 @@ def create_from_config(cls, config, **kwargs): test_sampler=test_sampler, num_users=num_users, num_items=num_items, - max_sequence_length=max_seq_len + max_sequence_length=max_seq_len, ) diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index e082b4d..f8a0762 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -19,12 +19,12 @@ class TorchLoss(BaseLoss, nn.Module): pass -class IdentityLoss(BaseLoss, config_name='identity'): +class IdentityLoss(BaseLoss, config_name="identity"): def __call__(self, inputs): return inputs -class CompositeLoss(TorchLoss, config_name='composite'): +class CompositeLoss(TorchLoss, config_name="composite"): def __init__(self, losses, weights=None, output_prefix=None): super().__init__() self._losses = losses @@ -36,8 +36,8 @@ def create_from_config(cls, config, **kwargs): losses = [] weights = [] - for loss_cfg in copy.deepcopy(config)['losses']: - weight = loss_cfg.pop('weight') if 'weight' in loss_cfg else 1.0 + for loss_cfg in copy.deepcopy(config)["losses"]: + weight = loss_cfg.pop("weight") if "weight" in loss_cfg else 1.0 loss_function = BaseLoss.create_from_config(loss_cfg) weights.append(weight) @@ -46,7 +46,7 @@ def create_from_config(cls, config, **kwargs): return cls( losses=losses, weights=weights, - output_prefix=config.get('output_prefix'), + output_prefix=config.get("output_prefix"), ) def forward(self, inputs): @@ -60,7 +60,7 @@ def forward(self, inputs): return total_loss -class FpsLoss(TorchLoss, config_name='fps'): +class FpsLoss(TorchLoss, config_name="fps"): def __init__( self, fst_embeddings_prefix, @@ -75,7 +75,7 @@ def __init__( self._snd_embeddings_prefix = snd_embeddings_prefix self._tau = tau self._loss_function = nn.CrossEntropyLoss( - reduction='mean' if use_mean else 'sum', + reduction="mean" if use_mean else "sum", ) self._normalize_embeddings = normalize_embeddings self._output_prefix = output_prefix @@ -84,21 +84,17 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - fst_embeddings_prefix=config['fst_embeddings_prefix'], - snd_embeddings_prefix=config['snd_embeddings_prefix'], - tau=config.get('temperature', 1.0), - normalize_embeddings=config.get('normalize_embeddings', False), - use_mean=config.get('use_mean', True), - output_prefix=config.get('output_prefix') + fst_embeddings_prefix=config["fst_embeddings_prefix"], + snd_embeddings_prefix=config["snd_embeddings_prefix"], + tau=config.get("temperature", 1.0), + normalize_embeddings=config.get("normalize_embeddings", False), + use_mean=config.get("use_mean", True), + output_prefix=config.get("output_prefix"), ) def forward(self, inputs): - fst_embeddings = inputs[ - self._fst_embeddings_prefix - ] # (x, embedding_dim) - snd_embeddings = inputs[ - self._snd_embeddings_prefix - ] # (x, embedding_dim) + fst_embeddings = inputs[self._fst_embeddings_prefix] # (x, embedding_dim) + snd_embeddings = inputs[self._snd_embeddings_prefix] # (x, embedding_dim) batch_size = fst_embeddings.shape[0] @@ -125,7 +121,9 @@ def forward(self, inputs): torch.diag(similarity_scores, -batch_size), ), dim=0, - ).reshape(2 * batch_size, 1) # (2 * x, 1) + ).reshape( + 2 * batch_size, 1 + ) # (2 * x, 1) assert torch.allclose( torch.diag(similarity_scores, batch_size), torch.diag(similarity_scores, -batch_size), @@ -162,14 +160,9 @@ def forward(self, inputs): return loss -class SASRecLoss(TorchLoss, config_name='sasrec'): +class SASRecLoss(TorchLoss, config_name="sasrec"): - def __init__( - self, - positive_prefix, - negative_prefix, - output_prefix=None - ): + def __init__(self, positive_prefix, negative_prefix, output_prefix=None): super().__init__() self._positive_prefix = positive_prefix self._negative_prefix = negative_prefix @@ -192,7 +185,7 @@ def forward(self, inputs): return loss -class SamplesSoftmaxLoss(TorchLoss, config_name='sampled_softmax'): +class SamplesSoftmaxLoss(TorchLoss, config_name="sampled_softmax"): def __init__( self, queries_prefix, @@ -207,9 +200,7 @@ def __init__( self._output_prefix = output_prefix def forward(self, inputs): - queries_embeddings = inputs[ - self._queries_prefix - ] # (batch_size, embedding_dim) + queries_embeddings = inputs[self._queries_prefix] # (batch_size, embedding_dim) positive_embeddings = inputs[ self._positive_prefix ] # (batch_size, embedding_dim) @@ -219,15 +210,17 @@ def forward(self, inputs): # b -- batch_size, d -- embedding_dim positive_scores = torch.einsum( - 'bd,bd->b', + "bd,bd->b", queries_embeddings, positive_embeddings, - ).unsqueeze(-1) # (batch_size, 1) + ).unsqueeze( + -1 + ) # (batch_size, 1) if negative_embeddings.dim() == 2: # (num_negatives, embedding_dim) # b -- batch_size, n -- num_negatives, d -- embedding_dim negative_scores = torch.einsum( - 'bd,nd->bn', + "bd,nd->bn", queries_embeddings, negative_embeddings, ) # (batch_size, num_negatives) @@ -237,7 +230,7 @@ def forward(self, inputs): ) # (batch_size, num_negatives, embedding_dim) # b -- batch_size, n -- num_negatives, d -- embedding_dim negative_scores = torch.einsum( - 'bd,bnd->bn', + "bd,bnd->bn", queries_embeddings, negative_embeddings, ) # (batch_size, num_negatives) @@ -259,7 +252,7 @@ def forward(self, inputs): return loss -class MCLSRLoss(TorchLoss, config_name='mclsr'): +class MCLSRLoss(TorchLoss, config_name="mclsr"): def __init__( self, all_scores_prefix, @@ -306,5 +299,3 @@ def forward(self, inputs): inputs[self._output_prefix] = loss.cpu().item() return loss - - diff --git a/src/irec/metric/base.py b/src/irec/metric/base.py index da68965..87c0190 100644 --- a/src/irec/metric/base.py +++ b/src/irec/metric/base.py @@ -12,7 +12,7 @@ def reduce(self): raise NotImplementedError -class StaticMetric(BaseMetric, config_name='dummy'): +class StaticMetric(BaseMetric, config_name="dummy"): def __init__(self, name, value): self._name = name self._value = value @@ -23,16 +23,14 @@ def __call__(self, inputs): return inputs -class CompositeMetric(BaseMetric, config_name='composite'): +class CompositeMetric(BaseMetric, config_name="composite"): def __init__(self, metrics): self._metrics = metrics @classmethod def create_from_config(cls, config): return cls( - metrics=[ - BaseMetric.create_from_config(cfg) for cfg in config['metrics'] - ], + metrics=[BaseMetric.create_from_config(cfg) for cfg in config["metrics"]], ) def __call__(self, inputs): @@ -41,7 +39,7 @@ def __call__(self, inputs): return inputs -class NDCGMetric(BaseMetric, config_name='ndcg'): +class NDCGMetric(BaseMetric, config_name="ndcg"): def __init__(self, k): self._k = k @@ -50,9 +48,8 @@ def __call__(self, inputs, pred_prefix, labels_prefix): :, : self._k, ].float() # (batch_size, top_k_indices) - labels = inputs['{}.ids'.format(labels_prefix)].float() # (batch_size) + labels = inputs["{}.ids".format(labels_prefix)].float() # (batch_size) - assert labels.shape[0] == predictions.shape[0] hits = torch.eq( @@ -61,13 +58,15 @@ def __call__(self, inputs, pred_prefix, labels_prefix): ).float() # (batch_size, top_k_indices) discount_factor = 1 / torch.log2( torch.arange(1, self._k + 1, 1).float() + 1.0, - ).to(hits.device) # (k) - dcg = torch.einsum('bk,k->b', hits, discount_factor) # (batch_size) + ).to( + hits.device + ) # (k) + dcg = torch.einsum("bk,k->b", hits, discount_factor) # (batch_size) return dcg.cpu().tolist() -class RecallMetric(BaseMetric, config_name='recall'): +class RecallMetric(BaseMetric, config_name="recall"): def __init__(self, k): self._k = k @@ -76,7 +75,7 @@ def __call__(self, inputs, pred_prefix, labels_prefix): :, : self._k, ].float() # (batch_size, top_k_indices) - labels = inputs['{}.ids'.format(labels_prefix)].float() # (batch_size) + labels = inputs["{}.ids".format(labels_prefix)].float() # (batch_size) assert labels.shape[0] == predictions.shape[0] @@ -89,35 +88,34 @@ def __call__(self, inputs, pred_prefix, labels_prefix): return recall.cpu().tolist() -class CoverageMetric(StatefullMetric, config_name='coverage'): +class CoverageMetric(StatefullMetric, config_name="coverage"): def __init__(self, k, num_items): self._k = k self._num_items = num_items @classmethod def create_from_config(cls, config, **kwargs): - return cls(k=config['k'], num_items=kwargs['num_items']) + return cls(k=config["k"], num_items=kwargs["num_items"]) def __call__(self, inputs, pred_prefix, labels_prefix): predictions = inputs[pred_prefix][ :, : self._k, ].float() # (batch_size, top_k_indices) - return ( - predictions.view(-1).long().cpu().detach().tolist() - ) # (batch_size * k) + return predictions.view(-1).long().cpu().detach().tolist() # (batch_size * k) def reduce(self, values): return len(set(values)) / self._num_items -class MCLSRNDCGMetric(BaseMetric, config_name='mclsr-ndcg'): + +class MCLSRNDCGMetric(BaseMetric, config_name="mclsr-ndcg"): def __init__(self, k): self._k = k def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k) - labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,) - labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,) + predictions = inputs[pred_prefix][:, : self._k] # (batch_size, k) + labels_flat = inputs[f"{labels_prefix}.ids"] # (total_labels,) + labels_lengths = inputs[f"{labels_prefix}.length"] # (batch_size,) assert predictions.shape[0] == labels_lengths.shape[0] @@ -129,30 +127,30 @@ def __call__(self, inputs, pred_prefix, labels_prefix): user_labels = labels_flat[offset : offset + num_user_labels] offset += num_user_labels - hits_mask = torch.isin(user_predictions, user_labels) # (k,) -> True/False - + hits_mask = torch.isin(user_predictions, user_labels) # (k,) -> True/False + positions = torch.arange(2, self._k + 2, device=predictions.device) weights = 1 / torch.log2(positions.float()) dcg = (hits_mask.float() * weights).sum() - + num_ideal_hits = min(self._k, num_user_labels) idcg_weights = weights[:num_ideal_hits] idcg = idcg_weights.sum() - + ndcg = dcg / idcg if idcg > 0 else torch.tensor(0.0) dcg_scores.append(ndcg.cpu().item()) - + return dcg_scores -class MCLSRRecallMetric(BaseMetric, config_name='mclsr-recall'): +class MCLSRRecallMetric(BaseMetric, config_name="mclsr-recall"): def __init__(self, k): self._k = k def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k) - labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,) - labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,) + predictions = inputs[pred_prefix][:, : self._k] # (batch_size, k) + labels_flat = inputs[f"{labels_prefix}.ids"] # (total_labels,) + labels_lengths = inputs[f"{labels_prefix}.length"] # (batch_size,) assert predictions.shape[0] == labels_lengths.shape[0] @@ -163,22 +161,25 @@ def __call__(self, inputs, pred_prefix, labels_prefix): num_user_labels = labels_lengths[i] user_labels = labels_flat[offset : offset + num_user_labels] offset += num_user_labels - + hits = torch.isin(user_predictions, user_labels).sum().float() - - recall = hits / num_user_labels if num_user_labels > 0 else torch.tensor(0.0) + + recall = ( + hits / num_user_labels if num_user_labels > 0 else torch.tensor(0.0) + ) recall_scores.append(recall.cpu().item()) - + return recall_scores -class MCLSRHitRateMetric(BaseMetric, config_name='mclsr-hit'): + +class MCLSRHitRateMetric(BaseMetric, config_name="mclsr-hit"): def __init__(self, k): self._k = k def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k) - labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,) - labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,) + predictions = inputs[pred_prefix][:, : self._k] # (batch_size, k) + labels_flat = inputs[f"{labels_prefix}.ids"] # (total_labels,) + labels_lengths = inputs[f"{labels_prefix}.length"] # (batch_size,) assert predictions.shape[0] == labels_lengths.shape[0] @@ -194,9 +195,9 @@ def __call__(self, inputs, pred_prefix, labels_prefix): user_labels = labels_flat[offset : offset + num_user_labels] offset += num_user_labels - + is_hit = torch.isin(user_predictions, user_labels).any() - + hit_scores.append(float(is_hit)) - - return hit_scores \ No newline at end of file + + return hit_scores diff --git a/src/irec/models/base.py b/src/irec/models/base.py index 2f059f7..22555f4 100644 --- a/src/irec/models/base.py +++ b/src/irec/models/base.py @@ -21,8 +21,8 @@ class TorchModel(nn.Module, BaseModel): @torch.no_grad() def _init_weights(self, initializer_range): for key, value in self.named_parameters(): - if 'weight' in key: - if 'norm' in key: + if "weight" in key: + if "norm" in key: nn.init.ones_(value.data) else: nn.init.trunc_normal_( @@ -31,10 +31,10 @@ def _init_weights(self, initializer_range): a=-2 * initializer_range, b=2 * initializer_range, ) - elif 'bias' in key: + elif "bias" in key: nn.init.zeros_(value.data) else: - raise ValueError(f'Unknown transformer weight: {key}') + raise ValueError(f"Unknown transformer weight: {key}") @staticmethod def _get_last_embedding(embeddings, mask): @@ -54,12 +54,12 @@ def _get_last_embedding(embeddings, mask): ) # (batch_size, 1, emb_dim) last_embeddings = last_embeddings[last_masks] # (batch_size, emb_dim) if not torch.allclose(embeddings[mask][-1], last_embeddings[-1]): - logger.debug(f'Embeddings: {embeddings}') + logger.debug(f"Embeddings: {embeddings}") logger.debug( - f'Lengths: {lengths}, max: {lengths.max()}, min: {lengths.min()}', + f"Lengths: {lengths}, max: {lengths.max()}, min: {lengths.min()}", ) - logger.debug(f'Last embedding from mask: {embeddings[mask][-1]}') - logger.debug(f'Last embedding from gather: {last_embeddings[-1]}') + logger.debug(f"Last embedding from mask: {embeddings[mask][-1]}") + logger.debug(f"Last embedding from gather: {last_embeddings[-1]}") assert False return last_embeddings @@ -74,7 +74,7 @@ def __init__( num_layers, dim_feedforward, dropout=0.0, - activation='relu', + activation="relu", layer_norm_eps=1e-5, is_causal=True, ): @@ -85,8 +85,7 @@ def __init__( self._embedding_dim = embedding_dim self._item_embeddings = nn.Embedding( - num_embeddings=num_items - + 2, # add zero embedding + mask embedding + num_embeddings=num_items + 2, # add zero embedding + mask embedding embedding_dim=embedding_dim, ) self._position_embeddings = nn.Embedding( @@ -135,9 +134,7 @@ def _apply_sequential_encoder(self, events, lengths, add_cls_token=False): .tile([batch_size, 1]) .long() ) # (batch_size, seq_len) - positions_mask = ( - positions < lengths[:, None] - ) # (batch_size, max_seq_len) + positions_mask = positions < lengths[:, None] # (batch_size, max_seq_len) positions = positions[positions_mask] # (all_batch_events) position_embeddings = self._position_embeddings( @@ -215,7 +212,9 @@ def _add_cls_token(items, lengths, cls_token_id=0): dim=0, index=torch.cat( [torch.LongTensor([0]).to(DEVICE), lengths + 1], - ).cumsum(dim=0)[:-1], + ).cumsum( + dim=0 + )[:-1], ) # (num_new_items) new_items[old_items_mask] = items new_length = lengths + 1 diff --git a/src/irec/models/mclsr.py b/src/irec/models/mclsr.py index 1708747..bf20e32 100644 --- a/src/irec/models/mclsr.py +++ b/src/irec/models/mclsr.py @@ -6,7 +6,7 @@ from irec.utils import create_masked_tensor -class MCLSRModel(TorchModel, config_name='mclsr'): +class MCLSRModel(TorchModel, config_name="mclsr"): def __init__( self, sequence_prefix, @@ -50,8 +50,7 @@ def __init__( self._item_graph = item_graph self._item_embeddings = nn.Embedding( - num_embeddings=num_items - + 2, # add zero embedding + mask embedding + num_embeddings=num_items + 2, # add zero embedding + mask embedding embedding_dim=embedding_dim, ) self._position_embeddings = nn.Embedding( @@ -61,8 +60,7 @@ def __init__( ) self._user_embeddings = nn.Embedding( - num_embeddings=num_users - + 2, # add zero embedding + mask embedding + num_embeddings=num_users + 2, # add zero embedding + mask embedding embedding_dim=embedding_dim, ) @@ -155,23 +153,23 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - sequence_prefix=config['sequence_prefix'], - user_prefix=config['user_prefix'], - labels_prefix=config['labels_prefix'], - negatives_prefix=config.get('negatives_prefix', 'negatives'), - candidate_prefix=config['candidate_prefix'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_graph_layers=config['num_graph_layers'], - common_graph=kwargs['graph'], - user_graph=kwargs['user_graph'], - item_graph=kwargs['item_graph'], - dropout=config.get('dropout', 0.0), - layer_norm_eps=config.get('layer_norm_eps', 1e-5), - graph_dropout=config.get('graph_dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), + sequence_prefix=config["sequence_prefix"], + user_prefix=config["user_prefix"], + labels_prefix=config["labels_prefix"], + negatives_prefix=config.get("negatives_prefix", "negatives"), + candidate_prefix=config["candidate_prefix"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_graph_layers=config["num_graph_layers"], + common_graph=kwargs["graph"], + user_graph=kwargs["user_graph"], + item_graph=kwargs["item_graph"], + dropout=config.get("dropout", 0.0), + layer_norm_eps=config.get("layer_norm_eps", 1e-5), + graph_dropout=config.get("graph_dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), ) def _apply_graph_encoder(self, embeddings, graph, use_mean=False): @@ -200,12 +198,12 @@ def _apply_graph_encoder(self, embeddings, graph, use_mean=False): def forward(self, inputs): all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) + "{}.ids".format(self._sequence_prefix) ] # (all_batch_events) all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) + "{}.length".format(self._sequence_prefix) ] # (batch_size) - user_ids = inputs['{}.ids'.format(self._user_prefix)] # (batch_size) + user_ids = inputs["{}.ids".format(self._user_prefix)] # (batch_size) embeddings = self._item_embeddings( all_sample_events, @@ -243,11 +241,11 @@ def forward(self, inputs): lengths=all_sample_lengths, ) # (batch_size, seq_len, embedding_dim) assert torch.allclose(position_embeddings[~mask], embeddings[~mask]) - + positioned_embeddings = ( embeddings + position_embeddings ) # (batch_size, seq_len, embedding_dim) - + positioned_embeddings = self._layernorm( positioned_embeddings, ) # (batch_size, seq_len, embedding_dim) @@ -258,7 +256,7 @@ def forward(self, inputs): # formula 2 sequential_attention_matrix = self._current_interest_learning_encoder( - positioned_embeddings, # E_u,p + positioned_embeddings, # E_u,p ).squeeze() # (batch_size, seq_len) sequential_attention_matrix[~mask] = -torch.inf @@ -269,75 +267,90 @@ def forward(self, inputs): # formula 3 sequential_representation = torch.einsum( - 'bs,bsd->bd', - sequential_attention_matrix, # A^s + "bs,bsd->bd", + sequential_attention_matrix, # A^s embeddings, ) # (batch_size, embedding_dim) if self.training: # general interest # formula 4 - all_init_embeddings = torch.cat([self._user_embeddings.weight, - self._item_embeddings.weight], - dim=0) - all_graph_embeddings = self._apply_graph_encoder(embeddings=all_init_embeddings, - graph=self._graph) + all_init_embeddings = torch.cat( + [self._user_embeddings.weight, self._item_embeddings.weight], dim=0 + ) + all_graph_embeddings = self._apply_graph_encoder( + embeddings=all_init_embeddings, graph=self._graph + ) common_graph_user_embs_all, common_graph_item_embs_all = torch.split( all_graph_embeddings, [self._num_users + 2, self._num_items + 2] ) common_graph_user_embs_batch = common_graph_user_embs_all[user_ids] common_graph_item_embs_batch, _ = create_masked_tensor( - data=common_graph_item_embs_all[all_sample_events], - lengths=all_sample_lengths + data=common_graph_item_embs_all[all_sample_events], + lengths=all_sample_lengths, ) # formula 5: A_c = softmax(tanh(W_3 * h_u,uv) * (E_u,uv)^T) - graph_attention_matrix = torch.einsum('bd,bsd->bs', - self._general_interest_learning_encoder - (common_graph_user_embs_batch), - common_graph_item_embs_batch) + graph_attention_matrix = torch.einsum( + "bd,bsd->bs", + self._general_interest_learning_encoder(common_graph_user_embs_batch), + common_graph_item_embs_batch, + ) graph_attention_matrix[~mask] = -torch.inf graph_attention_matrix = torch.softmax(graph_attention_matrix, dim=1) # formula 6: I_c = A_c * E_u,uv - original_graph_representation = torch.einsum('bs,bsd->bd', - graph_attention_matrix, - common_graph_item_embs_batch) + original_graph_representation = torch.einsum( + "bs,bsd->bd", graph_attention_matrix, common_graph_item_embs_batch + ) original_sequential_representation = sequential_representation # formula 13: I_comb = alpha * I_s + (1 - alpha) * I_c # L_P (Downstream Loss) - combined_representation = (self._alpha * original_sequential_representation + - (1 - self._alpha) * original_graph_representation) - labels = inputs['{}.ids'.format(self._labels_prefix)] + combined_representation = ( + self._alpha * original_sequential_representation + + (1 - self._alpha) * original_graph_representation + ) + labels = inputs["{}.ids".format(self._labels_prefix)] labels_embeddings = self._item_embeddings(labels) - + # formula 7 # L_IL (Interest-level CL) - sequential_representation_proj = self._sequential_projector(original_sequential_representation) - graph_representation_proj = self._graph_projector(original_graph_representation) + sequential_representation_proj = self._sequential_projector( + original_sequential_representation + ) + graph_representation_proj = self._graph_projector( + original_graph_representation + ) # formula 9: H_u,uu = GraphEncoder(H_u, G_uu) # L_UC (User-level CL) - user_graph_user_embs_all = self._apply_graph_encoder(embeddings=self._user_embeddings.weight, - graph=self._user_graph) + user_graph_user_embs_all = self._apply_graph_encoder( + embeddings=self._user_embeddings.weight, graph=self._user_graph + ) user_graph_user_embs_batch = user_graph_user_embs_all[user_ids] # formula 10 # T_f,uu = MLP(H_u,uu) и T_f,uv = MLP(H_u,uv) - user_graph_user_embeddings_proj = self._user_projection(user_graph_user_embs_batch) - common_graph_user_embeddings_proj = self._user_projection(common_graph_user_embs_batch) + user_graph_user_embeddings_proj = self._user_projection( + user_graph_user_embs_batch + ) + common_graph_user_embeddings_proj = self._user_projection( + common_graph_user_embs_batch + ) # item level CL common_graph_items_flat = common_graph_item_embs_batch[mask] - - item_graph_items_all = self._apply_graph_encoder(embeddings=self._item_embeddings.weight, - graph=self._item_graph) + + item_graph_items_all = self._apply_graph_encoder( + embeddings=self._item_embeddings.weight, graph=self._item_graph + ) item_graph_items_flat = item_graph_items_all[all_sample_events] - unique_item_ids, inverse_indices = torch.unique(all_sample_events, - return_inverse=True) + unique_item_ids, inverse_indices = torch.unique( + all_sample_events, return_inverse=True + ) try: from torch_scatter import scatter_mean @@ -345,75 +358,86 @@ def forward(self, inputs): # print("Warning: torch_scatter not found. Using a slower fallback function.") def scatter_mean(src, index, dim=0, dim_size=None): out_size = dim_size if dim_size is not None else index.max() + 1 - out = torch.zeros((out_size, src.size(1)), dtype=src.dtype, device=src.device) - counts = torch.bincount(index, minlength=out_size).unsqueeze(-1).clamp(min=1) - return out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) / counts - + out = torch.zeros( + (out_size, src.size(1)), dtype=src.dtype, device=src.device + ) + counts = ( + torch.bincount(index, minlength=out_size) + .unsqueeze(-1) + .clamp(min=1) + ) + return ( + out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) + / counts + ) + num_unique_items = unique_item_ids.shape[0] - unique_common_graph_items = scatter_mean(common_graph_items_flat, - inverse_indices, dim=0, - dim_size=num_unique_items) + unique_common_graph_items = scatter_mean( + common_graph_items_flat, + inverse_indices, + dim=0, + dim_size=num_unique_items, + ) - unique_item_graph_items = scatter_mean(item_graph_items_flat, - inverse_indices, dim=0, - dim_size=num_unique_items) + unique_item_graph_items = scatter_mean( + item_graph_items_flat, inverse_indices, dim=0, dim_size=num_unique_items + ) # projection for Item-level Feature CL - unique_common_graph_items_proj = self._item_projection(unique_common_graph_items) - unique_item_graph_items_proj = self._item_projection(unique_item_graph_items) - + unique_common_graph_items_proj = self._item_projection( + unique_common_graph_items + ) + unique_item_graph_items_proj = self._item_projection( + unique_item_graph_items + ) - raw_negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] + raw_negative_ids = inputs["{}.ids".format(self._negatives_prefix)] num_negatives = raw_negative_ids.shape[0] // batch_size - negative_ids = raw_negative_ids.view(batch_size, num_negatives) # (Batch, NumNegs) - negative_embeddings = self._item_embeddings(negative_ids) # (Batch, NumNegs, Dim) + negative_ids = raw_negative_ids.view( + batch_size, num_negatives + ) # (Batch, NumNegs) + negative_embeddings = self._item_embeddings( + negative_ids + ) # (Batch, NumNegs, Dim) # import code; code.interact(local=locals()) return { # L_P (formula 14) - 'combined_representation': combined_representation, - 'label_representation': labels_embeddings, - - 'negative_representation': negative_embeddings, - - + "combined_representation": combined_representation, + "label_representation": labels_embeddings, + "negative_representation": negative_embeddings, # --- ID PASS-THROUGH FOR LOGQ & MASKING --- # We pass raw item and user indices to enable advanced loss operations: - # 1. False Negative Masking: Allows SamplesSoftmaxLoss to identify and - # neutralize cases where a target item accidentally appears in the + # 1. False Negative Masking: Allows SamplesSoftmaxLoss to identify and + # neutralize cases where a target item accidentally appears in the # negative sampling pool. - # 2. Per-item LogQ Correction: Enables mapping item IDs to global frequency + # 2. Per-item LogQ Correction: Enables mapping item IDs to global frequency # stats (item_counts.pkl) to remove popularity bias (Sampling Bias). - 'positive_ids': labels, # Item target indices - 'negative_ids': negative_ids, # Sampled negative item indices - - # Useful for potential User-level LogQ correction as requested + "positive_ids": labels, # Item target indices + "negative_ids": negative_ids, # Sampled negative item indices + # Useful for potential User-level LogQ correction as requested # by the supervisor to handle highly active users. - 'user_ids': user_ids, - + "user_ids": user_ids, # for L_IL (formula 8) - 'sequential_representation': sequential_representation_proj, - 'graph_representation': graph_representation_proj, - + "sequential_representation": sequential_representation_proj, + "graph_representation": graph_representation_proj, # for L_UC (formula 11) - 'user_graph_user_embeddings': user_graph_user_embeddings_proj, - 'common_graph_user_embeddings': common_graph_user_embeddings_proj, - + "user_graph_user_embeddings": user_graph_user_embeddings_proj, + "common_graph_user_embeddings": common_graph_user_embeddings_proj, # for L_IC - 'item_graph_item_embeddings': unique_item_graph_items_proj, - 'common_graph_item_embeddings': unique_common_graph_items_proj, + "item_graph_item_embeddings": unique_item_graph_items_proj, + "common_graph_item_embeddings": unique_common_graph_items_proj, } else: # eval mode # formula 16: R(u,N) = Top-N((I_s)^T * h_o) - if '{}.ids'.format(self._candidate_prefix) in inputs: + if "{}.ids".format(self._candidate_prefix) in inputs: candidate_events = inputs[ - '{}.ids'.format(self._candidate_prefix) + "{}.ids".format(self._candidate_prefix) ] # (all_batch_candidates) candidate_lengths = inputs[ - '{}.length'.format(self._candidate_prefix) - + "{}.length".format(self._candidate_prefix) ] # (batch_size) candidate_embeddings = self._item_embeddings( @@ -426,23 +450,22 @@ def scatter_mean(src, index, dim=0, dim_size=None): ) # (batch_size, num_candidates, embedding_dim) candidate_scores = torch.einsum( - 'bd,bnd->bn', - sequential_representation, # I_s - candidate_embeddings, # h_o (and h_k) + "bd,bnd->bn", + sequential_representation, # I_s + candidate_embeddings, # h_o (and h_k) ) # (batch_size, num_candidates) else: candidate_embeddings = ( self._item_embeddings.weight ) # (num_items, embedding_dim) candidate_scores = torch.einsum( - 'bd,nd->bn', - sequential_representation, # I_s - candidate_embeddings, # all h_v + "bd,nd->bn", + sequential_representation, # I_s + candidate_embeddings, # all h_v ) # (batch_size, num_items) candidate_scores[:, 0] = -torch.inf candidate_scores[:, self._num_items + 1 :] = -torch.inf - values, indices = torch.topk( candidate_scores, k=50, @@ -450,4 +473,4 @@ def scatter_mean(src, index, dim=0, dim_size=None): largest=True, ) # (batch_size, 100), (batch_size, 100) - return indices \ No newline at end of file + return indices diff --git a/src/irec/models/sasrec.py b/src/irec/models/sasrec.py index 0cefdd5..52384d2 100644 --- a/src/irec/models/sasrec.py +++ b/src/irec/models/sasrec.py @@ -4,22 +4,22 @@ import torch -class SasRecModel(SequentialTorchModel, config_name='sasrec'): +class SasRecModel(SequentialTorchModel, config_name="sasrec"): def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02 + self, + sequence_prefix, + positive_prefix, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, ): super().__init__( num_items=num_items, @@ -31,7 +31,7 @@ def __init__( dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, - is_causal=True + is_causal=True, ) self._sequence_prefix = sequence_prefix self._positive_prefix = positive_prefix @@ -41,88 +41,100 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)), - num_layers=config['num_layers'], - dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02) + sequence_prefix=config["sequence_prefix"], + positive_prefix=config["positive_prefix"], + num_items=kwargs["num_items"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_layers=config["num_layers"], + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), ) def forward(self, inputs): - all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events) - all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size) + all_sample_events = inputs[ + "{}.ids".format(self._sequence_prefix) + ] # (all_batch_events) + all_sample_lengths = inputs[ + "{}.length".format(self._sequence_prefix) + ] # (batch_size) embeddings, mask = self._apply_sequential_encoder( all_sample_events, all_sample_lengths ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) if self.training: # training mode - all_positive_sample_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events) + all_positive_sample_events = inputs[ + "{}.ids".format(self._positive_prefix) + ] # (all_batch_events) - all_sample_embeddings = embeddings[mask] # (all_batch_events, embedding_dim) + all_sample_embeddings = embeddings[ + mask + ] # (all_batch_events, embedding_dim) all_embeddings = self._item_embeddings.weight # (num_items, embedding_dim) # a -- all_batch_events, n -- num_items, d -- embedding_dim all_scores = torch.einsum( - 'ad,nd->an', - all_sample_embeddings, - all_embeddings + "ad,nd->an", all_sample_embeddings, all_embeddings ) # (all_batch_events, num_items) positive_scores = torch.gather( - input=all_scores, - dim=1, - index=all_positive_sample_events[..., None] - )[:, 0] # (all_batch_items) + input=all_scores, dim=1, index=all_positive_sample_events[..., None] + )[ + :, 0 + ] # (all_batch_items) negative_scores = torch.gather( input=all_scores, dim=1, - index=torch.randint(low=0, high=all_scores.shape[1], size=all_positive_sample_events.shape, device=all_positive_sample_events.device)[..., None] - )[:, 0] # (all_batch_items) + index=torch.randint( + low=0, + high=all_scores.shape[1], + size=all_positive_sample_events.shape, + device=all_positive_sample_events.device, + )[..., None], + )[ + :, 0 + ] # (all_batch_items) return { - 'positive_scores': positive_scores, - 'negative_scores': negative_scores + "positive_scores": positive_scores, + "negative_scores": negative_scores, } else: # eval mode - last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim) + last_embeddings = self._get_last_embedding( + embeddings, mask + ) # (batch_size, embedding_dim) # b - batch_size, n - num_candidates, d - embedding_dim candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight + "bd,nd->bn", last_embeddings, self._item_embeddings.weight ) # (batch_size, num_items + 2) _, indices = torch.topk( - candidate_scores, - k=50, dim=-1, largest=True + candidate_scores, k=50, dim=-1, largest=True ) # (batch_size, 20) return indices -class SasRecInBatchModel(SasRecModel, config_name='sasrec_in_batch'): +class SasRecInBatchModel(SasRecModel, config_name="sasrec_in_batch"): def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02 + self, + sequence_prefix, + positive_prefix, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, ): super().__init__( sequence_prefix=sequence_prefix, @@ -136,27 +148,31 @@ def __init__( dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, - initializer_range=initializer_range + initializer_range=initializer_range, ) @classmethod def create_from_config(cls, config, **kwargs): return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)), - num_layers=config['num_layers'], - dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02) + sequence_prefix=config["sequence_prefix"], + positive_prefix=config["positive_prefix"], + num_items=kwargs["num_items"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_layers=config["num_layers"], + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), ) def forward(self, inputs): - all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events) - all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size) + all_sample_events = inputs[ + "{}.ids".format(self._sequence_prefix) + ] # (all_batch_events) + all_sample_lengths = inputs[ + "{}.length".format(self._sequence_prefix) + ] # (batch_size) embeddings, mask = self._apply_sequential_encoder( all_sample_events, all_sample_lengths @@ -164,10 +180,14 @@ def forward(self, inputs): if self.training: # training mode # queries - in_batch_queries_embeddings = embeddings[mask] # (all_batch_events, embedding_dim) + in_batch_queries_embeddings = embeddings[ + mask + ] # (all_batch_events, embedding_dim) # positives - in_batch_positive_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events) + in_batch_positive_events = inputs[ + "{}.ids".format(self._positive_prefix) + ] # (all_batch_events) in_batch_positive_embeddings = self._item_embeddings( in_batch_positive_events ) # (all_batch_events, embedding_dim) @@ -182,35 +202,33 @@ def forward(self, inputs): ) # (batch_size, embedding_dim) return { - 'query_embeddings': in_batch_queries_embeddings, - 'positive_embeddings': in_batch_positive_embeddings, - 'negative_embeddings': in_batch_negative_embeddings, - + "query_embeddings": in_batch_queries_embeddings, + "positive_embeddings": in_batch_positive_embeddings, + "negative_embeddings": in_batch_negative_embeddings, # --- ID PASS-THROUGH FOR DOWNSTREAM LOSS OPERATIONS --- # We pass the raw item IDs to the loss function to enable: - # 1. False Negative Masking: Identifying and neutralizing cases where a positive - # item for one user accidentally appears as a negative sample for another + # 1. False Negative Masking: Identifying and neutralizing cases where a positive + # item for one user accidentally appears as a negative sample for another # user in the same batch (critical for In-Batch Negatives stability). - # 2. Per-item LogQ Correction: Mapping item IDs to their global frequencies + # 2. Per-item LogQ Correction: Mapping item IDs to their global frequencies # to subtract log(Q) based on the specific item popularity. - 'positive_ids': in_batch_positive_events, - 'negative_ids': in_batch_negative_ids + "positive_ids": in_batch_positive_events, + "negative_ids": in_batch_negative_ids, } else: # eval mode - last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim) + last_embeddings = self._get_last_embedding( + embeddings, mask + ) # (batch_size, embedding_dim) # b - batch_size, n - num_candidates, d - embedding_dim candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight + "bd,nd->bn", last_embeddings, self._item_embeddings.weight ) # (batch_size, num_items + 2) candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1:] = -torch.inf + candidate_scores[:, self._num_items + 1 :] = -torch.inf _, indices = torch.topk( - candidate_scores, - k=50, dim=-1, largest=True + candidate_scores, k=50, dim=-1, largest=True ) # (batch_size, 20) - return indices \ No newline at end of file + return indices diff --git a/src/irec/optimizer/base.py b/src/irec/optimizer/base.py index 0fc1f70..51d9464 100644 --- a/src/irec/optimizer/base.py +++ b/src/irec/optimizer/base.py @@ -5,14 +5,14 @@ import torch OPTIMIZERS = { - 'sgd': torch.optim.SGD, - 'adam': torch.optim.Adam, - 'adamw': torch.optim.AdamW, + "sgd": torch.optim.SGD, + "adam": torch.optim.Adam, + "adamw": torch.optim.AdamW, } SCHEDULERS = { - 'step': torch.optim.lr_scheduler.StepLR, - 'cyclic': torch.optim.lr_scheduler.CyclicLR, + "step": torch.optim.lr_scheduler.StepLR, + "cyclic": torch.optim.lr_scheduler.CyclicLR, } @@ -20,7 +20,7 @@ class BaseOptimizer(metaclass=MetaParent): pass -class BasicOptimizer(BaseOptimizer, config_name='basic'): +class BasicOptimizer(BaseOptimizer, config_name="basic"): def __init__( self, model, @@ -35,15 +35,15 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): - optimizer_cfg = copy.deepcopy(config['optimizer']) - optimizer = OPTIMIZERS[optimizer_cfg.pop('type')]( - kwargs['model'].parameters(), + optimizer_cfg = copy.deepcopy(config["optimizer"]) + optimizer = OPTIMIZERS[optimizer_cfg.pop("type")]( + kwargs["model"].parameters(), **optimizer_cfg, ) - if 'scheduler' in config: - scheduler_cfg = copy.deepcopy(config['scheduler']) - scheduler = SCHEDULERS[scheduler_cfg.pop('type')]( + if "scheduler" in config: + scheduler_cfg = copy.deepcopy(config["scheduler"]) + scheduler = SCHEDULERS[scheduler_cfg.pop("type")]( optimizer, **scheduler_cfg, ) @@ -51,10 +51,10 @@ def create_from_config(cls, config, **kwargs): scheduler = None return cls( - model=kwargs['model'], + model=kwargs["model"], optimizer=optimizer, scheduler=scheduler, - clip_grad_threshold=config.get('clip_grad_threshold', None), + clip_grad_threshold=config.get("clip_grad_threshold", None), ) def step(self, loss): @@ -72,7 +72,7 @@ def step(self, loss): self._scheduler.step() def state_dict(self): - state_dict = {'optimizer': self._optimizer.state_dict()} + state_dict = {"optimizer": self._optimizer.state_dict()} if self._scheduler is not None: - state_dict.update({'scheduler': self._scheduler.state_dict()}) + state_dict.update({"scheduler": self._scheduler.state_dict()}) return state_dict diff --git a/src/irec/train.py b/src/irec/train.py index 7844439..8251822 100644 --- a/src/irec/train.py +++ b/src/irec/train.py @@ -43,20 +43,20 @@ def train( best_epoch = 0 best_checkpoint = None - logger.debug('Start training...') + logger.debug("Start training...") while (epoch_cnt is None or epoch_num < epoch_cnt) and ( step_cnt is None or step_num < step_cnt ): if best_epoch + epochs_threshold < epoch_num: logger.debug( - 'There is no progress during {} epochs. Finish training'.format( + "There is no progress during {} epochs. Finish training".format( epochs_threshold, ), ) break - logger.debug(f'Start epoch {epoch_num}') + logger.debug(f"Start epoch {epoch_num}") for step, batch in enumerate(dataloader): batch_ = copy.deepcopy(batch) @@ -87,7 +87,7 @@ def train( best_epoch = epoch_num epoch_num += 1 - logger.debug('Training procedure has been finished!') + logger.debug("Training procedure has been finished!") return best_checkpoint @@ -95,71 +95,72 @@ def main(): fix_random_seed(seed_val) config = parse_args() - if config.get('use_wandb', False): + if config.get("use_wandb", False): wandb.init( - project='irec', - name=config['experiment_name'], + project="irec", + name=config["experiment_name"], sync_tensorboard=True, ) - tensorboard_writer = irec.utils.tensorboards.TensorboardWriter(config['experiment_name']) + tensorboard_writer = irec.utils.tensorboards.TensorboardWriter( + config["experiment_name"] + ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER = tensorboard_writer log_dir = tensorboard_writer.log_dir - config_save_path = os.path.join(log_dir, 'config.json') - with open(config_save_path, 'w') as f: + config_save_path = os.path.join(log_dir, "config.json") + with open(config_save_path, "w") as f: json.dump(config, f, indent=2) - - logger.debug('Training config: \n{}'.format(json.dumps(config, indent=2))) - logger.debug('Current DEVICE: {}'.format(DEVICE)) - logger.info(f"Experiment config saved to: {config_save_path}") + logger.debug("Training config: \n{}".format(json.dumps(config, indent=2))) + logger.debug("Current DEVICE: {}".format(DEVICE)) + logger.info(f"Experiment config saved to: {config_save_path}") - dataset = BaseDataset.create_from_config(config['dataset']) + dataset = BaseDataset.create_from_config(config["dataset"]) train_sampler, validation_sampler, test_sampler = dataset.get_samplers() train_dataloader = BaseDataloader.create_from_config( - config['dataloader']['train'], + config["dataloader"]["train"], dataset=train_sampler, **dataset.meta, ) validation_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], + config["dataloader"]["validation"], dataset=validation_sampler, **dataset.meta, ) eval_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], + config["dataloader"]["validation"], dataset=test_sampler, **dataset.meta, ) - model = BaseModel.create_from_config(config['model'], **dataset.meta).to( + model = BaseModel.create_from_config(config["model"], **dataset.meta).to( DEVICE, ) - if 'checkpoint' in config: + if "checkpoint" in config: ensure_checkpoints_dir() checkpoint_path = os.path.join( - './checkpoints', + "./checkpoints", f'{config["checkpoint"]}.pth', ) - logger.debug('Loading checkpoint from {}'.format(checkpoint_path)) + logger.debug("Loading checkpoint from {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) logger.debug(checkpoint.keys()) model.load_state_dict(checkpoint) - loss_function = BaseLoss.create_from_config(config['loss']) + loss_function = BaseLoss.create_from_config(config["loss"]) optimizer = BaseOptimizer.create_from_config( - config['optimizer'], + config["optimizer"], model=model, ) callback = BaseCallback.create_from_config( - config['callback'], + config["callback"], model=model, train_dataloader=train_dataloader, validation_dataloader=validation_dataloader, @@ -170,7 +171,7 @@ def main(): # TODO add verbose option for all callbacks, multiple optimizer options (???) # TODO create pre/post callbacks - logger.debug('Everything is ready for training process!') + logger.debug("Everything is ready for training process!") # Train process _ = train( @@ -179,22 +180,22 @@ def main(): optimizer=optimizer, loss_function=loss_function, callback=callback, - epoch_cnt=config.get('train_epochs_num'), - step_cnt=config.get('train_steps_num'), - best_metric=config.get('best_metric'), + epoch_cnt=config.get("train_epochs_num"), + step_cnt=config.get("train_steps_num"), + best_metric=config.get("best_metric"), ) - logger.debug('Saving model...') + logger.debug("Saving model...") ensure_checkpoints_dir() - checkpoint_path = './checkpoints/{}_final_state.pth'.format( - config['experiment_name'], + checkpoint_path = "./checkpoints/{}_final_state.pth".format( + config["experiment_name"], ) torch.save(model.state_dict(), checkpoint_path) - logger.debug('Saved model as {}'.format(checkpoint_path)) + logger.debug("Saved model as {}".format(checkpoint_path)) - if config.get('use_wandb', False): + if config.get("use_wandb", False): wandb.finish() -if __name__ == '__main__': +if __name__ == "__main__": main() From 43c5d3c7ee026e36c83e2c1c854bd0d48d7a1c08 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Sat, 4 Apr 2026 17:35:50 +0300 Subject: [PATCH 20/27] perf: optimize _filter_matrix_by_top_k using CSR in-place operations --- src/irec/dataset/base.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/irec/dataset/base.py b/src/irec/dataset/base.py index c380709..b87c56d 100644 --- a/src/irec/dataset/base.py +++ b/src/irec/dataset/base.py @@ -564,19 +564,23 @@ def _convert_sp_mat_to_sp_tensor(X): return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape)) @staticmethod - def _filter_matrix_by_top_k(matrix, k): - mat = matrix.tolil() + def filter_matrix_by_top_k(matrix, k): + mat = matrix.tocsr() for i in range(mat.shape[0]): - if len(mat.rows[i]) <= k: - continue - data = np.array(mat.data[i]) + start = mat.indptr[i] + end = mat.indptr[i + 1] - top_k_indices = np.argpartition(data, -k)[-k:] - mat.data[i] = [mat.data[i][j] for j in top_k_indices] - mat.rows[i] = [mat.rows[i][j] for j in top_k_indices] + if end - start > k: + row_view = mat.data[start:end] - return mat.tocsr() + threshold = np.partition(row_view, -k)[-k] + + row_view[row_view < threshold] = 0 + + mat.eliminate_zeros() + + return mat def get_samplers(self): return self._dataset.get_samplers() From aff9bd5050ba47740d28558d66e4751508752289 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Sat, 4 Apr 2026 19:01:09 +0300 Subject: [PATCH 21/27] perf: dataloader pin memory optimization --- src/irec/dataloader/base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/irec/dataloader/base.py b/src/irec/dataloader/base.py index 17c21bf..3cccb86 100644 --- a/src/irec/dataloader/base.py +++ b/src/irec/dataloader/base.py @@ -9,6 +9,21 @@ logger = logging.getLogger(__name__) +class BaseDataloader(metaclass=MetaParent): + pass + + +import copy + +from irec.utils import MetaParent +from .batch_processors import BaseBatchProcessor + +import logging +from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + + class BaseDataloader(metaclass=MetaParent): pass From 1d99479de6b6a60e44a4b357f5150b82a0f2dbac Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Sat, 4 Apr 2026 19:05:42 +0300 Subject: [PATCH 22/27] perf: mclsr cpu gpu (also forgot transfer exp-4 branch) --- src/irec/models/mclsr.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/irec/models/mclsr.py b/src/irec/models/mclsr.py index bf20e32..2de98f2 100644 --- a/src/irec/models/mclsr.py +++ b/src/irec/models/mclsr.py @@ -5,6 +5,8 @@ from irec.utils import create_masked_tensor +torch.backends.cudnn.benchmark = True + class MCLSRModel(TorchModel, config_name="mclsr"): def __init__( @@ -352,10 +354,34 @@ def forward(self, inputs): all_sample_events, return_inverse=True ) + # OLD BAD + # try: + # from torch_scatter import scatter_mean + # except ImportError: + # # print("Warning: torch_scatter not found. Using a slower fallback function.") + # def scatter_mean(src, index, dim=0, dim_size=None): + # out_size = dim_size if dim_size is not None else index.max() + 1 + # out = torch.zeros((out_size, src.size(1)), dtype=src.dtype, device=src.device) + # counts = torch.bincount(index, minlength=out_size).unsqueeze(-1).clamp(min=1) + # return out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) / counts + + # --- OPTIMIZED AGGREGATION: scatter_mean --- + # We use scatter_mean to aggregate features of unique items within a batch + # for Item-level Feature Contrastive Learning. + # + # Performance Note: + # We prioritize the 'torch-scatter' library because it provides highly optimized + # C++/CUDA kernels that perform in-place aggregation. + # + # The 'except ImportError' fallback is provided for environment compatibility, + # but it is NOT recommended for large-scale datasets like Amazon Books. + # The fallback implementation uses 'expand_as', which creates massive temporary + # tensors in GPU memory, potentially leading to Out-Of-Memory (OOM) errors + # when processing millions of interaction events. try: from torch_scatter import scatter_mean except ImportError: - # print("Warning: torch_scatter not found. Using a slower fallback function.") + def scatter_mean(src, index, dim=0, dim_size=None): out_size = dim_size if dim_size is not None else index.max() + 1 out = torch.zeros( @@ -366,6 +392,7 @@ def scatter_mean(src, index, dim=0, dim_size=None): .unsqueeze(-1) .clamp(min=1) ) + # WARNING: .expand_as() below is a memory bottleneck for large tensors return ( out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) / counts @@ -392,6 +419,9 @@ def scatter_mean(src, index, dim=0, dim_size=None): unique_item_graph_items ) + # negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] # (batch_size, num_negatives) + # negative_embeddings = self._item_embeddings(negative_ids) # (batch_size, num_negatives, embedding_dim) + raw_negative_ids = inputs["{}.ids".format(self._negatives_prefix)] num_negatives = raw_negative_ids.shape[0] // batch_size negative_ids = raw_negative_ids.view( From 1ebbf34d094def970c25b38948c963bb3ca3d89a Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Sat, 4 Apr 2026 19:08:06 +0300 Subject: [PATCH 23/27] perf dataloader --- src/irec/dataloader/base.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/irec/dataloader/base.py b/src/irec/dataloader/base.py index 3cccb86..5e09f46 100644 --- a/src/irec/dataloader/base.py +++ b/src/irec/dataloader/base.py @@ -9,21 +9,6 @@ logger = logging.getLogger(__name__) -class BaseDataloader(metaclass=MetaParent): - pass - - -import copy - -from irec.utils import MetaParent -from .batch_processors import BaseBatchProcessor - -import logging -from torch.utils.data import DataLoader - -logger = logging.getLogger(__name__) - - class BaseDataloader(metaclass=MetaParent): pass @@ -62,3 +47,12 @@ def create_from_config(cls, config, **kwargs): **create_config, ), ) + + # return cls( + # dataloader=DataLoader( + # kwargs['dataset'], + # collate_fn=batch_processor, + # pin_memory=True, + # **create_config, + # ), + # ) From ece48b86f438919a43a8389de4a9974742eb6b9c Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Sat, 4 Apr 2026 19:09:33 +0300 Subject: [PATCH 24/27] perf dataset --- src/irec/dataset/base.py | 120 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 111 insertions(+), 9 deletions(-) diff --git a/src/irec/dataset/base.py b/src/irec/dataset/base.py index b87c56d..5936a31 100644 --- a/src/irec/dataset/base.py +++ b/src/irec/dataset/base.py @@ -378,6 +378,10 @@ def _build_or_load_similarity_graph( ): if entity_type not in ["user", "item"]: raise ValueError("entity_type must be either 'user' or 'item'") + # have to delete and replace to not delete npz each time manually + # path_to_graph = os.path.join(self._graph_dir_path, '{}_graph.npz'.format(entity_type)) + + # instead better use such construction # neighborhood_size # The neighborhood_size is a filter that constrains the number of edges for each user or @@ -419,6 +423,7 @@ def _build_or_load_similarity_graph( continue visited_user_item_pairs.add((user_id, item_id)) + # TODO look here at review source_entity = user_id if is_user_graph else item_id connection_map = ( train_item_2_users if is_user_graph else train_user_2_items @@ -429,6 +434,11 @@ def _build_or_load_similarity_graph( if source_entity == connected_entity: continue + pair_key = (source_entity, connected_entity) + # if pair_key in visited_entity_pairs: + # continue + + # visited_entity_pairs.add(pair_key) interactions_fst.append(source_entity) interactions_snd.append(connected_entity) @@ -452,6 +462,7 @@ def _build_or_load_similarity_graph( def _build_or_load_bipartite_graph( self, graph_dir_path, train_user_interactions, train_item_interactions ): + # path_to_graph = os.path.join(graph_dir_path, 'general_graph.npz') train_suffix = "trainOnly" if self._use_train_data_only else "withValTest" filename = f"general_graph_{train_suffix}.npz" path_to_graph = os.path.join(graph_dir_path, filename) @@ -522,6 +533,40 @@ def create_from_config(cls, config): neighborhood_size=config.get("neighborhood_size", None), ) + # @staticmethod + # def get_sparse_graph_layer( + # sparse_matrix, + # fst_dim, + # snd_dim, + # biparite=False, + # ): + # if not biparite: + # adj_mat = sparse_matrix.tocsr() + # else: + # R = sparse_matrix.tocsr() + + # upper_right = R + # lower_left = R.T + + # upper_left = sp.csr_matrix((fst_dim, fst_dim)) + # lower_right = sp.csr_matrix((snd_dim, snd_dim)) + + # adj_mat = sp.bmat([ + # [upper_left, upper_right], + # [lower_left, lower_right] + # ]) + # assert adj_mat.shape == (fst_dim + snd_dim, fst_dim + snd_dim), ( + # f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}" + # ) + + # rowsum = np.array(adj_mat.sum(1)) + # d_inv = np.power(rowsum, -0.5).flatten() + # d_inv[np.isinf(d_inv)] = 0. + # d_mat_inv = sp.diags(d_inv) + + # norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv) + # return norm_adj.tocsr() + @staticmethod def get_sparse_graph_layer( sparse_matrix, @@ -546,12 +591,39 @@ def get_sparse_graph_layer( fst_dim + snd_dim, ), f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}" - rowsum = np.array(adj_mat.sum(1)) - d_inv = np.power(rowsum, -0.5).flatten() + # --- OLD IMPLEMENTATION (Slow & Memory Intensive for Large Corpus) --- + # rowsum = np.array(adj_mat.sum(1)) + # d_inv = np.power(rowsum, -0.5).flatten() + # d_inv[np.isinf(d_inv)] = 0. + # d_mat_inv = sp.diags(d_inv) + # norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv) + # return norm_adj.tocsr() + + # --- NEW OPTIMIZED IMPLEMENTATION --- + """ + Optimization Strategy: Vectorized Symmetric Normalization (D^-0.5 * A * D^-0.5). + + Justification for Amazon Books Scale: + 1. Memory Efficiency: Creating an explicit diagonal matrix 'd_mat_inv' + (size N x N) via sp.diags is redundant. For 800k+ nodes, this consumes + significant RAM and creates heavy intermediate objects. + 2. Computational Speed: Traditional matrix-matrix multiplication (.dot) + in Scipy sparse has higher overhead compared to row/column-wise scaling. + 3. Implementation: We perform element-wise multiplication of the sparse + matrix by 1D degree vectors. Multiplying a sparse matrix by a column + vector scales rows, and by a row vector scales columns. + + This results in an identical Laplacian matrix but is calculated in O(E) + time with minimal memory footprint, where E is the number of edges. + """ + rowsum = np.array(adj_mat.sum(1)).flatten() + d_inv = np.power(rowsum, -0.5) d_inv[np.isinf(d_inv)] = 0.0 - d_mat_inv = sp.diags(d_inv) - norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv) + # Scaling rows: multiply by column vector [N, 1] + # Scaling columns: multiply by row vector [1, N] + norm_adj = adj_mat.multiply(d_inv[:, np.newaxis]).multiply(d_inv) + return norm_adj.tocsr() @staticmethod @@ -564,22 +636,52 @@ def _convert_sp_mat_to_sp_tensor(X): return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape)) @staticmethod - def filter_matrix_by_top_k(matrix, k): + def _filter_matrix_by_top_k(matrix, k): + # --- OLD IMPLEMENTATION (Extremely slow conversion to LIL for large datasets) --- + # mat = matrix.tolil() + # for i in range(mat.shape[0]): + # if len(mat.rows[i]) <= k: + # continue + # data = np.array(mat.data[i]) + # top_k_indices = np.argpartition(data, -k)[-k:] + # mat.data[i] = [mat.data[i][j] for j in top_k_indices] + # mat.rows[i] = [mat.rows[i][j] for j in top_k_indices] + # return mat.tocsr() + + # --- NEW OPTIMIZED IMPLEMENTATION --- + """ + Optimization Strategy: Direct CSR Array Manipulation with NumPy Partitioning. + + Justification for Amazon Books Scale: + 1. Avoids LIL conversion: Converting a 450k x 300k matrix to LIL format + (List of Lists) is extremely memory-intensive and slow. + 2. In-place filtering: By accessing the CSR 'data' and 'indptr' arrays directly, + we perform the Top-K filtering with zero additional memory allocation + for the matrix structure itself. + 3. Algorithmic Speed: np.partition finds the threshold value in O(n) average + time. Slicing the underlying NumPy arrays is performed at near-C speed, + making this orders of magnitude faster than Python-level list operations. + """ mat = matrix.tocsr() for i in range(mat.shape[0]): start = mat.indptr[i] end = mat.indptr[i + 1] + # Only process rows that actually exceed the neighborhood size if end - start > k: - row_view = mat.data[start:end] + row_slice = mat.data[start:end] - threshold = np.partition(row_view, -k)[-k] + # Find the threshold value (the k-th largest element) + # np.partition is faster than a full sort: it puts the top-k values at the end + threshold = np.partition(row_slice, -k)[-k] - row_view[row_view < threshold] = 0 + # Effectively prune edges by zeroing out everything below the threshold + # This keeps exactly k (or slightly more if there are ties) elements + row_slice[row_slice < threshold] = 0 + # Post-processing: remove the explicitly zeroed elements from the sparse structure mat.eliminate_zeros() - return mat def get_samplers(self): From 4e008f26d5fe98c679e6df1295ff788d9d002e35 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Sat, 4 Apr 2026 19:11:23 +0300 Subject: [PATCH 25/27] perf optimized version config --- .../train/mclsr_baseline_clothing_sanity.json | 233 ++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 configs/train/mclsr_baseline_clothing_sanity.json diff --git a/configs/train/mclsr_baseline_clothing_sanity.json b/configs/train/mclsr_baseline_clothing_sanity.json new file mode 100644 index 0000000..0039582 --- /dev/null +++ b/configs/train/mclsr_baseline_clothing_sanity.json @@ -0,0 +1,233 @@ +{ + "experiment_name": "mclsr_Clothing", + "use_wandb": true, + "best_metric": "validation/ndcg@20", + "dataset": { + "type": "graph", + "use_user_graph": true, + "use_item_graph": true, + "neighborhood_size": 50, + "graph_dir_path": "./data/Clothing", + "dataset": { + "type": "mclsr", + "path_to_data_dir": "./data", + "name": "Clothing", + "max_sequence_length": 20, + "samplers": { + "num_negatives_val": 1280, + "num_negatives_train": 1280, + "type": "mclsr", + "negative_sampler_type": "random" + } + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": true, + "shuffle": true, + "num_workers": 4, + "pin_memory" : true + }, + "validation": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": false, + "shuffle": false, + "num_workers": 4, + "pin_memory" : true + } + }, + "model": { + "type": "mclsr", + "sequence_prefix": "item", + "user_prefix": "user", + "labels_prefix": "labels", + "candidate_prefix": "candidates", + "embedding_dim": 64, + "num_graph_layers": 2, + "dropout": 0.3, + "layer_norm_eps": 1e-9, + "graph_dropout": 0.3, + "initializer_range": 0.02, + "alpha": 0.5 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 0.001 + } + }, + "loss": { + "type": "composite", + "losses": [ + { + "type": "sampled_softmax", + "queries_prefix": "combined_representation", + "positive_prefix": "label_representation", + "negative_prefix": "negative_representation", + "output_prefix": "downstream_loss", + "weight": 1.0 + }, + { + "type": "fps", + "fst_embeddings_prefix": "sequential_representation", + "snd_embeddings_prefix": "graph_representation", + "output_prefix": "contrastive_interest_loss", + "weight": 1.0, + "temperature": 0.5 + }, + { + "type": "fps", + "fst_embeddings_prefix": "user_graph_user_embeddings", + "snd_embeddings_prefix": "common_graph_user_embeddings", + "output_prefix": "contrastive_user_feature_loss", + "weight": 0.05, + "temperature": 0.5 + }, + { + "type": "fps", + "fst_embeddings_prefix": "item_graph_item_embeddings", + "snd_embeddings_prefix": "common_graph_item_embeddings", + "output_prefix": "contrastive_item_feature_loss", + "weight": 0.05, + "temperature": 0.5 + } + ], + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "downstream_loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "contrastive_interest_loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "contrastive_user_feature_loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "contrastive_item_feature_loss" + }, + { + "type": "validation", + "on_step": 64, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "mclsr-ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "mclsr-ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "mclsr-ndcg", + "k": 20 + }, + "ndcg@50": { + "type": "mclsr-ndcg", + "k": 50 + }, + "recall@5": { + "type": "mclsr-recall", + "k": 5 + }, + "recall@10": { + "type": "mclsr-recall", + "k": 10 + }, + "recall@20": { + "type": "mclsr-recall", + "k": 20 + }, + "recall@50": { + "type": "mclsr-recall", + "k": 50 + }, + "hit@20": { + "type": "mclsr-hit", + "k": 20 + }, + "hit@50": { + "type": "mclsr-hit", + "k": 50 + } + } + }, + { + "type": "eval", + "on_step": 256, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "mclsr-ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "mclsr-ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "mclsr-ndcg", + "k": 20 + }, + "ndcg@50": { + "type": "mclsr-ndcg", + "k": 50 + }, + "recall@5": { + "type": "mclsr-recall", + "k": 5 + }, + "recall@10": { + "type": "mclsr-recall", + "k": 10 + }, + "recall@20": { + "type": "mclsr-recall", + "k": 20 + }, + "recall@50": { + "type": "mclsr-recall", + "k": 50 + }, + "hit@20": { + "type": "mclsr-hit", + "k": 20 + }, + "hit@50": { + "type": "mclsr-hit", + "k": 50 + } + } + } + ] + } +} \ No newline at end of file From 5d3cb881ba0f4e0b6d8e91bce7ffb37dda7a7d43 Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Sat, 4 Apr 2026 19:13:11 +0300 Subject: [PATCH 26/27] perf batch batch processors --- src/irec/dataloader/batch_processors.py | 50 +++++++++++++++++++------ 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/src/irec/dataloader/batch_processors.py b/src/irec/dataloader/batch_processors.py index c4b9d97..f2f3927 100644 --- a/src/irec/dataloader/batch_processors.py +++ b/src/irec/dataloader/batch_processors.py @@ -1,4 +1,5 @@ import torch +import itertools from irec.utils import MetaParent @@ -19,20 +20,47 @@ def __call__(self, batch): for key in batch[0].keys(): if key.endswith(".ids"): prefix = key.split(".")[0] - assert "{}.length".format(prefix) in batch[0] + length_key = f"{prefix}.length" + assert length_key in batch[0] - processed_batch[f"{prefix}.ids"] = [] - processed_batch[f"{prefix}.length"] = [] + # --- OLD SLOW IMPLEMENTATION (Python loop with manual .extend) --- + # processed_batch[f'{prefix}.ids'] = [] + # processed_batch[f'{prefix}.length'] = [] + # for sample in batch: + # processed_batch[f'{prefix}.ids'].extend( + # sample[f'{prefix}.ids'], + # ) + # processed_batch[f'{prefix}.length'].append( + # sample[f'{prefix}.length'], + # ) - for sample in batch: - processed_batch[f"{prefix}.ids"].extend( - sample[f"{prefix}.ids"], - ) - processed_batch[f"{prefix}.length"].append( - sample[f"{prefix}.length"], - ) + # --- NEW OPTIMIZED IMPLEMENTATION (Books-Scale Ready) --- + """ + Optimization Strategy: C-level Flattening via itertools. + + Justification for Amazon Books scale: + 1. Avoiding Reallocations: Python's list.extend() repeatedly triggers + memory reallocation as the list grows. For large batches on a + 9-million-interaction dataset, this creates significant overhead. + 2. itertools.chain.from_iterable: This is implemented in C. It creates + a flat iterator over the sequence slices without creating + intermediate Python list objects, which is much faster. + 3. List Comprehension: Collecting lengths via a comprehension is + consistently faster than manual .append() calls in a for-loop. + """ + # Efficiently flatten all sequence IDs into one long list + ids_iter = itertools.chain.from_iterable(s[key] for s in batch) + processed_batch[key] = torch.tensor(list(ids_iter), dtype=torch.long) + # Efficiently collect all lengths into a tensor + lengths_list = [s[length_key] for s in batch] + processed_batch[length_key] = torch.tensor( + lengths_list, dtype=torch.long + ) + + # Final conversion for any keys that might have missed the .ids check for part, values in processed_batch.items(): - processed_batch[part] = torch.tensor(values, dtype=torch.long) + if not isinstance(processed_batch[part], torch.Tensor): + processed_batch[part] = torch.tensor(values, dtype=torch.long) return processed_batch From e135ce47904a456d8df38b7837fc1be6e6a031ca Mon Sep 17 00:00:00 2001 From: Aksinya-Bykova <367121@niuitmo.ru> Date: Sun, 5 Apr 2026 20:00:00 +0300 Subject: [PATCH 27/27] fix: train and mclsr imports --- src/irec/models/mclsr.py | 5 +---- src/irec/train.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/irec/models/mclsr.py b/src/irec/models/mclsr.py index 2de98f2..be3b47d 100644 --- a/src/irec/models/mclsr.py +++ b/src/irec/models/mclsr.py @@ -1,12 +1,9 @@ -from .base import TorchModel - import torch import torch.nn as nn +from .base import TorchModel from irec.utils import create_masked_tensor -torch.backends.cudnn.benchmark = True - class MCLSRModel(TorchModel, config_name="mclsr"): def __init__( diff --git a/src/irec/train.py b/src/irec/train.py index 8251822..187120d 100644 --- a/src/irec/train.py +++ b/src/irec/train.py @@ -1,3 +1,12 @@ +import os +import json +import copy +import torch +import wandb + +torch.backends.cudnn.benchmark = True +torch.set_float32_matmul_precision("high") + import irec.utils from irec.utils import ( parse_args, @@ -14,14 +23,10 @@ from irec.models import BaseModel from irec.optimizer import BaseOptimizer -import copy -import json -import os -import torch -import wandb +seed_val = 42 +fix_random_seed(seed_val) logger = create_logger(name=__name__) -seed_val = 42 def train(