diff --git a/README.md b/README.md index 680c46e..f0053bc 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ + +# Steps on local machine

cool irec logo @@ -40,26 +42,6 @@ uv sync --frozen ``` -### Using pip - -1. Create and activate a virtual environment: - ```bash - python3 -m venv .venv - source ./.venv/bin/activate - ``` - -2. Install dependencies: - - **For development:** - ```bash - pip install -e ".[dev]" - ``` - - **For production:** - ```bash - pip install -e . - ``` - ## Preparing datasets All pre-processed datasets used in our experiments are available for download from our cloud storage. This is the fastest way to get started. @@ -91,4 +73,4 @@ The script has 1 input argument: `params` which is the path to the json file wit -`callbacks` Different additional traning --`use_wandb` Enable Weights & Biases logging for experiment tracking \ No newline at end of file +-`use_wandb` Enable Weights & Biases logging for experiment tracking diff --git a/configs/inference/sasrec_inference_config.json b/configs/inference/sasrec_inference_config.json deleted file mode 100644 index c91e21b..0000000 --- a/configs/inference/sasrec_inference_config.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "pred_prefix": "logits", - "label_prefix": "labels", - "experiment_name": "sasrec_beauty_grid__0-5__", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.5, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } -} diff --git a/configs/train/mclsr_baseline_clothing_sanity.json b/configs/train/mclsr_baseline_clothing_sanity.json new file mode 100644 index 0000000..0039582 --- /dev/null +++ b/configs/train/mclsr_baseline_clothing_sanity.json @@ -0,0 +1,233 @@ +{ + "experiment_name": "mclsr_Clothing", + "use_wandb": true, + "best_metric": "validation/ndcg@20", + "dataset": { + "type": "graph", + "use_user_graph": true, + "use_item_graph": true, + "neighborhood_size": 50, + "graph_dir_path": "./data/Clothing", + "dataset": { + "type": "mclsr", + "path_to_data_dir": "./data", + "name": "Clothing", + "max_sequence_length": 20, + "samplers": { + "num_negatives_val": 1280, + "num_negatives_train": 1280, + "type": "mclsr", + "negative_sampler_type": "random" + } + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": true, + "shuffle": true, + "num_workers": 4, + "pin_memory" : true + }, + "validation": { + "type": "torch", + "batch_size": 128, + "batch_processor": { + "type": "basic" + }, + "drop_last": false, + "shuffle": false, + "num_workers": 4, + "pin_memory" : true + } + }, + "model": { + "type": "mclsr", + "sequence_prefix": "item", + "user_prefix": "user", + "labels_prefix": "labels", + "candidate_prefix": "candidates", + "embedding_dim": 64, + "num_graph_layers": 2, + "dropout": 0.3, + "layer_norm_eps": 1e-9, + "graph_dropout": 0.3, + "initializer_range": 0.02, + "alpha": 0.5 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 0.001 + } + }, + "loss": { + "type": "composite", + "losses": [ + { + "type": "sampled_softmax", + "queries_prefix": "combined_representation", + "positive_prefix": "label_representation", + "negative_prefix": "negative_representation", + "output_prefix": "downstream_loss", + "weight": 1.0 + }, + { + "type": "fps", + "fst_embeddings_prefix": "sequential_representation", + "snd_embeddings_prefix": "graph_representation", + "output_prefix": "contrastive_interest_loss", + "weight": 1.0, + "temperature": 0.5 + }, + { + "type": "fps", + "fst_embeddings_prefix": "user_graph_user_embeddings", + "snd_embeddings_prefix": "common_graph_user_embeddings", + "output_prefix": "contrastive_user_feature_loss", + "weight": 0.05, + "temperature": 0.5 + }, + { + "type": "fps", + "fst_embeddings_prefix": "item_graph_item_embeddings", + "snd_embeddings_prefix": "common_graph_item_embeddings", + "output_prefix": "contrastive_item_feature_loss", + "weight": 0.05, + "temperature": 0.5 + } + ], + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "downstream_loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "contrastive_interest_loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "contrastive_user_feature_loss" + }, + { + "type": "metric", + "on_step": 1, + "loss_prefix": "contrastive_item_feature_loss" + }, + { + "type": "validation", + "on_step": 64, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "mclsr-ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "mclsr-ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "mclsr-ndcg", + "k": 20 + }, + "ndcg@50": { + "type": "mclsr-ndcg", + "k": 50 + }, + "recall@5": { + "type": "mclsr-recall", + "k": 5 + }, + "recall@10": { + "type": "mclsr-recall", + "k": 10 + }, + "recall@20": { + "type": "mclsr-recall", + "k": 20 + }, + "recall@50": { + "type": "mclsr-recall", + "k": 50 + }, + "hit@20": { + "type": "mclsr-hit", + "k": 20 + }, + "hit@50": { + "type": "mclsr-hit", + "k": 50 + } + } + }, + { + "type": "eval", + "on_step": 256, + "pred_prefix": "predictions", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "mclsr-ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "mclsr-ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "mclsr-ndcg", + "k": 20 + }, + "ndcg@50": { + "type": "mclsr-ndcg", + "k": 50 + }, + "recall@5": { + "type": "mclsr-recall", + "k": 5 + }, + "recall@10": { + "type": "mclsr-recall", + "k": 10 + }, + "recall@20": { + "type": "mclsr-recall", + "k": 20 + }, + "recall@50": { + "type": "mclsr-recall", + "k": 50 + }, + "hit@20": { + "type": "mclsr-hit", + "k": 20 + }, + "hit@50": { + "type": "mclsr-hit", + "k": 50 + } + } + } + ] + } +} \ No newline at end of file diff --git a/configs/train/sasrec_ce_train_config.json b/configs/train/sasrec_ce_train_config.json deleted file mode 100644 index ce17ae3..0000000 --- a/configs/train/sasrec_ce_train_config.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "experiment_name": "sasrec_test", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "use_ce": true, - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec", - "positive_prefix": "positive_embeddings", - "negative_prefix": "negative_embeddings", - "representation_prefix": "current_embeddings", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_in_batch_train_config.json b/configs/train/sasrec_in_batch_train_config.json deleted file mode 100644 index 9950f31..0000000 --- a/configs/train/sasrec_in_batch_train_config.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "experiment_name": "sasrec_in_batch_test", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec_in_batch", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "use_ce": true, - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sampled_softmax", - "queries_prefix": "query_embeddings", - "positive_prefix": "positive_embeddings", - "negative_prefix": "negative_embeddings", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_real_train_config.json b/configs/train/sasrec_real_train_config.json deleted file mode 100644 index b04a5a0..0000000 --- a/configs/train/sasrec_real_train_config.json +++ /dev/null @@ -1,217 +0,0 @@ -{ - "experiment_name": "sasrec_real_beauty", - "best_metric": "validation/ndcg@100", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 100, - "samplers": { - "num_negatives_train": 1, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 2048, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec_real", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec_real", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "ndcg", - "k": 50 - }, - "ndcg@100": { - "type": "ndcg", - "k": 100 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "recall@50": { - "type": "recall", - "k": 50 - }, - "recall@100": { - "type": "recall", - "k": 100 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - }, - "coverage@50": { - "type": "coverage", - "k": 50 - }, - "coverage@100": { - "type": "coverage", - "k": 100 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "ndcg", - "k": 50 - }, - "ndcg@100": { - "type": "ndcg", - "k": 100 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "recall@50": { - "type": "recall", - - "k": 50 - }, - "recall@100": { - "type": "recall", - "k": 100 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - }, - "coverage@50": { - "type": "coverage", - "k": 50 - }, - "coverage@100": { - "type": "coverage", - "k": 100 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_train_grid_config.json b/configs/train/sasrec_train_grid_config.json deleted file mode 100644 index acf86b6..0000000 --- a/configs/train/sasrec_train_grid_config.json +++ /dev/null @@ -1,185 +0,0 @@ -{ - "start_from": 0, - "experiment_name": "sasrec_beauty_grid", - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataset_params": { - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "model_params": { - "dropout": [ - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9] - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "optimizer_params": { - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "loss_params": { - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - } - ] - } -} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5b0ca5a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +numpy~=1.26 +torch~=2.4 +tqdm~=4.66 +scipy~=1.14 +pandas~=2.2 +polars~=1.27 +matplotlib~=3.9 +tensorboard~=2.19 +wandb~=0.19 +jupyterlab~=4.4 +ipykernel~=6.29 +notebook~=7.4 diff --git a/src/irec/callbacks/base.py b/src/irec/callbacks/base.py index 81271f5..77895d6 100644 --- a/src/irec/callbacks/base.py +++ b/src/irec/callbacks/base.py @@ -30,7 +30,7 @@ def __call__(self, inputs, step_num): raise NotImplementedError -class MetricCallback(BaseCallback, config_name='metric'): +class MetricCallback(BaseCallback, config_name="metric"): def __init__( self, model, @@ -56,43 +56,39 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], - metrics=config.get('metrics', None), - loss_prefix=config['loss_prefix'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], + metrics=config.get("metrics", None), + loss_prefix=config["loss_prefix"], ) def __call__(self, inputs, step_num): if step_num % self._on_step == 0: for metric_name, metric_function in self._metrics.items(): metric_value = metric_function( - ground_truth=inputs[ - self._model.schema['ground_truth_prefix'] - ], - predictions=inputs[ - self._model.schema['predictions_prefix'] - ], + ground_truth=inputs[self._model.schema["ground_truth_prefix"]], + predictions=inputs[self._model.schema["predictions_prefix"]], ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar( - 'train/{}'.format(metric_name), + "train/{}".format(metric_name), metric_value, step_num, ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar( - 'train/{}'.format(self._loss_prefix), + "train/{}".format(self._loss_prefix), inputs[self._loss_prefix], step_num, ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.flush() -class CheckpointCallback(BaseCallback, config_name='checkpoint'): +class CheckpointCallback(BaseCallback, config_name="checkpoint"): def __init__( self, model, @@ -115,7 +111,7 @@ def __init__( self._save_path = Path(os.path.join(save_path, model_name)) if self._save_path.exists(): logger.warning( - 'Checkpoint path `{}` is already exists!'.format( + "Checkpoint path `{}` is already exists!".format( self._save_path, ), ) @@ -125,31 +121,31 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], - save_path=config['save_path'], - model_name=config['model_name'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], + save_path=config["save_path"], + model_name=config["model_name"], ) def __call__(self, inputs, step_num): if step_num % self._on_step == 0: - logger.debug('Saving model state on step {}...'.format(step_num)) + logger.debug("Saving model state on step {}...".format(step_num)) torch.save( { - 'step_num': step_num, - 'model_state_dict': self._model.state_dict(), - 'optimizer_state_dict': self._optimizer.state_dict(), + "step_num": step_num, + "model_state_dict": self._model.state_dict(), + "optimizer_state_dict": self._optimizer.state_dict(), }, os.path.join( self._save_path, - 'checkpoint_{}.pth'.format(step_num), + "checkpoint_{}.pth".format(step_num), ), ) - logger.debug('Saving done!') + logger.debug("Saving done!") class InferenceCallback(BaseCallback): @@ -183,24 +179,24 @@ def __init__( def create_from_config(cls, config, **kwargs): metrics = { metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs) - for metric_name, metric_cfg in config['metrics'].items() + for metric_name, metric_cfg in config["metrics"].items() } return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['labels_prefix'], + pred_prefix=config["pred_prefix"], + labels_prefix=config["labels_prefix"], ) def __call__(self, inputs, step_num): if step_num % self._on_step == 0: # TODO Add time monitoring - logger.debug(f'Running {self._get_name()} on step {step_num}...') + logger.debug(f"Running {self._get_name()} on step {step_num}...") running_params = {} for metric_name, metric_function in self._metrics.items(): running_params[metric_name] = [] @@ -239,16 +235,16 @@ def __call__(self, inputs, step_num): ) for label, value in running_params.items(): - inputs[f'{self._get_name()}/{label}'] = np.mean(value) + inputs[f"{self._get_name()}/{label}"] = np.mean(value) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar( - f'{self._get_name()}/{label}', + f"{self._get_name()}/{label}", np.mean(value), step_num, ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.flush() logger.debug( - f'Running {self._get_name()} on step {step_num} is done!', + f"Running {self._get_name()} on step {step_num} is done!", ) def _get_name(self): @@ -258,55 +254,55 @@ def _get_dataloader(self): raise NotImplementedError -class ValidationCallback(InferenceCallback, config_name='validation'): +class ValidationCallback(InferenceCallback, config_name="validation"): @classmethod def create_from_config(cls, config, **kwargs): metrics = { metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs) - for metric_name, metric_cfg in config['metrics'].items() + for metric_name, metric_cfg in config["metrics"].items() } return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['labels_prefix'], + pred_prefix=config["pred_prefix"], + labels_prefix=config["labels_prefix"], ) def _get_dataloader(self): return self._validation_dataloader -class EvalCallback(InferenceCallback, config_name='eval'): +class EvalCallback(InferenceCallback, config_name="eval"): @classmethod def create_from_config(cls, config, **kwargs): metrics = { metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs) - for metric_name, metric_cfg in config['metrics'].items() + for metric_name, metric_cfg in config["metrics"].items() } return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], + on_step=config["on_step"], metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['labels_prefix'], + pred_prefix=config["pred_prefix"], + labels_prefix=config["labels_prefix"], ) def _get_dataloader(self): return self._eval_dataloader -class CompositeCallback(BaseCallback, config_name='composite'): +class CompositeCallback(BaseCallback, config_name="composite"): def __init__( self, model, @@ -328,14 +324,14 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], + model=kwargs["model"], + train_dataloader=kwargs["train_dataloader"], + validation_dataloader=kwargs["validation_dataloader"], + eval_dataloader=kwargs["eval_dataloader"], + optimizer=kwargs["optimizer"], callbacks=[ BaseCallback.create_from_config(cfg, **kwargs) - for cfg in config['callbacks'] + for cfg in config["callbacks"] ], ) diff --git a/src/irec/dataloader/base.py b/src/irec/dataloader/base.py index 06fdefc..5e09f46 100644 --- a/src/irec/dataloader/base.py +++ b/src/irec/dataloader/base.py @@ -13,7 +13,7 @@ class BaseDataloader(metaclass=MetaParent): pass -class TorchDataloader(BaseDataloader, config_name='torch'): +class TorchDataloader(BaseDataloader, config_name="torch"): def __init__(self, dataloader): self._dataloader = dataloader @@ -27,17 +27,32 @@ def __len__(self): def create_from_config(cls, config, **kwargs): create_config = copy.deepcopy(config) batch_processor = BaseBatchProcessor.create_from_config( - create_config.pop('batch_processor') - if 'batch_processor' in create_config - else {'type': 'identity'}, + ( + create_config.pop("batch_processor") + if "batch_processor" in create_config + else {"type": "identity"} + ), ) create_config.pop( - 'type', + "type", ) # For passing as **config in torch DataLoader + + pin_memory = create_config.pop("pin_memory", True) + return cls( dataloader=DataLoader( - kwargs['dataset'], + kwargs["dataset"], collate_fn=batch_processor, + pin_memory=pin_memory, **create_config, ), ) + + # return cls( + # dataloader=DataLoader( + # kwargs['dataset'], + # collate_fn=batch_processor, + # pin_memory=True, + # **create_config, + # ), + # ) diff --git a/src/irec/dataloader/batch_processors.py b/src/irec/dataloader/batch_processors.py index a8dbdde..f2f3927 100644 --- a/src/irec/dataloader/batch_processors.py +++ b/src/irec/dataloader/batch_processors.py @@ -1,4 +1,5 @@ import torch +import itertools from irec.utils import MetaParent @@ -7,32 +8,59 @@ def __call__(self, batch): raise NotImplementedError -class IdentityBatchProcessor(BaseBatchProcessor, config_name='identity'): +class IdentityBatchProcessor(BaseBatchProcessor, config_name="identity"): def __call__(self, batch): return torch.tensor(batch) -class BasicBatchProcessor(BaseBatchProcessor, config_name='basic'): +class BasicBatchProcessor(BaseBatchProcessor, config_name="basic"): def __call__(self, batch): processed_batch = {} for key in batch[0].keys(): - if key.endswith('.ids'): - prefix = key.split('.')[0] - assert '{}.length'.format(prefix) in batch[0] + if key.endswith(".ids"): + prefix = key.split(".")[0] + length_key = f"{prefix}.length" + assert length_key in batch[0] - processed_batch[f'{prefix}.ids'] = [] - processed_batch[f'{prefix}.length'] = [] + # --- OLD SLOW IMPLEMENTATION (Python loop with manual .extend) --- + # processed_batch[f'{prefix}.ids'] = [] + # processed_batch[f'{prefix}.length'] = [] + # for sample in batch: + # processed_batch[f'{prefix}.ids'].extend( + # sample[f'{prefix}.ids'], + # ) + # processed_batch[f'{prefix}.length'].append( + # sample[f'{prefix}.length'], + # ) - for sample in batch: - processed_batch[f'{prefix}.ids'].extend( - sample[f'{prefix}.ids'], - ) - processed_batch[f'{prefix}.length'].append( - sample[f'{prefix}.length'], - ) + # --- NEW OPTIMIZED IMPLEMENTATION (Books-Scale Ready) --- + """ + Optimization Strategy: C-level Flattening via itertools. + + Justification for Amazon Books scale: + 1. Avoiding Reallocations: Python's list.extend() repeatedly triggers + memory reallocation as the list grows. For large batches on a + 9-million-interaction dataset, this creates significant overhead. + 2. itertools.chain.from_iterable: This is implemented in C. It creates + a flat iterator over the sequence slices without creating + intermediate Python list objects, which is much faster. + 3. List Comprehension: Collecting lengths via a comprehension is + consistently faster than manual .append() calls in a for-loop. + """ + # Efficiently flatten all sequence IDs into one long list + ids_iter = itertools.chain.from_iterable(s[key] for s in batch) + processed_batch[key] = torch.tensor(list(ids_iter), dtype=torch.long) + # Efficiently collect all lengths into a tensor + lengths_list = [s[length_key] for s in batch] + processed_batch[length_key] = torch.tensor( + lengths_list, dtype=torch.long + ) + + # Final conversion for any keys that might have missed the .ids check for part, values in processed_batch.items(): - processed_batch[part] = torch.tensor(values, dtype=torch.long) + if not isinstance(processed_batch[part], torch.Tensor): + processed_batch[part] = torch.tensor(values, dtype=torch.long) return processed_batch diff --git a/src/irec/dataset/base.py b/src/irec/dataset/base.py index b07d00c..5936a31 100644 --- a/src/irec/dataset/base.py +++ b/src/irec/dataset/base.py @@ -33,7 +33,8 @@ def num_items(self): @property def max_sequence_length(self): return self._max_sequence_length - + + class BaseSequenceDataset(BaseDataset): def __init__( self, @@ -52,7 +53,7 @@ def __init__( self._max_sequence_length = max_sequence_length @staticmethod - def _create_sequences(data, max_sample_len): # TODO + def _create_sequences(data, max_sample_len): # TODO user_sequences = [] item_sequences = [] @@ -61,10 +62,8 @@ def _create_sequences(data, max_sample_len): # TODO max_sequence_length = 0 for sample in data: - sample = sample.strip('\n').split(' ') - item_ids = [int(item_id) for item_id in sample[1:]][ - -max_sample_len: - ] + sample = sample.strip("\n").split(" ") + item_ids = [int(item_id) for item_id in sample[1:]][-max_sample_len:] user_id = int(sample[0]) max_user_id = max(max_user_id, user_id) @@ -92,12 +91,13 @@ def get_samplers(self): @property def meta(self): return { - 'num_users': self.num_users, - 'num_items': self.num_items, - 'max_sequence_length': self.max_sequence_length, + "num_users": self.num_users, + "num_items": self.num_items, + "max_sequence_length": self.max_sequence_length, } - -class SequenceDataset(BaseDataset, config_name='sequence'): + + +class SequenceDataset(BaseDataset, config_name="sequence"): def __init__( self, train_sampler, @@ -117,24 +117,24 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): data_dir_path = os.path.join( - config['path_to_data_dir'], - config['name'], + config["path_to_data_dir"], + config["name"], ) common_params_for_creation = { - 'dir_path': data_dir_path, - 'max_sequence_length': config['max_sequence_length'], - 'use_cached': config.get('use_cached', False), + "dir_path": data_dir_path, + "max_sequence_length": config["max_sequence_length"], + "use_cached": config.get("use_cached", False), } train_dataset, train_max_user_id, train_max_item_id, train_seq_len = ( - cls._create_dataset(part='train', **common_params_for_creation) + cls._create_dataset(part="train", **common_params_for_creation) ) validation_dataset, valid_max_user_id, valid_max_item_id, valid_seq_len = ( - cls._create_dataset(part='valid', **common_params_for_creation) + cls._create_dataset(part="valid", **common_params_for_creation) ) test_dataset, test_max_user_id, test_max_item_id, test_seq_len = ( - cls._create_dataset(part='test', **common_params_for_creation) + cls._create_dataset(part="test", **common_params_for_creation) ) max_user_id = max( @@ -145,11 +145,11 @@ def create_from_config(cls, config, **kwargs): ) max_seq_len = max([train_seq_len, valid_seq_len, test_seq_len]) - logger.info('Train dataset size: {}'.format(len(train_dataset))) - logger.info('Test dataset size: {}'.format(len(test_dataset))) - logger.info('Max user id: {}'.format(max_user_id)) - logger.info('Max item id: {}'.format(max_item_id)) - logger.info('Max sequence length: {}'.format(max_seq_len)) + logger.info("Train dataset size: {}".format(len(train_dataset))) + logger.info("Test dataset size: {}".format(len(test_dataset))) + logger.info("Max user id: {}".format(max_user_id)) + logger.info("Max item id: {}".format(max_item_id)) + logger.info("Max sequence length: {}".format(max_seq_len)) train_interactions = sum( list(map(lambda x: len(x), train_dataset)), @@ -161,15 +161,15 @@ def create_from_config(cls, config, **kwargs): test_dataset, ) # each new interaction as a sample logger.info( - '{} dataset sparsity: {}'.format( - config['name'], + "{} dataset sparsity: {}".format( + config["name"], (train_interactions + valid_interactions + test_interactions) / max_user_id / max_item_id, ), ) - samplers_config = config['samplers'] + samplers_config = config["samplers"] train_sampler = TrainSampler.create_from_config( samplers_config, dataset=train_dataset, @@ -206,31 +206,30 @@ def _create_dataset( max_sequence_length=None, use_cached=False, ): - cache_path = os.path.join(dir_path, '{}.pkl'.format(part)) + cache_path = os.path.join(dir_path, "{}.pkl".format(part)) if use_cached and os.path.exists(cache_path): - logger.info( - 'Loading cached dataset from {}'.format(cache_path) - ) - with open(cache_path, 'rb') as f: + logger.info("Loading cached dataset from {}".format(cache_path)) + with open(cache_path, "rb") as f: return pickle.load(f) - - return cls._build_and_cache_dataset(dir_path, part, max_sequence_length, cache_path, use_cached) + return cls._build_and_cache_dataset( + dir_path, part, max_sequence_length, cache_path, use_cached + ) @classmethod - def _build_and_cache_dataset(cls, dir_path, part, max_sequence_length, cache_path, use_cached): + def _build_and_cache_dataset( + cls, dir_path, part, max_sequence_length, cache_path, use_cached + ): logger.info( - 'Cache is forcefully ignored.' + "Cache is forcefully ignored." if not use_cached - else 'No cached dataset has been found.' - ) - dataset_path = os.path.join(dir_path, '{}.txt'.format(part)) - logger.info( - 'Creating a dataset from {}...'.format(dataset_path) + else "No cached dataset has been found." ) + dataset_path = os.path.join(dir_path, "{}.txt".format(part)) + logger.info("Creating a dataset from {}...".format(dataset_path)) - with open(dataset_path, 'r') as f: + with open(dataset_path, "r") as f: data = f.readlines() sequence_info = cls._create_sequences(data, max_sequence_length) @@ -246,22 +245,22 @@ def _build_and_cache_dataset(cls, dir_path, part, max_sequence_length, cache_pat for user_id, item_ids in zip(user_sequences, item_sequences): dataset.append( { - 'user.ids': [user_id], - 'user.length': 1, - 'item.ids': item_ids, - 'item.length': len(item_ids), + "user.ids": [user_id], + "user.length": 1, + "item.ids": item_ids, + "item.length": len(item_ids), }, ) - logger.info('{} dataset size: {}'.format(part, len(dataset))) + logger.info("{} dataset size: {}".format(part, len(dataset))) logger.info( - '{} dataset max sequence length: {}'.format( + "{} dataset max sequence length: {}".format( part, max_sequence_len, ), ) - with open(cache_path, 'wb') as dataset_file: + with open(cache_path, "wb") as dataset_file: pickle.dump( (dataset, max_user_id, max_item_id, max_sequence_len), dataset_file, @@ -270,7 +269,7 @@ def _build_and_cache_dataset(cls, dir_path, part, max_sequence_length, cache_pat return dataset, max_user_id, max_item_id, max_sequence_len @staticmethod - def _create_sequences(data, max_sample_len): # TODO + def _create_sequences(data, max_sample_len): # TODO user_sequences = [] item_sequences = [] @@ -279,10 +278,8 @@ def _create_sequences(data, max_sample_len): # TODO max_sequence_length = 0 for sample in data: - sample = sample.strip('\n').split(' ') - item_ids = [int(item_id) for item_id in sample[1:]][ - -max_sample_len: - ] + sample = sample.strip("\n").split(" ") + item_ids = [int(item_id) for item_id in sample[1:]][-max_sample_len:] user_id = int(sample[0]) max_user_id = max(max_user_id, user_id) @@ -307,7 +304,8 @@ def get_samplers(self): self._test_sampler, ) -class GraphDataset(BaseDataset, config_name='graph'): + +class GraphDataset(BaseDataset, config_name="graph"): def __init__( self, dataset, @@ -315,7 +313,7 @@ def __init__( use_train_data_only=True, use_user_graph=False, use_item_graph=False, - neighborhood_size=None + neighborhood_size=None, ): self._dataset = dataset self._graph_dir_path = graph_dir_path @@ -327,11 +325,11 @@ def __init__( self._num_users = dataset.num_users self._num_items = dataset.num_items - train_sampler, validation_sampler, test_sampler = ( - dataset.get_samplers() - ) + train_sampler, validation_sampler, test_sampler = dataset.get_samplers() - interactions_data = self._collect_interactions(train_sampler, validation_sampler, test_sampler) + interactions_data = self._collect_interactions( + train_sampler, validation_sampler, test_sampler + ) train_interactions = interactions_data["train_interactions"] train_user_interactions = interactions_data["train_user_interactions"] train_item_interactions = interactions_data["train_item_interactions"] @@ -343,49 +341,63 @@ def __init__( self._train_item_interactions = np.array(train_item_interactions) self._graph = self._build_or_load_bipartite_graph( - graph_dir_path, - train_user_interactions, - train_item_interactions + graph_dir_path, train_user_interactions, train_item_interactions ) - self._user_graph = ( self._build_or_load_similarity_graph( - 'user', - self._train_user_interactions, - self._train_item_interactions, - train_item_2_users, - train_user_2_items - ) - if self._use_user_graph + "user", + self._train_user_interactions, + self._train_item_interactions, + train_item_2_users, + train_user_2_items, + ) + if self._use_user_graph else None ) self._item_graph = ( self._build_or_load_similarity_graph( - 'item', - self._train_user_interactions, - self._train_item_interactions, - train_item_2_users, - train_user_2_items - ) - if self._use_item_graph + "item", + self._train_user_interactions, + self._train_item_interactions, + train_item_2_users, + train_user_2_items, + ) + if self._use_item_graph else None ) def _build_or_load_similarity_graph( - self, - entity_type, - train_user_interactions, - train_item_interactions, - train_item_2_users, - train_user_2_items + self, + entity_type, + train_user_interactions, + train_item_interactions, + train_item_2_users, + train_user_2_items, ): - if entity_type not in ['user', 'item']: + if entity_type not in ["user", "item"]: raise ValueError("entity_type must be either 'user' or 'item'") + # have to delete and replace to not delete npz each time manually + # path_to_graph = os.path.join(self._graph_dir_path, '{}_graph.npz'.format(entity_type)) + + # instead better use such construction + + # neighborhood_size + # The neighborhood_size is a filter that constrains the number of edges for each user or + # item node in the graph. + # k=50 implies that for each user, we find all possible neighbors, sort them based on + # co-occurrence counts, and keep only the top 50. All other connections are removed from the graph. + k_suffix = ( + f"k{self._neighborhood_size}" + if self._neighborhood_size is not None + else "full" + ) + train_suffix = "trainOnly" if self._use_train_data_only else "withValTest" + filename = f"{entity_type}_graph_{k_suffix}_{train_suffix}.npz" + path_to_graph = os.path.join(self._graph_dir_path, filename) - path_to_graph = os.path.join(self._graph_dir_path, '{}_graph.npz'.format(entity_type)) - is_user_graph = (entity_type == 'user') + is_user_graph = entity_type == "user" num_entities = self._num_users if is_user_graph else self._num_items if os.path.exists(path_to_graph): @@ -394,19 +406,28 @@ def _build_or_load_similarity_graph( interactions_fst = [] interactions_snd = [] visited_user_item_pairs = set() - visited_entity_pairs = set() + # have to delete cause + # 3.2 Graph Construction + # User-user/item-item graph + # ..the weight of each edge denotes the number of co-action behaviors between user i and user j + + # visited_entity_pairs = set() for user_id, item_id in tqdm( zip(train_user_interactions, train_item_interactions), - desc='Building {}-{} graph'.format(entity_type, entity_type) # TODO need? + desc="Building {}-{} graph".format( + entity_type, entity_type + ), # TODO need? ): if (user_id, item_id) in visited_user_item_pairs: continue - visited_user_item_pairs.add((user_id, item_id)) + visited_user_item_pairs.add((user_id, item_id)) # TODO look here at review source_entity = user_id if is_user_graph else item_id - connection_map = train_item_2_users if is_user_graph else train_user_2_items + connection_map = ( + train_item_2_users if is_user_graph else train_user_2_items + ) connection_point = item_id if is_user_graph else user_id for connected_entity in connection_map[connection_point]: @@ -414,38 +435,38 @@ def _build_or_load_similarity_graph( continue pair_key = (source_entity, connected_entity) - if pair_key in visited_entity_pairs: - continue - - visited_entity_pairs.add(pair_key) + # if pair_key in visited_entity_pairs: + # continue + + # visited_entity_pairs.add(pair_key) interactions_fst.append(source_entity) interactions_snd.append(connected_entity) connections = csr_matrix( - (np.ones(len(interactions_fst)), - ( - interactions_fst, - interactions_snd - ) - ), - shape=(num_entities + 2, num_entities + 2) + (np.ones(len(interactions_fst)), (interactions_fst, interactions_snd)), + shape=(num_entities + 2, num_entities + 2), ) if self._neighborhood_size is not None: - connections = self._filter_matrix_by_top_k(connections, self._neighborhood_size) + connections = self._filter_matrix_by_top_k( + connections, self._neighborhood_size + ) graph_matrix = self.get_sparse_graph_layer( - connections, - num_entities + 2, - num_entities + 2, - biparite=False + connections, num_entities + 2, num_entities + 2, biparite=False ) sp.save_npz(path_to_graph, graph_matrix) return self._convert_sp_mat_to_sp_tensor(graph_matrix).coalesce().to(DEVICE) - def _build_or_load_bipartite_graph(self, graph_dir_path, train_user_interactions, train_item_interactions): - path_to_graph = os.path.join(graph_dir_path, 'general_graph.npz') + def _build_or_load_bipartite_graph( + self, graph_dir_path, train_user_interactions, train_item_interactions + ): + # path_to_graph = os.path.join(graph_dir_path, 'general_graph.npz') + train_suffix = "trainOnly" if self._use_train_data_only else "withValTest" + filename = f"general_graph_{train_suffix}.npz" + path_to_graph = os.path.join(graph_dir_path, filename) + if os.path.exists(path_to_graph): graph_matrix = sp.load_npz(path_to_graph) else: @@ -481,8 +502,8 @@ def _collect_interactions(self, train_sampler, validation_sampler, test_sampler) for sampler in samplers_to_process: for sample in sampler.dataset: - user_id = sample['user.ids'][0] - for item_id in sample['item.ids']: + user_id = sample["user.ids"][0] + for item_id in sample["item.ids"]: if (user_id, item_id) not in visited_user_item_pairs: train_interactions.append((user_id, item_id)) train_user_interactions.append(user_id) @@ -492,7 +513,7 @@ def _collect_interactions(self, train_sampler, validation_sampler, test_sampler) train_item_2_users[item_id].add(user_id) visited_user_item_pairs.add((user_id, item_id)) - + return { "train_interactions": train_interactions, "train_user_interactions": train_user_interactions, @@ -503,15 +524,49 @@ def _collect_interactions(self, train_sampler, validation_sampler, test_sampler) @classmethod def create_from_config(cls, config): - dataset = BaseDataset.create_from_config(config['dataset']) + dataset = BaseDataset.create_from_config(config["dataset"]) return cls( dataset=dataset, - graph_dir_path=config['graph_dir_path'], - use_user_graph=config.get('use_user_graph', False), - use_item_graph=config.get('use_item_graph', False), - neighborhood_size=config.get('neighborhood_size', None), + graph_dir_path=config["graph_dir_path"], + use_user_graph=config.get("use_user_graph", False), + use_item_graph=config.get("use_item_graph", False), + neighborhood_size=config.get("neighborhood_size", None), ) + # @staticmethod + # def get_sparse_graph_layer( + # sparse_matrix, + # fst_dim, + # snd_dim, + # biparite=False, + # ): + # if not biparite: + # adj_mat = sparse_matrix.tocsr() + # else: + # R = sparse_matrix.tocsr() + + # upper_right = R + # lower_left = R.T + + # upper_left = sp.csr_matrix((fst_dim, fst_dim)) + # lower_right = sp.csr_matrix((snd_dim, snd_dim)) + + # adj_mat = sp.bmat([ + # [upper_left, upper_right], + # [lower_left, lower_right] + # ]) + # assert adj_mat.shape == (fst_dim + snd_dim, fst_dim + snd_dim), ( + # f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}" + # ) + + # rowsum = np.array(adj_mat.sum(1)) + # d_inv = np.power(rowsum, -0.5).flatten() + # d_inv[np.isinf(d_inv)] = 0. + # d_mat_inv = sp.diags(d_inv) + + # norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv) + # return norm_adj.tocsr() + @staticmethod def get_sparse_graph_layer( sparse_matrix, @@ -523,27 +578,52 @@ def get_sparse_graph_layer( adj_mat = sparse_matrix.tocsr() else: R = sparse_matrix.tocsr() - + upper_right = R lower_left = R.T - + upper_left = sp.csr_matrix((fst_dim, fst_dim)) lower_right = sp.csr_matrix((snd_dim, snd_dim)) - - adj_mat = sp.bmat([ - [upper_left, upper_right], - [lower_left, lower_right] - ]) - assert adj_mat.shape == (fst_dim + snd_dim, fst_dim + snd_dim), ( - f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}" - ) + + adj_mat = sp.bmat([[upper_left, upper_right], [lower_left, lower_right]]) + assert adj_mat.shape == ( + fst_dim + snd_dim, + fst_dim + snd_dim, + ), f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}" + + # --- OLD IMPLEMENTATION (Slow & Memory Intensive for Large Corpus) --- + # rowsum = np.array(adj_mat.sum(1)) + # d_inv = np.power(rowsum, -0.5).flatten() + # d_inv[np.isinf(d_inv)] = 0. + # d_mat_inv = sp.diags(d_inv) + # norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv) + # return norm_adj.tocsr() + + # --- NEW OPTIMIZED IMPLEMENTATION --- + """ + Optimization Strategy: Vectorized Symmetric Normalization (D^-0.5 * A * D^-0.5). - rowsum = np.array(adj_mat.sum(1)) - d_inv = np.power(rowsum, -0.5).flatten() - d_inv[np.isinf(d_inv)] = 0. - d_mat_inv = sp.diags(d_inv) + Justification for Amazon Books Scale: + 1. Memory Efficiency: Creating an explicit diagonal matrix 'd_mat_inv' + (size N x N) via sp.diags is redundant. For 800k+ nodes, this consumes + significant RAM and creates heavy intermediate objects. + 2. Computational Speed: Traditional matrix-matrix multiplication (.dot) + in Scipy sparse has higher overhead compared to row/column-wise scaling. + 3. Implementation: We perform element-wise multiplication of the sparse + matrix by 1D degree vectors. Multiplying a sparse matrix by a column + vector scales rows, and by a row vector scales columns. - norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv) + This results in an identical Laplacian matrix but is calculated in O(E) + time with minimal memory footprint, where E is the number of edges. + """ + rowsum = np.array(adj_mat.sum(1)).flatten() + d_inv = np.power(rowsum, -0.5) + d_inv[np.isinf(d_inv)] = 0.0 + + # Scaling rows: multiply by column vector [N, 1] + # Scaling columns: multiply by row vector [1, N] + norm_adj = adj_mat.multiply(d_inv[:, np.newaxis]).multiply(d_inv) + return norm_adj.tocsr() @staticmethod @@ -557,18 +637,52 @@ def _convert_sp_mat_to_sp_tensor(X): @staticmethod def _filter_matrix_by_top_k(matrix, k): - mat = matrix.tolil() + # --- OLD IMPLEMENTATION (Extremely slow conversion to LIL for large datasets) --- + # mat = matrix.tolil() + # for i in range(mat.shape[0]): + # if len(mat.rows[i]) <= k: + # continue + # data = np.array(mat.data[i]) + # top_k_indices = np.argpartition(data, -k)[-k:] + # mat.data[i] = [mat.data[i][j] for j in top_k_indices] + # mat.rows[i] = [mat.rows[i][j] for j in top_k_indices] + # return mat.tocsr() + + # --- NEW OPTIMIZED IMPLEMENTATION --- + """ + Optimization Strategy: Direct CSR Array Manipulation with NumPy Partitioning. + + Justification for Amazon Books Scale: + 1. Avoids LIL conversion: Converting a 450k x 300k matrix to LIL format + (List of Lists) is extremely memory-intensive and slow. + 2. In-place filtering: By accessing the CSR 'data' and 'indptr' arrays directly, + we perform the Top-K filtering with zero additional memory allocation + for the matrix structure itself. + 3. Algorithmic Speed: np.partition finds the threshold value in O(n) average + time. Slicing the underlying NumPy arrays is performed at near-C speed, + making this orders of magnitude faster than Python-level list operations. + """ + mat = matrix.tocsr() for i in range(mat.shape[0]): - if len(mat.rows[i]) <= k: - continue - data = np.array(mat.data[i]) - - top_k_indices = np.argpartition(data, -k)[-k:] - mat.data[i] = [mat.data[i][j] for j in top_k_indices] - mat.rows[i] = [mat.rows[i][j] for j in top_k_indices] + start = mat.indptr[i] + end = mat.indptr[i + 1] + + # Only process rows that actually exceed the neighborhood size + if end - start > k: + row_slice = mat.data[start:end] + + # Find the threshold value (the k-th largest element) + # np.partition is faster than a full sort: it puts the top-k values at the end + threshold = np.partition(row_slice, -k)[-k] + + # Effectively prune edges by zeroing out everything below the threshold + # This keeps exactly k (or slightly more if there are ties) elements + row_slice[row_slice < threshold] = 0 - return mat.tocsr() + # Post-processing: remove the explicitly zeroed elements from the sparse structure + mat.eliminate_zeros() + return mat def get_samplers(self): return self._dataset.get_samplers() @@ -576,50 +690,52 @@ def get_samplers(self): @property def meta(self): meta = { - 'user_graph': self._user_graph, - 'item_graph': self._item_graph, - 'graph': self._graph, + "user_graph": self._user_graph, + "item_graph": self._item_graph, + "graph": self._graph, **self._dataset.meta, } return meta -class ScientificDataset(BaseSequenceDataset, config_name='scientific'): +class ScientificDataset(BaseSequenceDataset, config_name="scientific"): @classmethod def create_from_config(cls, config, **kwargs): data_dir_path = os.path.join( - config['path_to_data_dir'], - config['name'], + config["path_to_data_dir"], + config["name"], ) - max_sequence_length = config['max_sequence_length'] + max_sequence_length = config["max_sequence_length"] - dataset_path = os.path.join(data_dir_path, '{}.txt'.format('all_data')) - with open(dataset_path, 'r') as f: + dataset_path = os.path.join(data_dir_path, "{}.txt".format("all_data")) + with open(dataset_path, "r") as f: lines = f.readlines() - datasets, max_user_id, max_item_id = cls._parse_and_split_data(lines, max_sequence_length) + datasets, max_user_id, max_item_id = cls._parse_and_split_data( + lines, max_sequence_length + ) - train_dataset = datasets['train'] - validation_dataset = datasets['validation'] - test_dataset = datasets['test'] + train_dataset = datasets["train"] + validation_dataset = datasets["validation"] + test_dataset = datasets["test"] cls._log_stats( - train_dataset, - test_dataset, - max_user_id, - max_item_id, - max_sequence_length, - config['name'] - ) + train_dataset, + test_dataset, + max_user_id, + max_item_id, + max_sequence_length, + config["name"], + ) train_sampler, validation_sampler, test_sampler = cls._create_samplers( - config['samplers'], - train_dataset, - validation_dataset, - test_dataset, - max_user_id, - max_item_id + config["samplers"], + train_dataset, + validation_dataset, + test_dataset, + max_user_id, + max_item_id, ) return cls( @@ -630,71 +746,99 @@ def create_from_config(cls, config, **kwargs): num_items=max_item_id, max_sequence_length=max_sequence_length, ) - + @staticmethod - def _create_samplers(sampler_config, train_dataset, validation_dataset, test_dataset, num_users, num_items): + def _create_samplers( + sampler_config, + train_dataset, + validation_dataset, + test_dataset, + num_users, + num_items, + ): train_sampler = TrainSampler.create_from_config( - sampler_config, dataset=train_dataset, num_users=num_users, num_items=num_items + sampler_config, + dataset=train_dataset, + num_users=num_users, + num_items=num_items, ) validation_sampler = EvalSampler.create_from_config( - sampler_config, dataset=validation_dataset, num_users=num_users, num_items=num_items + sampler_config, + dataset=validation_dataset, + num_users=num_users, + num_items=num_items, ) test_sampler = EvalSampler.create_from_config( - sampler_config, dataset=test_dataset, num_users=num_users, num_items=num_items + sampler_config, + dataset=test_dataset, + num_users=num_users, + num_items=num_items, ) return train_sampler, validation_sampler, test_sampler - + @staticmethod - def _log_stats(train_dataset, test_dataset, max_user_id, max_item_id, max_len, name): - logger.info('Train dataset size: {}'.format(len(train_dataset))) - logger.info('Test dataset size: {}'.format(len(test_dataset))) - logger.info('Max user id: {}'.format(max_user_id)) - logger.info('Max item id: {}'.format(max_item_id)) - logger.info('Max sequence length: {}'.format(max_len)) - + def _log_stats( + train_dataset, test_dataset, max_user_id, max_item_id, max_len, name + ): + logger.info("Train dataset size: {}".format(len(train_dataset))) + logger.info("Test dataset size: {}".format(len(test_dataset))) + logger.info("Max user id: {}".format(max_user_id)) + logger.info("Max item id: {}".format(max_item_id)) + logger.info("Max sequence length: {}".format(max_len)) + if max_user_id > 0 and max_item_id > 0: - sparsity = (len(train_dataset) + len(test_dataset)) / max_user_id / max_item_id - logger.info('{} dataset sparsity: {}'.format(name, sparsity)) + sparsity = ( + (len(train_dataset) + len(test_dataset)) / max_user_id / max_item_id + ) + logger.info("{} dataset sparsity: {}".format(name, sparsity)) @staticmethod def _parse_and_split_data(lines, max_sequence_length): - datasets = {'train': [], 'validation': [], 'test': []} + datasets = {"train": [], "validation": [], "test": []} - user_ids, item_sequences, max_user_id, max_item_id, _ = \ + user_ids, item_sequences, max_user_id, max_item_id, _ = ( BaseSequenceDataset._create_sequences(lines) + ) for user_id, item_ids in zip(user_ids, item_sequences): - + assert len(item_ids) >= 5 split_slices = { - 'train': slice(None, -2), - 'validation': slice(None, -1), - 'test': slice(None, None) + "train": slice(None, -2), + "validation": slice(None, -1), + "test": slice(None, None), } - + for part_name, part_slice in split_slices.items(): sliced_items = item_ids[part_slice] final_items = sliced_items[-max_sequence_length:] - - assert len(item_ids[-max_sequence_length:]) == len(set(item_ids[-max_sequence_length:]),) - datasets[part_name].append({ - 'user.ids': [user_id], 'user.length': 1, - 'item.ids': final_items, 'item.length': len(final_items), - }) + assert len(item_ids[-max_sequence_length:]) == len( + set(item_ids[-max_sequence_length:]), + ) + + datasets[part_name].append( + { + "user.ids": [user_id], + "user.length": 1, + "item.ids": final_items, + "item.length": len(final_items), + } + ) return datasets, max_user_id, max_item_id -class MCLSRDataset(BaseSequenceDataset, config_name='mclsr'): + +class MCLSRDataset(BaseSequenceDataset, config_name="mclsr"): @staticmethod def _create_sequences_from_file(filepath, max_len=None): sequences = {} max_user, max_item = 0, 0 - - with open(filepath, 'r') as f: + + with open(filepath, "r") as f: for line in f: - parts = line.strip().split(' ') + parts = line.strip().split(" ") user_id = int(parts[0]) item_ids = [int(i) for i in parts[1:]] if max_len: @@ -704,40 +848,98 @@ def _create_sequences_from_file(filepath, max_len=None): if item_ids: max_item = max(max_item, max(item_ids)) return sequences, max_user, max_item - + @classmethod def _create_evaluation_sets(cls, data_dir, max_seq_len): - valid_hist, u2, i2 = cls._create_sequences_from_file(os.path.join(data_dir, 'valid_history.txt'), max_seq_len) - valid_trg, u3, i3 = cls._create_sequences_from_file(os.path.join(data_dir, 'valid_target.txt')) + valid_hist, u2, i2 = cls._create_sequences_from_file( + os.path.join(data_dir, "valid_history.txt"), max_seq_len + ) + valid_trg, u3, i3 = cls._create_sequences_from_file( + os.path.join(data_dir, "valid_target.txt") + ) - validation_dataset = [{'user.ids': [uid], 'history': valid_hist[uid], 'target': valid_trg[uid]} for uid in valid_hist if uid in valid_trg] - - test_hist, u4, i4 = cls._create_sequences_from_file(os.path.join(data_dir, 'test_history.txt'), max_seq_len) - test_trg, u5, i5 = cls._create_sequences_from_file(os.path.join(data_dir, 'test_target.txt')) + validation_dataset = [ + {"user.ids": [uid], "history": valid_hist[uid], "target": valid_trg[uid]} + for uid in valid_hist + if uid in valid_trg + ] - test_dataset = [{'user.ids': [uid], 'history': test_hist[uid], 'target': test_trg[uid]} for uid in test_hist if uid in test_trg] + test_hist, u4, i4 = cls._create_sequences_from_file( + os.path.join(data_dir, "test_history.txt"), max_seq_len + ) + test_trg, u5, i5 = cls._create_sequences_from_file( + os.path.join(data_dir, "test_target.txt") + ) + + test_dataset = [ + {"user.ids": [uid], "history": test_hist[uid], "target": test_trg[uid]} + for uid in test_hist + if uid in test_trg + ] - return validation_dataset, test_dataset, max(u2, u3, u4, u5), max(i2, i3, i4, i5) + return ( + validation_dataset, + test_dataset, + max(u2, u3, u4, u5), + max(i2, i3, i4, i5), + ) @classmethod def create_from_config(cls, config, **kwargs): - data_dir = os.path.join(config['path_to_data_dir'], config['name']) - max_seq_len = config.get('max_sequence_length') + data_dir = os.path.join(config["path_to_data_dir"], config["name"]) + max_seq_len = config.get("max_sequence_length") - train_sequences, u1, i1 = cls._create_sequences_from_file(os.path.join(data_dir, 'train_mclsr.txt'), max_seq_len) - train_dataset = [{'user.ids': [uid], 'user.length': 1, 'item.ids': seq, 'item.length': len(seq)} for uid, seq in train_sequences.items()] + train_sequences, u1, i1 = cls._create_sequences_from_file( + os.path.join(data_dir, "train_mclsr.txt"), max_seq_len + ) + train_dataset = [ + { + "user.ids": [uid], + "user.length": 1, + "item.ids": seq, + "item.length": len(seq), + } + for uid, seq in train_sequences.items() + ] user_to_all_seen_items = defaultdict(set) - for sample in train_dataset: user_to_all_seen_items[sample['user.ids'][0]].update(sample['item.ids']) - kwargs['user_to_all_seen_items'] = user_to_all_seen_items + for sample in train_dataset: + user_to_all_seen_items[sample["user.ids"][0]].update(sample["item.ids"]) + kwargs["user_to_all_seen_items"] = user_to_all_seen_items - validation_dataset, test_dataset, u_eval, i_eval = cls._create_evaluation_sets(data_dir, max_seq_len) + validation_dataset, test_dataset, u_eval, i_eval = cls._create_evaluation_sets( + data_dir, max_seq_len + ) num_users = max(u1, u_eval) num_items = max(i1, i_eval) - - train_sampler = TrainSampler.create_from_config(config['samplers'], dataset=train_dataset, num_users=num_users, num_items=num_items, **kwargs) - validation_sampler = EvalSampler.create_from_config(config['samplers'], dataset=validation_dataset, num_users=num_users, num_items=num_items, **kwargs) - test_sampler = EvalSampler.create_from_config(config['samplers'], dataset=test_dataset, num_users=num_users, num_items=num_items, **kwargs) - return cls(train_sampler, validation_sampler, test_sampler, num_users, num_items, max_seq_len) - + train_sampler = TrainSampler.create_from_config( + config["samplers"], + dataset=train_dataset, + num_users=num_users, + num_items=num_items, + **kwargs, + ) + validation_sampler = EvalSampler.create_from_config( + config["samplers"], + dataset=validation_dataset, + num_users=num_users, + num_items=num_items, + **kwargs, + ) + test_sampler = EvalSampler.create_from_config( + config["samplers"], + dataset=test_dataset, + num_users=num_users, + num_items=num_items, + **kwargs, + ) + + return cls( + train_sampler, + validation_sampler, + test_sampler, + num_users, + num_items, + max_seq_len, + ) diff --git a/src/irec/dataset/negative_samplers/__init__.py b/src/irec/dataset/negative_samplers/__init__.py index 04b82f9..654f7e1 100644 --- a/src/irec/dataset/negative_samplers/__init__.py +++ b/src/irec/dataset/negative_samplers/__init__.py @@ -3,7 +3,7 @@ from .random import RandomNegativeSampler __all__ = [ - 'BaseNegativeSampler', - 'PopularNegativeSampler', - 'RandomNegativeSampler', + "BaseNegativeSampler", + "PopularNegativeSampler", + "RandomNegativeSampler", ] diff --git a/src/irec/dataset/negative_samplers/base.py b/src/irec/dataset/negative_samplers/base.py index b4d1224..3360179 100644 --- a/src/irec/dataset/negative_samplers/base.py +++ b/src/irec/dataset/negative_samplers/base.py @@ -11,8 +11,8 @@ def __init__(self, dataset, num_users, num_items): self._seen_items = defaultdict(set) for sample in self._dataset: - user_id = sample['user.ids'][0] - items = list(sample['item.ids']) + user_id = sample["user.ids"][0] + items = list(sample["item.ids"]) self._seen_items[user_id].update(items) def generate_negative_samples(self, sample, num_negatives): diff --git a/src/irec/dataset/negative_samplers/popular.py b/src/irec/dataset/negative_samplers/popular.py index ca91bf6..f04822e 100644 --- a/src/irec/dataset/negative_samplers/popular.py +++ b/src/irec/dataset/negative_samplers/popular.py @@ -1,44 +1,70 @@ +import numpy as np from irec.dataset.negative_samplers.base import BaseNegativeSampler - from collections import Counter -class PopularNegativeSampler(BaseNegativeSampler, config_name='popular'): +class PopularNegativeSampler(BaseNegativeSampler, config_name="popular"): def __init__(self, dataset, num_users, num_items): super().__init__( dataset=dataset, num_users=num_users, num_items=num_items, ) - - self._popular_items = self._items_by_popularity() + self._item_ids, self._probs = self._calculate_item_probabilities() @classmethod def create_from_config(cls, _, **kwargs): return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], ) - def _items_by_popularity(self): - popularity = Counter() - + def _calculate_item_probabilities(self): + """ + Calculates sampling probabilities proportional to item popularity. + This distribution is required to provide non-zero p_j values for LogQ correction. + """ + counts = Counter() for sample in self._dataset: - for item_id in sample['item.ids']: - popularity[item_id] += 1 + for item_id in sample["item.ids"]: + counts[item_id] += 1 + + items = np.array(list(counts.keys())) + freqs = np.array(list(counts.values()), dtype=np.float32) + probabilities = freqs / freqs.sum() - popular_items = sorted(popularity, key=popularity.get, reverse=True) - return popular_items + return items, probabilities def generate_negative_samples(self, sample, num_negatives): - user_id = sample['user.ids'][0] - popularity_idx = 0 - negatives = [] + """ + Stochastic sampling proportional to popularity. + + Justification: + The original implementation always picked the same Top-K popular items. + For LogQ correction (Yi et al., Google 2019), we need a stochastic + sampling process where p_j > 0 for all items in the distribution. + This allows the model to see a diverse set of negatives across epochs + while penalizing popular items correctly via the log(p_j) term. + """ + user_id = sample["user.ids"][0] + seen = self._seen_items[user_id] + + negatives = set() while len(negatives) < num_negatives: - negative_idx = self._popular_items[popularity_idx] - if negative_idx not in self._seen_items[user_id]: - negatives.append(negative_idx) - popularity_idx += 1 + # Sample items based on the pre-calculated frequency distribution + sampled_ids = np.random.choice( + self._item_ids, + size=num_negatives - len(negatives), + p=self._probs, + replace=True, + ) + + # Filter out items already seen by the user (False Negatives) + for idx in sampled_ids: + if idx not in seen: + negatives.add(idx) + if len(negatives) == num_negatives: + break - return negatives + return list(negatives) diff --git a/src/irec/dataset/negative_samplers/random.py b/src/irec/dataset/negative_samplers/random.py index 79e245a..d7365a8 100644 --- a/src/irec/dataset/negative_samplers/random.py +++ b/src/irec/dataset/negative_samplers/random.py @@ -3,26 +3,41 @@ import numpy as np -class RandomNegativeSampler(BaseNegativeSampler, config_name='random'): +class RandomNegativeSampler(BaseNegativeSampler, config_name="random"): @classmethod def create_from_config(cls, _, **kwargs): return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], ) def generate_negative_samples(self, sample, num_negatives): - user_id = sample['user.ids'][0] - all_items = list(range(1, self._num_items + 1)) - np.random.shuffle(all_items) - - negatives = [] - running_idx = 0 - while len(negatives) < num_negatives and running_idx < len(all_items): - negative_idx = all_items[running_idx] - if negative_idx not in self._seen_items[user_id]: - negatives.append(negative_idx) - running_idx += 1 - - return negatives + """ + Optimized via Rejection Sampling (O(k) complexity). + + Mathematical Proof of Equivalence: + Let V be the set of all items and H be the user's history. + We need a uniform random sample S ⊂ (V \ H) such that |S| = k. + + 1. Shuffle Approach (Previous): Generates a random permutation of V, + then filters H. Complexity: O(|V|). + 2. Rejection Sampling (Current): Independently draws i ~ Uniform(V) + and accepts i if i ∉ H and i ∉ S. Complexity: O(k * 1/p), + where p = (|V| - |H|) / |V|. + + Since |H| << |V|, the probability p ≈ 1, making the expected complexity + effectively O(k). Both methods yield an identical uniform distribution + over the valid item space. + """ + user_id = sample["user.ids"][0] + seen = self._seen_items[user_id] + + negatives = set() + while len(negatives) < num_negatives: + # Drawing a random index is O(1) + negative_idx = np.random.randint(1, self._num_items + 1) + if negative_idx not in seen: + negatives.add(negative_idx) + + return list(negatives) diff --git a/src/irec/dataset/samplers/base.py b/src/irec/dataset/samplers/base.py index 158bede..7017b71 100644 --- a/src/irec/dataset/samplers/base.py +++ b/src/irec/dataset/samplers/base.py @@ -33,16 +33,16 @@ def __len__(self): return len(self._dataset) def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) + sample = self._dataset[index] - item_sequence = sample['item.ids'][:-1] - next_item = sample['item.ids'][-1] + item_sequence = sample["item.ids"][:-1] + next_item = sample["item.ids"][-1] return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'labels.ids': [next_item], - 'labels.length': 1, + "user.ids": sample["user.ids"], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "labels.ids": [next_item], + "labels.length": 1, } diff --git a/src/irec/dataset/samplers/mclsr.py b/src/irec/dataset/samplers/mclsr.py index b957cde..3a9276a 100644 --- a/src/irec/dataset/samplers/mclsr.py +++ b/src/irec/dataset/samplers/mclsr.py @@ -4,8 +4,16 @@ import random -class MCLSRTrainSampler(TrainSampler, config_name='mclsr'): - def __init__(self, dataset, num_users, num_items, user_to_all_seen_items, num_negatives, **kwargs): +class MCLSRTrainSampler(TrainSampler, config_name="mclsr"): + def __init__( + self, + dataset, + num_users, + num_items, + user_to_all_seen_items, + num_negatives, + **kwargs, + ): super().__init__() self._dataset = dataset self._num_users = num_users @@ -16,65 +24,81 @@ def __init__(self, dataset, num_users, num_items, user_to_all_seen_items, num_ne @classmethod def create_from_config(cls, config, **kwargs): - num_negatives = config['num_negatives_train'] + num_negatives = config["num_negatives_train"] print(num_negatives) return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], num_negatives=num_negatives, - user_to_all_seen_items=kwargs['user_to_all_seen_items'], + user_to_all_seen_items=kwargs["user_to_all_seen_items"], ) - def __getitem__(self, index): sample = self._dataset[index] - user_id = sample['user.ids'][0] - item_sequence = sample['item.ids'][:-1] - positive_item = sample['item.ids'][-1] + user_id = sample["user.ids"][0] + item_sequence = sample["item.ids"][:-1] + positive_item = sample["item.ids"][-1] user_seen = self._user_to_all_seen_items[user_id] - unseen_items = list(self._all_items_set - user_seen) - - negatives = random.sample(unseen_items, self._num_negatives) - + # unseen_items = list(self._all_items_set - user_seen) + # negatives = random.sample(unseen_items, self._num_negatives) + + # --- OPTIMIZATION: Rejection Sampling --- + # Mathematically equivalent to: random.sample(list(all_items - user_seen), k) + # Logic: Instead of allocating a huge list of unseen items (O(N) memory), + # we draw random indices until we have the required number of unique negatives. + # Since |user_seen| << |num_items|, the collision probability is near zero. + negatives = set() + while len(negatives) < self._num_negatives: + # Draw a random item index from the range [1, num_items] + candidate = random.randint(1, self._num_items) + + # Rejection step: Only accept if the user has never interacted with it. + # This ensures we only sample from the "unseen" pool. + if candidate not in user_seen: + negatives.add(candidate) + + # Convert back to list to match the expected format for BatchProcessor + negatives = list(negatives) + # ---------------------------------------- return { - 'user.ids': [user_id], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'labels.ids': [positive_item], - 'labels.length': 1, - 'negatives.ids': negatives, - 'negatives.length': len(negatives), + "user.ids": [user_id], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "labels.ids": [positive_item], + "labels.length": 1, + "negatives.ids": negatives, + "negatives.length": len(negatives), } -class MCLSRPredictionEvalSampler(EvalSampler, config_name='mclsr'): +class MCLSRPredictionEvalSampler(EvalSampler, config_name="mclsr"): def __init__(self, dataset, num_users, num_items): super().__init__(dataset, num_users, num_items) @classmethod def create_from_config(cls, config, **kwargs): return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], ) - + def __getitem__(self, index): sample = self._dataset[index] - history_sequence = sample['history'] - target_items = sample['target'] + history_sequence = sample["history"] + target_items = sample["target"] return { - 'user.ids': sample['user.ids'], - 'user.length': 1, - 'item.ids': history_sequence, - 'item.length': len(history_sequence), - 'labels.ids': target_items, - 'labels.length': len(target_items), + "user.ids": sample["user.ids"], + "user.length": 1, + "item.ids": history_sequence, + "item.length": len(history_sequence), + "labels.ids": target_items, + "labels.length": len(target_items), } diff --git a/src/irec/dataset/samplers/next_item_prediction.py b/src/irec/dataset/samplers/next_item_prediction.py index e263062..f757b4b 100644 --- a/src/irec/dataset/samplers/next_item_prediction.py +++ b/src/irec/dataset/samplers/next_item_prediction.py @@ -6,7 +6,7 @@ class NextItemPredictionTrainSampler( TrainSampler, - config_name='next_item_prediction', + config_name="next_item_prediction", ): def __init__( self, @@ -26,61 +26,59 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): negative_sampler = BaseNegativeSampler.create_from_config( - {'type': config['negative_sampler_type']}, + {"type": config["negative_sampler_type"]}, **kwargs, ) return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], negative_sampler=negative_sampler, - num_negatives=config.get('num_negatives_train', 0), + num_negatives=config.get("num_negatives_train", 0), ) def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) + sample = self._dataset[index] - item_sequence = sample['item.ids'][:-1] - next_item_sequence = sample['item.ids'][1:] + item_sequence = sample["item.ids"][:-1] + next_item_sequence = sample["item.ids"][1:] if self._num_negatives == 0: return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'positive.ids': next_item_sequence, - 'positive.length': len(next_item_sequence), + "user.ids": sample["user.ids"], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "positive.ids": next_item_sequence, + "positive.length": len(next_item_sequence), } else: - negative_sequence = ( - self._negative_sampler.generate_negative_samples( - sample, - self._num_negatives, - ) + negative_sequence = self._negative_sampler.generate_negative_samples( + sample, + self._num_negatives, ) return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'positive.ids': next_item_sequence, - 'positive.length': len(next_item_sequence), - 'negative.ids': negative_sequence, - 'negative.length': len(negative_sequence), + "user.ids": sample["user.ids"], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "positive.ids": next_item_sequence, + "positive.length": len(next_item_sequence), + "negative.ids": negative_sequence, + "negative.length": len(negative_sequence), } class NextItemPredictionEvalSampler( EvalSampler, - config_name='next_item_prediction', + config_name="next_item_prediction", ): @classmethod def create_from_config(cls, config, **kwargs): return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], ) diff --git a/src/irec/dataset/sasrec.py b/src/irec/dataset/sasrec.py index 0cba986..0c1dd97 100644 --- a/src/irec/dataset/sasrec.py +++ b/src/irec/dataset/sasrec.py @@ -3,35 +3,41 @@ from .base import BaseSequenceDataset, SequenceDataset, MCLSRDataset from .samplers import TrainSampler, EvalSampler -class SASRecDataset(BaseSequenceDataset, config_name='sasrec_comparison'): + +class SASRecDataset(BaseSequenceDataset, config_name="sasrec_comparison"): @classmethod def create_from_config(cls, config, **kwargs): - data_dir = os.path.join(config['path_to_data_dir'], config['name']) - max_seq_len = config.get('max_sequence_length') + data_dir = os.path.join(config["path_to_data_dir"], config["name"]) + max_seq_len = config.get("max_sequence_length") train_dataset, u1, i1, _ = SequenceDataset._create_dataset( - dir_path=data_dir, - part='train_sasrec', - max_sequence_length=max_seq_len + dir_path=data_dir, part="train_sasrec", max_sequence_length=max_seq_len ) - validation_dataset, test_dataset, u_eval, i_eval = \ + validation_dataset, test_dataset, u_eval, i_eval = ( MCLSRDataset._create_evaluation_sets(data_dir, max_seq_len) + ) num_users = max(u1, u_eval) num_items = max(i1, i_eval) - + train_sampler = TrainSampler.create_from_config( - config['train_sampler'], - dataset=train_dataset, num_users=num_users, num_items=num_items + config["train_sampler"], + dataset=train_dataset, + num_users=num_users, + num_items=num_items, ) validation_sampler = EvalSampler.create_from_config( - config['eval_sampler'], - dataset=validation_dataset, num_users=num_users, num_items=num_items + config["eval_sampler"], + dataset=validation_dataset, + num_users=num_users, + num_items=num_items, ) test_sampler = EvalSampler.create_from_config( - config['eval_sampler'], - dataset=test_dataset, num_users=num_users, num_items=num_items + config["eval_sampler"], + dataset=test_dataset, + num_users=num_users, + num_items=num_items, ) return cls( @@ -40,5 +46,5 @@ def create_from_config(cls, config, **kwargs): test_sampler=test_sampler, num_users=num_users, num_items=num_items, - max_sequence_length=max_seq_len + max_sequence_length=max_seq_len, ) diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py index 10f689e..f8a0762 100644 --- a/src/irec/loss/base.py +++ b/src/irec/loss/base.py @@ -1,12 +1,14 @@ -import copy - from irec.utils import ( MetaParent, maybe_to_list, ) +import copy import torch import torch.nn as nn +import pickle +import os +import logging class BaseLoss(metaclass=MetaParent): @@ -17,12 +19,12 @@ class TorchLoss(BaseLoss, nn.Module): pass -class IdentityLoss(BaseLoss, config_name='identity'): +class IdentityLoss(BaseLoss, config_name="identity"): def __call__(self, inputs): return inputs -class CompositeLoss(TorchLoss, config_name='composite'): +class CompositeLoss(TorchLoss, config_name="composite"): def __init__(self, losses, weights=None, output_prefix=None): super().__init__() self._losses = losses @@ -34,8 +36,8 @@ def create_from_config(cls, config, **kwargs): losses = [] weights = [] - for loss_cfg in copy.deepcopy(config)['losses']: - weight = loss_cfg.pop('weight') if 'weight' in loss_cfg else 1.0 + for loss_cfg in copy.deepcopy(config)["losses"]: + weight = loss_cfg.pop("weight") if "weight" in loss_cfg else 1.0 loss_function = BaseLoss.create_from_config(loss_cfg) weights.append(weight) @@ -44,7 +46,7 @@ def create_from_config(cls, config, **kwargs): return cls( losses=losses, weights=weights, - output_prefix=config.get('output_prefix'), + output_prefix=config.get("output_prefix"), ) def forward(self, inputs): @@ -58,7 +60,7 @@ def forward(self, inputs): return total_loss -class FpsLoss(TorchLoss, config_name='fps'): +class FpsLoss(TorchLoss, config_name="fps"): def __init__( self, fst_embeddings_prefix, @@ -73,7 +75,7 @@ def __init__( self._snd_embeddings_prefix = snd_embeddings_prefix self._tau = tau self._loss_function = nn.CrossEntropyLoss( - reduction='mean' if use_mean else 'sum', + reduction="mean" if use_mean else "sum", ) self._normalize_embeddings = normalize_embeddings self._output_prefix = output_prefix @@ -82,21 +84,17 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - fst_embeddings_prefix=config['fst_embeddings_prefix'], - snd_embeddings_prefix=config['snd_embeddings_prefix'], - tau=config.get('temperature', 1.0), - normalize_embeddings=config.get('normalize_embeddings', False), - use_mean=config.get('use_mean', True), - output_prefix=config.get('output_prefix') + fst_embeddings_prefix=config["fst_embeddings_prefix"], + snd_embeddings_prefix=config["snd_embeddings_prefix"], + tau=config.get("temperature", 1.0), + normalize_embeddings=config.get("normalize_embeddings", False), + use_mean=config.get("use_mean", True), + output_prefix=config.get("output_prefix"), ) def forward(self, inputs): - fst_embeddings = inputs[ - self._fst_embeddings_prefix - ] # (x, embedding_dim) - snd_embeddings = inputs[ - self._snd_embeddings_prefix - ] # (x, embedding_dim) + fst_embeddings = inputs[self._fst_embeddings_prefix] # (x, embedding_dim) + snd_embeddings = inputs[self._snd_embeddings_prefix] # (x, embedding_dim) batch_size = fst_embeddings.shape[0] @@ -123,7 +121,9 @@ def forward(self, inputs): torch.diag(similarity_scores, -batch_size), ), dim=0, - ).reshape(2 * batch_size, 1) # (2 * x, 1) + ).reshape( + 2 * batch_size, 1 + ) # (2 * x, 1) assert torch.allclose( torch.diag(similarity_scores, batch_size), torch.diag(similarity_scores, -batch_size), @@ -160,14 +160,9 @@ def forward(self, inputs): return loss -class SASRecLoss(TorchLoss, config_name='sasrec'): +class SASRecLoss(TorchLoss, config_name="sasrec"): - def __init__( - self, - positive_prefix, - negative_prefix, - output_prefix=None - ): + def __init__(self, positive_prefix, negative_prefix, output_prefix=None): super().__init__() self._positive_prefix = positive_prefix self._negative_prefix = negative_prefix @@ -190,7 +185,7 @@ def forward(self, inputs): return loss -class SamplesSoftmaxLoss(TorchLoss, config_name='sampled_softmax'): +class SamplesSoftmaxLoss(TorchLoss, config_name="sampled_softmax"): def __init__( self, queries_prefix, @@ -205,9 +200,7 @@ def __init__( self._output_prefix = output_prefix def forward(self, inputs): - queries_embeddings = inputs[ - self._queries_prefix - ] # (batch_size, embedding_dim) + queries_embeddings = inputs[self._queries_prefix] # (batch_size, embedding_dim) positive_embeddings = inputs[ self._positive_prefix ] # (batch_size, embedding_dim) @@ -217,15 +210,17 @@ def forward(self, inputs): # b -- batch_size, d -- embedding_dim positive_scores = torch.einsum( - 'bd,bd->b', + "bd,bd->b", queries_embeddings, positive_embeddings, - ).unsqueeze(-1) # (batch_size, 1) + ).unsqueeze( + -1 + ) # (batch_size, 1) if negative_embeddings.dim() == 2: # (num_negatives, embedding_dim) # b -- batch_size, n -- num_negatives, d -- embedding_dim negative_scores = torch.einsum( - 'bd,nd->bn', + "bd,nd->bn", queries_embeddings, negative_embeddings, ) # (batch_size, num_negatives) @@ -235,7 +230,7 @@ def forward(self, inputs): ) # (batch_size, num_negatives, embedding_dim) # b -- batch_size, n -- num_negatives, d -- embedding_dim negative_scores = torch.einsum( - 'bd,bnd->bn', + "bd,bnd->bn", queries_embeddings, negative_embeddings, ) # (batch_size, num_negatives) @@ -257,7 +252,7 @@ def forward(self, inputs): return loss -class MCLSRLoss(TorchLoss, config_name='mclsr'): +class MCLSRLoss(TorchLoss, config_name="mclsr"): def __init__( self, all_scores_prefix, @@ -294,12 +289,11 @@ def forward(self, inputs): assert torch.allclose(all_scores[0, 0], positive_scores[0]) assert torch.allclose(all_scores[-1, -1], positive_scores[-1]) - # Maybe try mean over sequence TODO loss = torch.sum( torch.log( torch.sigmoid(positive_scores.unsqueeze(1) - negative_scores), ), - ) # (1) + ) if self._output_prefix is not None: inputs[self._output_prefix] = loss.cpu().item() diff --git a/src/irec/metric/base.py b/src/irec/metric/base.py index da68965..87c0190 100644 --- a/src/irec/metric/base.py +++ b/src/irec/metric/base.py @@ -12,7 +12,7 @@ def reduce(self): raise NotImplementedError -class StaticMetric(BaseMetric, config_name='dummy'): +class StaticMetric(BaseMetric, config_name="dummy"): def __init__(self, name, value): self._name = name self._value = value @@ -23,16 +23,14 @@ def __call__(self, inputs): return inputs -class CompositeMetric(BaseMetric, config_name='composite'): +class CompositeMetric(BaseMetric, config_name="composite"): def __init__(self, metrics): self._metrics = metrics @classmethod def create_from_config(cls, config): return cls( - metrics=[ - BaseMetric.create_from_config(cfg) for cfg in config['metrics'] - ], + metrics=[BaseMetric.create_from_config(cfg) for cfg in config["metrics"]], ) def __call__(self, inputs): @@ -41,7 +39,7 @@ def __call__(self, inputs): return inputs -class NDCGMetric(BaseMetric, config_name='ndcg'): +class NDCGMetric(BaseMetric, config_name="ndcg"): def __init__(self, k): self._k = k @@ -50,9 +48,8 @@ def __call__(self, inputs, pred_prefix, labels_prefix): :, : self._k, ].float() # (batch_size, top_k_indices) - labels = inputs['{}.ids'.format(labels_prefix)].float() # (batch_size) + labels = inputs["{}.ids".format(labels_prefix)].float() # (batch_size) - assert labels.shape[0] == predictions.shape[0] hits = torch.eq( @@ -61,13 +58,15 @@ def __call__(self, inputs, pred_prefix, labels_prefix): ).float() # (batch_size, top_k_indices) discount_factor = 1 / torch.log2( torch.arange(1, self._k + 1, 1).float() + 1.0, - ).to(hits.device) # (k) - dcg = torch.einsum('bk,k->b', hits, discount_factor) # (batch_size) + ).to( + hits.device + ) # (k) + dcg = torch.einsum("bk,k->b", hits, discount_factor) # (batch_size) return dcg.cpu().tolist() -class RecallMetric(BaseMetric, config_name='recall'): +class RecallMetric(BaseMetric, config_name="recall"): def __init__(self, k): self._k = k @@ -76,7 +75,7 @@ def __call__(self, inputs, pred_prefix, labels_prefix): :, : self._k, ].float() # (batch_size, top_k_indices) - labels = inputs['{}.ids'.format(labels_prefix)].float() # (batch_size) + labels = inputs["{}.ids".format(labels_prefix)].float() # (batch_size) assert labels.shape[0] == predictions.shape[0] @@ -89,35 +88,34 @@ def __call__(self, inputs, pred_prefix, labels_prefix): return recall.cpu().tolist() -class CoverageMetric(StatefullMetric, config_name='coverage'): +class CoverageMetric(StatefullMetric, config_name="coverage"): def __init__(self, k, num_items): self._k = k self._num_items = num_items @classmethod def create_from_config(cls, config, **kwargs): - return cls(k=config['k'], num_items=kwargs['num_items']) + return cls(k=config["k"], num_items=kwargs["num_items"]) def __call__(self, inputs, pred_prefix, labels_prefix): predictions = inputs[pred_prefix][ :, : self._k, ].float() # (batch_size, top_k_indices) - return ( - predictions.view(-1).long().cpu().detach().tolist() - ) # (batch_size * k) + return predictions.view(-1).long().cpu().detach().tolist() # (batch_size * k) def reduce(self, values): return len(set(values)) / self._num_items -class MCLSRNDCGMetric(BaseMetric, config_name='mclsr-ndcg'): + +class MCLSRNDCGMetric(BaseMetric, config_name="mclsr-ndcg"): def __init__(self, k): self._k = k def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k) - labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,) - labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,) + predictions = inputs[pred_prefix][:, : self._k] # (batch_size, k) + labels_flat = inputs[f"{labels_prefix}.ids"] # (total_labels,) + labels_lengths = inputs[f"{labels_prefix}.length"] # (batch_size,) assert predictions.shape[0] == labels_lengths.shape[0] @@ -129,30 +127,30 @@ def __call__(self, inputs, pred_prefix, labels_prefix): user_labels = labels_flat[offset : offset + num_user_labels] offset += num_user_labels - hits_mask = torch.isin(user_predictions, user_labels) # (k,) -> True/False - + hits_mask = torch.isin(user_predictions, user_labels) # (k,) -> True/False + positions = torch.arange(2, self._k + 2, device=predictions.device) weights = 1 / torch.log2(positions.float()) dcg = (hits_mask.float() * weights).sum() - + num_ideal_hits = min(self._k, num_user_labels) idcg_weights = weights[:num_ideal_hits] idcg = idcg_weights.sum() - + ndcg = dcg / idcg if idcg > 0 else torch.tensor(0.0) dcg_scores.append(ndcg.cpu().item()) - + return dcg_scores -class MCLSRRecallMetric(BaseMetric, config_name='mclsr-recall'): +class MCLSRRecallMetric(BaseMetric, config_name="mclsr-recall"): def __init__(self, k): self._k = k def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k) - labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,) - labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,) + predictions = inputs[pred_prefix][:, : self._k] # (batch_size, k) + labels_flat = inputs[f"{labels_prefix}.ids"] # (total_labels,) + labels_lengths = inputs[f"{labels_prefix}.length"] # (batch_size,) assert predictions.shape[0] == labels_lengths.shape[0] @@ -163,22 +161,25 @@ def __call__(self, inputs, pred_prefix, labels_prefix): num_user_labels = labels_lengths[i] user_labels = labels_flat[offset : offset + num_user_labels] offset += num_user_labels - + hits = torch.isin(user_predictions, user_labels).sum().float() - - recall = hits / num_user_labels if num_user_labels > 0 else torch.tensor(0.0) + + recall = ( + hits / num_user_labels if num_user_labels > 0 else torch.tensor(0.0) + ) recall_scores.append(recall.cpu().item()) - + return recall_scores -class MCLSRHitRateMetric(BaseMetric, config_name='mclsr-hit'): + +class MCLSRHitRateMetric(BaseMetric, config_name="mclsr-hit"): def __init__(self, k): self._k = k def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k) - labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,) - labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,) + predictions = inputs[pred_prefix][:, : self._k] # (batch_size, k) + labels_flat = inputs[f"{labels_prefix}.ids"] # (total_labels,) + labels_lengths = inputs[f"{labels_prefix}.length"] # (batch_size,) assert predictions.shape[0] == labels_lengths.shape[0] @@ -194,9 +195,9 @@ def __call__(self, inputs, pred_prefix, labels_prefix): user_labels = labels_flat[offset : offset + num_user_labels] offset += num_user_labels - + is_hit = torch.isin(user_predictions, user_labels).any() - + hit_scores.append(float(is_hit)) - - return hit_scores \ No newline at end of file + + return hit_scores diff --git a/src/irec/models/__init__.py b/src/irec/models/__init__.py index 351a048..761e479 100644 --- a/src/irec/models/__init__.py +++ b/src/irec/models/__init__.py @@ -1,13 +1,9 @@ from .base import BaseModel, SequentialTorchModel from .mclsr import MCLSRModel from .sasrec import SasRecModel, SasRecInBatchModel -from .sasrec_ce import SasRecCeModel __all__ = [ 'BaseModel', 'MCLSRModel', 'SasRecModel', - 'SasRecInBatchModel', - 'SasRecCeModel', - 'SasRecRealModel', ] diff --git a/src/irec/models/base.py b/src/irec/models/base.py index 2f059f7..22555f4 100644 --- a/src/irec/models/base.py +++ b/src/irec/models/base.py @@ -21,8 +21,8 @@ class TorchModel(nn.Module, BaseModel): @torch.no_grad() def _init_weights(self, initializer_range): for key, value in self.named_parameters(): - if 'weight' in key: - if 'norm' in key: + if "weight" in key: + if "norm" in key: nn.init.ones_(value.data) else: nn.init.trunc_normal_( @@ -31,10 +31,10 @@ def _init_weights(self, initializer_range): a=-2 * initializer_range, b=2 * initializer_range, ) - elif 'bias' in key: + elif "bias" in key: nn.init.zeros_(value.data) else: - raise ValueError(f'Unknown transformer weight: {key}') + raise ValueError(f"Unknown transformer weight: {key}") @staticmethod def _get_last_embedding(embeddings, mask): @@ -54,12 +54,12 @@ def _get_last_embedding(embeddings, mask): ) # (batch_size, 1, emb_dim) last_embeddings = last_embeddings[last_masks] # (batch_size, emb_dim) if not torch.allclose(embeddings[mask][-1], last_embeddings[-1]): - logger.debug(f'Embeddings: {embeddings}') + logger.debug(f"Embeddings: {embeddings}") logger.debug( - f'Lengths: {lengths}, max: {lengths.max()}, min: {lengths.min()}', + f"Lengths: {lengths}, max: {lengths.max()}, min: {lengths.min()}", ) - logger.debug(f'Last embedding from mask: {embeddings[mask][-1]}') - logger.debug(f'Last embedding from gather: {last_embeddings[-1]}') + logger.debug(f"Last embedding from mask: {embeddings[mask][-1]}") + logger.debug(f"Last embedding from gather: {last_embeddings[-1]}") assert False return last_embeddings @@ -74,7 +74,7 @@ def __init__( num_layers, dim_feedforward, dropout=0.0, - activation='relu', + activation="relu", layer_norm_eps=1e-5, is_causal=True, ): @@ -85,8 +85,7 @@ def __init__( self._embedding_dim = embedding_dim self._item_embeddings = nn.Embedding( - num_embeddings=num_items - + 2, # add zero embedding + mask embedding + num_embeddings=num_items + 2, # add zero embedding + mask embedding embedding_dim=embedding_dim, ) self._position_embeddings = nn.Embedding( @@ -135,9 +134,7 @@ def _apply_sequential_encoder(self, events, lengths, add_cls_token=False): .tile([batch_size, 1]) .long() ) # (batch_size, seq_len) - positions_mask = ( - positions < lengths[:, None] - ) # (batch_size, max_seq_len) + positions_mask = positions < lengths[:, None] # (batch_size, max_seq_len) positions = positions[positions_mask] # (all_batch_events) position_embeddings = self._position_embeddings( @@ -215,7 +212,9 @@ def _add_cls_token(items, lengths, cls_token_id=0): dim=0, index=torch.cat( [torch.LongTensor([0]).to(DEVICE), lengths + 1], - ).cumsum(dim=0)[:-1], + ).cumsum( + dim=0 + )[:-1], ) # (num_new_items) new_items[old_items_mask] = items new_length = lengths + 1 diff --git a/src/irec/models/mclsr.py b/src/irec/models/mclsr.py index fc8c6ef..2de98f2 100644 --- a/src/irec/models/mclsr.py +++ b/src/irec/models/mclsr.py @@ -5,8 +5,10 @@ from irec.utils import create_masked_tensor +torch.backends.cudnn.benchmark = True -class MCLSRModel(TorchModel, config_name='mclsr'): + +class MCLSRModel(TorchModel, config_name="mclsr"): def __init__( self, sequence_prefix, @@ -50,8 +52,7 @@ def __init__( self._item_graph = item_graph self._item_embeddings = nn.Embedding( - num_embeddings=num_items - + 2, # add zero embedding + mask embedding + num_embeddings=num_items + 2, # add zero embedding + mask embedding embedding_dim=embedding_dim, ) self._position_embeddings = nn.Embedding( @@ -61,8 +62,7 @@ def __init__( ) self._user_embeddings = nn.Embedding( - num_embeddings=num_users - + 2, # add zero embedding + mask embedding + num_embeddings=num_users + 2, # add zero embedding + mask embedding embedding_dim=embedding_dim, ) @@ -155,23 +155,23 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - sequence_prefix=config['sequence_prefix'], - user_prefix=config['user_prefix'], - labels_prefix=config['labels_prefix'], - negatives_prefix=config.get('negatives_prefix', 'negatives'), - candidate_prefix=config['candidate_prefix'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_graph_layers=config['num_graph_layers'], - common_graph=kwargs['graph'], - user_graph=kwargs['user_graph'], - item_graph=kwargs['item_graph'], - dropout=config.get('dropout', 0.0), - layer_norm_eps=config.get('layer_norm_eps', 1e-5), - graph_dropout=config.get('graph_dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), + sequence_prefix=config["sequence_prefix"], + user_prefix=config["user_prefix"], + labels_prefix=config["labels_prefix"], + negatives_prefix=config.get("negatives_prefix", "negatives"), + candidate_prefix=config["candidate_prefix"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_graph_layers=config["num_graph_layers"], + common_graph=kwargs["graph"], + user_graph=kwargs["user_graph"], + item_graph=kwargs["item_graph"], + dropout=config.get("dropout", 0.0), + layer_norm_eps=config.get("layer_norm_eps", 1e-5), + graph_dropout=config.get("graph_dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), ) def _apply_graph_encoder(self, embeddings, graph, use_mean=False): @@ -200,12 +200,12 @@ def _apply_graph_encoder(self, embeddings, graph, use_mean=False): def forward(self, inputs): all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) + "{}.ids".format(self._sequence_prefix) ] # (all_batch_events) all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) + "{}.length".format(self._sequence_prefix) ] # (batch_size) - user_ids = inputs['{}.ids'.format(self._user_prefix)] # (batch_size) + user_ids = inputs["{}.ids".format(self._user_prefix)] # (batch_size) embeddings = self._item_embeddings( all_sample_events, @@ -243,11 +243,11 @@ def forward(self, inputs): lengths=all_sample_lengths, ) # (batch_size, seq_len, embedding_dim) assert torch.allclose(position_embeddings[~mask], embeddings[~mask]) - + positioned_embeddings = ( embeddings + position_embeddings ) # (batch_size, seq_len, embedding_dim) - + positioned_embeddings = self._layernorm( positioned_embeddings, ) # (batch_size, seq_len, embedding_dim) @@ -258,7 +258,7 @@ def forward(self, inputs): # formula 2 sequential_attention_matrix = self._current_interest_learning_encoder( - positioned_embeddings, # E_u,p + positioned_embeddings, # E_u,p ).squeeze() # (batch_size, seq_len) sequential_attention_matrix[~mask] = -torch.inf @@ -269,134 +269,205 @@ def forward(self, inputs): # formula 3 sequential_representation = torch.einsum( - 'bs,bsd->bd', - sequential_attention_matrix, # A^s + "bs,bsd->bd", + sequential_attention_matrix, # A^s embeddings, ) # (batch_size, embedding_dim) if self.training: # general interest # formula 4 - all_init_embeddings = torch.cat([self._user_embeddings.weight, - self._item_embeddings.weight], - dim=0) - all_graph_embeddings = self._apply_graph_encoder(embeddings=all_init_embeddings, - graph=self._graph) + all_init_embeddings = torch.cat( + [self._user_embeddings.weight, self._item_embeddings.weight], dim=0 + ) + all_graph_embeddings = self._apply_graph_encoder( + embeddings=all_init_embeddings, graph=self._graph + ) common_graph_user_embs_all, common_graph_item_embs_all = torch.split( all_graph_embeddings, [self._num_users + 2, self._num_items + 2] ) common_graph_user_embs_batch = common_graph_user_embs_all[user_ids] common_graph_item_embs_batch, _ = create_masked_tensor( - data=common_graph_item_embs_all[all_sample_events], - lengths=all_sample_lengths + data=common_graph_item_embs_all[all_sample_events], + lengths=all_sample_lengths, ) # formula 5: A_c = softmax(tanh(W_3 * h_u,uv) * (E_u,uv)^T) - graph_attention_matrix = torch.einsum('bd,bsd->bs', - self._general_interest_learning_encoder - (common_graph_user_embs_batch), - common_graph_item_embs_batch) + graph_attention_matrix = torch.einsum( + "bd,bsd->bs", + self._general_interest_learning_encoder(common_graph_user_embs_batch), + common_graph_item_embs_batch, + ) graph_attention_matrix[~mask] = -torch.inf graph_attention_matrix = torch.softmax(graph_attention_matrix, dim=1) # formula 6: I_c = A_c * E_u,uv - original_graph_representation = torch.einsum('bs,bsd->bd', - graph_attention_matrix, - common_graph_item_embs_batch) + original_graph_representation = torch.einsum( + "bs,bsd->bd", graph_attention_matrix, common_graph_item_embs_batch + ) original_sequential_representation = sequential_representation # formula 13: I_comb = alpha * I_s + (1 - alpha) * I_c # L_P (Downstream Loss) - combined_representation = (self._alpha * original_sequential_representation + - (1 - self._alpha) * original_graph_representation) - labels = inputs['{}.ids'.format(self._labels_prefix)] + combined_representation = ( + self._alpha * original_sequential_representation + + (1 - self._alpha) * original_graph_representation + ) + labels = inputs["{}.ids".format(self._labels_prefix)] labels_embeddings = self._item_embeddings(labels) - + # formula 7 # L_IL (Interest-level CL) - sequential_representation_proj = self._sequential_projector(original_sequential_representation) - graph_representation_proj = self._graph_projector(original_graph_representation) + sequential_representation_proj = self._sequential_projector( + original_sequential_representation + ) + graph_representation_proj = self._graph_projector( + original_graph_representation + ) # formula 9: H_u,uu = GraphEncoder(H_u, G_uu) # L_UC (User-level CL) - user_graph_user_embs_all = self._apply_graph_encoder(embeddings=self._user_embeddings.weight, - graph=self._user_graph) + user_graph_user_embs_all = self._apply_graph_encoder( + embeddings=self._user_embeddings.weight, graph=self._user_graph + ) user_graph_user_embs_batch = user_graph_user_embs_all[user_ids] # formula 10 # T_f,uu = MLP(H_u,uu) и T_f,uv = MLP(H_u,uv) - user_graph_user_embeddings_proj = self._user_projection(user_graph_user_embs_batch) - common_graph_user_embeddings_proj = self._user_projection(common_graph_user_embs_batch) + user_graph_user_embeddings_proj = self._user_projection( + user_graph_user_embs_batch + ) + common_graph_user_embeddings_proj = self._user_projection( + common_graph_user_embs_batch + ) # item level CL common_graph_items_flat = common_graph_item_embs_batch[mask] - - item_graph_items_all = self._apply_graph_encoder(embeddings=self._item_embeddings.weight, - graph=self._item_graph) + + item_graph_items_all = self._apply_graph_encoder( + embeddings=self._item_embeddings.weight, graph=self._item_graph + ) item_graph_items_flat = item_graph_items_all[all_sample_events] - unique_item_ids, inverse_indices = torch.unique(all_sample_events, - return_inverse=True) + unique_item_ids, inverse_indices = torch.unique( + all_sample_events, return_inverse=True + ) + # OLD BAD + # try: + # from torch_scatter import scatter_mean + # except ImportError: + # # print("Warning: torch_scatter not found. Using a slower fallback function.") + # def scatter_mean(src, index, dim=0, dim_size=None): + # out_size = dim_size if dim_size is not None else index.max() + 1 + # out = torch.zeros((out_size, src.size(1)), dtype=src.dtype, device=src.device) + # counts = torch.bincount(index, minlength=out_size).unsqueeze(-1).clamp(min=1) + # return out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) / counts + + # --- OPTIMIZED AGGREGATION: scatter_mean --- + # We use scatter_mean to aggregate features of unique items within a batch + # for Item-level Feature Contrastive Learning. + # + # Performance Note: + # We prioritize the 'torch-scatter' library because it provides highly optimized + # C++/CUDA kernels that perform in-place aggregation. + # + # The 'except ImportError' fallback is provided for environment compatibility, + # but it is NOT recommended for large-scale datasets like Amazon Books. + # The fallback implementation uses 'expand_as', which creates massive temporary + # tensors in GPU memory, potentially leading to Out-Of-Memory (OOM) errors + # when processing millions of interaction events. try: from torch_scatter import scatter_mean except ImportError: - # print("Warning: torch_scatter not found. Using a slower fallback function.") + def scatter_mean(src, index, dim=0, dim_size=None): out_size = dim_size if dim_size is not None else index.max() + 1 - out = torch.zeros((out_size, src.size(1)), dtype=src.dtype, device=src.device) - counts = torch.bincount(index, minlength=out_size).unsqueeze(-1).clamp(min=1) - return out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) / counts - + out = torch.zeros( + (out_size, src.size(1)), dtype=src.dtype, device=src.device + ) + counts = ( + torch.bincount(index, minlength=out_size) + .unsqueeze(-1) + .clamp(min=1) + ) + # WARNING: .expand_as() below is a memory bottleneck for large tensors + return ( + out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) + / counts + ) + num_unique_items = unique_item_ids.shape[0] - unique_common_graph_items = scatter_mean(common_graph_items_flat, - inverse_indices, dim=0, - dim_size=num_unique_items) + unique_common_graph_items = scatter_mean( + common_graph_items_flat, + inverse_indices, + dim=0, + dim_size=num_unique_items, + ) - unique_item_graph_items = scatter_mean(item_graph_items_flat, - inverse_indices, dim=0, - dim_size=num_unique_items) + unique_item_graph_items = scatter_mean( + item_graph_items_flat, inverse_indices, dim=0, dim_size=num_unique_items + ) # projection for Item-level Feature CL - unique_common_graph_items_proj = self._item_projection(unique_common_graph_items) - unique_item_graph_items_proj = self._item_projection(unique_item_graph_items) + unique_common_graph_items_proj = self._item_projection( + unique_common_graph_items + ) + unique_item_graph_items_proj = self._item_projection( + unique_item_graph_items + ) - negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] # (batch_size, num_negatives) - negative_embeddings = self._item_embeddings(negative_ids) # (batch_size, num_negatives, embedding_dim) + # negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] # (batch_size, num_negatives) + # negative_embeddings = self._item_embeddings(negative_ids) # (batch_size, num_negatives, embedding_dim) + + raw_negative_ids = inputs["{}.ids".format(self._negatives_prefix)] + num_negatives = raw_negative_ids.shape[0] // batch_size + negative_ids = raw_negative_ids.view( + batch_size, num_negatives + ) # (Batch, NumNegs) + negative_embeddings = self._item_embeddings( + negative_ids + ) # (Batch, NumNegs, Dim) # import code; code.interact(local=locals()) return { # L_P (formula 14) - 'combined_representation': combined_representation, - 'label_representation': labels_embeddings, - - 'negative_representation': negative_embeddings, - + "combined_representation": combined_representation, + "label_representation": labels_embeddings, + "negative_representation": negative_embeddings, + # --- ID PASS-THROUGH FOR LOGQ & MASKING --- + # We pass raw item and user indices to enable advanced loss operations: + # 1. False Negative Masking: Allows SamplesSoftmaxLoss to identify and + # neutralize cases where a target item accidentally appears in the + # negative sampling pool. + # 2. Per-item LogQ Correction: Enables mapping item IDs to global frequency + # stats (item_counts.pkl) to remove popularity bias (Sampling Bias). + "positive_ids": labels, # Item target indices + "negative_ids": negative_ids, # Sampled negative item indices + # Useful for potential User-level LogQ correction as requested + # by the supervisor to handle highly active users. + "user_ids": user_ids, # for L_IL (formula 8) - - 'sequential_representation': sequential_representation_proj, - 'graph_representation': graph_representation_proj, - + "sequential_representation": sequential_representation_proj, + "graph_representation": graph_representation_proj, # for L_UC (formula 11) - 'user_graph_user_embeddings': user_graph_user_embeddings_proj, - 'common_graph_user_embeddings': common_graph_user_embeddings_proj, - + "user_graph_user_embeddings": user_graph_user_embeddings_proj, + "common_graph_user_embeddings": common_graph_user_embeddings_proj, # for L_IC - 'item_graph_item_embeddings': unique_item_graph_items_proj, - 'common_graph_item_embeddings': unique_common_graph_items_proj, + "item_graph_item_embeddings": unique_item_graph_items_proj, + "common_graph_item_embeddings": unique_common_graph_items_proj, } else: # eval mode # formula 16: R(u,N) = Top-N((I_s)^T * h_o) - if '{}.ids'.format(self._candidate_prefix) in inputs: + if "{}.ids".format(self._candidate_prefix) in inputs: candidate_events = inputs[ - '{}.ids'.format(self._candidate_prefix) + "{}.ids".format(self._candidate_prefix) ] # (all_batch_candidates) candidate_lengths = inputs[ - '{}.length'.format(self._candidate_prefix) - + "{}.length".format(self._candidate_prefix) ] # (batch_size) candidate_embeddings = self._item_embeddings( @@ -409,23 +480,22 @@ def scatter_mean(src, index, dim=0, dim_size=None): ) # (batch_size, num_candidates, embedding_dim) candidate_scores = torch.einsum( - 'bd,bnd->bn', - sequential_representation, # I_s - candidate_embeddings, # h_o (and h_k) + "bd,bnd->bn", + sequential_representation, # I_s + candidate_embeddings, # h_o (and h_k) ) # (batch_size, num_candidates) else: candidate_embeddings = ( self._item_embeddings.weight ) # (num_items, embedding_dim) candidate_scores = torch.einsum( - 'bd,nd->bn', - sequential_representation, # I_s - candidate_embeddings, # all h_v + "bd,nd->bn", + sequential_representation, # I_s + candidate_embeddings, # all h_v ) # (batch_size, num_items) candidate_scores[:, 0] = -torch.inf candidate_scores[:, self._num_items + 1 :] = -torch.inf - values, indices = torch.topk( candidate_scores, k=50, @@ -433,4 +503,4 @@ def scatter_mean(src, index, dim=0, dim_size=None): largest=True, ) # (batch_size, 100), (batch_size, 100) - return indices \ No newline at end of file + return indices diff --git a/src/irec/models/sasrec.py b/src/irec/models/sasrec.py index e97019a..52384d2 100644 --- a/src/irec/models/sasrec.py +++ b/src/irec/models/sasrec.py @@ -4,22 +4,22 @@ import torch -class SasRecModel(SequentialTorchModel, config_name='sasrec'): +class SasRecModel(SequentialTorchModel, config_name="sasrec"): def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02 + self, + sequence_prefix, + positive_prefix, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, ): super().__init__( num_items=num_items, @@ -31,7 +31,7 @@ def __init__( dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, - is_causal=True + is_causal=True, ) self._sequence_prefix = sequence_prefix self._positive_prefix = positive_prefix @@ -41,103 +41,100 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)), - num_layers=config['num_layers'], - dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02) + sequence_prefix=config["sequence_prefix"], + positive_prefix=config["positive_prefix"], + num_items=kwargs["num_items"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_layers=config["num_layers"], + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), ) def forward(self, inputs): - all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events) - all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size) + all_sample_events = inputs[ + "{}.ids".format(self._sequence_prefix) + ] # (all_batch_events) + all_sample_lengths = inputs[ + "{}.length".format(self._sequence_prefix) + ] # (batch_size) embeddings, mask = self._apply_sequential_encoder( all_sample_events, all_sample_lengths ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) if self.training: # training mode - all_positive_sample_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events) + all_positive_sample_events = inputs[ + "{}.ids".format(self._positive_prefix) + ] # (all_batch_events) - all_sample_embeddings = embeddings[mask] # (all_batch_events, embedding_dim) + all_sample_embeddings = embeddings[ + mask + ] # (all_batch_events, embedding_dim) all_embeddings = self._item_embeddings.weight # (num_items, embedding_dim) # a -- all_batch_events, n -- num_items, d -- embedding_dim all_scores = torch.einsum( - 'ad,nd->an', - all_sample_embeddings, - all_embeddings + "ad,nd->an", all_sample_embeddings, all_embeddings ) # (all_batch_events, num_items) positive_scores = torch.gather( - input=all_scores, - dim=1, - index=all_positive_sample_events[..., None] - )[:, 0] # (all_batch_items) + input=all_scores, dim=1, index=all_positive_sample_events[..., None] + )[ + :, 0 + ] # (all_batch_items) negative_scores = torch.gather( input=all_scores, dim=1, - index=torch.randint(low=0, high=all_scores.shape[1], size=all_positive_sample_events.shape, device=all_positive_sample_events.device)[..., None] - )[:, 0] # (all_batch_items) - - # sample_ids, _ = create_masked_tensor( - # data=all_sample_events, - # lengths=all_sample_lengths - # ) # (batch_size, seq_len) - - # sample_ids = torch.repeat_interleave(sample_ids, all_sample_lengths, dim=0) # (all_batch_events, seq_len) - - # negative_scores = torch.scatter( - # input=all_scores, - # dim=1, - # index=sample_ids, - # src=torch.ones_like(sample_ids) * (-torch.inf) - # ) # (all_batch_events, num_items) + index=torch.randint( + low=0, + high=all_scores.shape[1], + size=all_positive_sample_events.shape, + device=all_positive_sample_events.device, + )[..., None], + )[ + :, 0 + ] # (all_batch_items) return { - 'positive_scores': positive_scores, - 'negative_scores': negative_scores + "positive_scores": positive_scores, + "negative_scores": negative_scores, } else: # eval mode - last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim) + last_embeddings = self._get_last_embedding( + embeddings, mask + ) # (batch_size, embedding_dim) # b - batch_size, n - num_candidates, d - embedding_dim candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight + "bd,nd->bn", last_embeddings, self._item_embeddings.weight ) # (batch_size, num_items + 2) _, indices = torch.topk( - candidate_scores, - k=50, dim=-1, largest=True + candidate_scores, k=50, dim=-1, largest=True ) # (batch_size, 20) return indices -class SasRecInBatchModel(SasRecModel, config_name='sasrec_in_batch'): - +class SasRecInBatchModel(SasRecModel, config_name="sasrec_in_batch"): def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02 + self, + sequence_prefix, + positive_prefix, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, ): super().__init__( sequence_prefix=sequence_prefix, @@ -151,27 +148,31 @@ def __init__( dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, - initializer_range=initializer_range + initializer_range=initializer_range, ) @classmethod def create_from_config(cls, config, **kwargs): return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)), - num_layers=config['num_layers'], - dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02) + sequence_prefix=config["sequence_prefix"], + positive_prefix=config["positive_prefix"], + num_items=kwargs["num_items"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_layers=config["num_layers"], + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), ) def forward(self, inputs): - all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events) - all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size) + all_sample_events = inputs[ + "{}.ids".format(self._sequence_prefix) + ] # (all_batch_events) + all_sample_lengths = inputs[ + "{}.length".format(self._sequence_prefix) + ] # (batch_size) embeddings, mask = self._apply_sequential_encoder( all_sample_events, all_sample_lengths @@ -179,10 +180,14 @@ def forward(self, inputs): if self.training: # training mode # queries - in_batch_queries_embeddings = embeddings[mask] # (all_batch_events, embedding_dim) + in_batch_queries_embeddings = embeddings[ + mask + ] # (all_batch_events, embedding_dim) # positives - in_batch_positive_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events) + in_batch_positive_events = inputs[ + "{}.ids".format(self._positive_prefix) + ] # (all_batch_events) in_batch_positive_embeddings = self._item_embeddings( in_batch_positive_events ) # (all_batch_events, embedding_dim) @@ -197,25 +202,33 @@ def forward(self, inputs): ) # (batch_size, embedding_dim) return { - 'query_embeddings': in_batch_queries_embeddings, - 'positive_embeddings': in_batch_positive_embeddings, - 'negative_embeddings': in_batch_negative_embeddings + "query_embeddings": in_batch_queries_embeddings, + "positive_embeddings": in_batch_positive_embeddings, + "negative_embeddings": in_batch_negative_embeddings, + # --- ID PASS-THROUGH FOR DOWNSTREAM LOSS OPERATIONS --- + # We pass the raw item IDs to the loss function to enable: + # 1. False Negative Masking: Identifying and neutralizing cases where a positive + # item for one user accidentally appears as a negative sample for another + # user in the same batch (critical for In-Batch Negatives stability). + # 2. Per-item LogQ Correction: Mapping item IDs to their global frequencies + # to subtract log(Q) based on the specific item popularity. + "positive_ids": in_batch_positive_events, + "negative_ids": in_batch_negative_ids, } else: # eval mode - last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim) + last_embeddings = self._get_last_embedding( + embeddings, mask + ) # (batch_size, embedding_dim) # b - batch_size, n - num_candidates, d - embedding_dim candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight + "bd,nd->bn", last_embeddings, self._item_embeddings.weight ) # (batch_size, num_items + 2) candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1:] = -torch.inf + candidate_scores[:, self._num_items + 1 :] = -torch.inf _, indices = torch.topk( - candidate_scores, - k=50, dim=-1, largest=True + candidate_scores, k=50, dim=-1, largest=True ) # (batch_size, 20) - return indices \ No newline at end of file + return indices diff --git a/src/irec/models/sasrec_ce.py b/src/irec/models/sasrec_ce.py deleted file mode 100644 index c07b9ff..0000000 --- a/src/irec/models/sasrec_ce.py +++ /dev/null @@ -1,108 +0,0 @@ -from .base import SequentialTorchModel - -import torch -import torch.nn as nn - - -class SasRecCeModel(SequentialTorchModel, config_name='sasrec_ce'): - def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=True, - ) - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - - self._output_projection = nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get( - 'num_heads', - int(config['embedding_dim'] // 64), - ), - num_layers=config['num_layers'], - dim_feedforward=config.get( - 'dim_feedforward', - 4 * config['embedding_dim'], - ), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, - all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - embeddings = self._output_projection( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.nn.functional.gelu( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.einsum( - 'bsd,nd->bsn', - embeddings, - self._item_embeddings.weight, - ) # (batch_size, seq_len, num_items + 2) - - if self.training: # training mode - return {'logits': embeddings[mask]} - else: # eval mode - candidate_scores = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/optimizer/base.py b/src/irec/optimizer/base.py index 0fc1f70..51d9464 100644 --- a/src/irec/optimizer/base.py +++ b/src/irec/optimizer/base.py @@ -5,14 +5,14 @@ import torch OPTIMIZERS = { - 'sgd': torch.optim.SGD, - 'adam': torch.optim.Adam, - 'adamw': torch.optim.AdamW, + "sgd": torch.optim.SGD, + "adam": torch.optim.Adam, + "adamw": torch.optim.AdamW, } SCHEDULERS = { - 'step': torch.optim.lr_scheduler.StepLR, - 'cyclic': torch.optim.lr_scheduler.CyclicLR, + "step": torch.optim.lr_scheduler.StepLR, + "cyclic": torch.optim.lr_scheduler.CyclicLR, } @@ -20,7 +20,7 @@ class BaseOptimizer(metaclass=MetaParent): pass -class BasicOptimizer(BaseOptimizer, config_name='basic'): +class BasicOptimizer(BaseOptimizer, config_name="basic"): def __init__( self, model, @@ -35,15 +35,15 @@ def __init__( @classmethod def create_from_config(cls, config, **kwargs): - optimizer_cfg = copy.deepcopy(config['optimizer']) - optimizer = OPTIMIZERS[optimizer_cfg.pop('type')]( - kwargs['model'].parameters(), + optimizer_cfg = copy.deepcopy(config["optimizer"]) + optimizer = OPTIMIZERS[optimizer_cfg.pop("type")]( + kwargs["model"].parameters(), **optimizer_cfg, ) - if 'scheduler' in config: - scheduler_cfg = copy.deepcopy(config['scheduler']) - scheduler = SCHEDULERS[scheduler_cfg.pop('type')]( + if "scheduler" in config: + scheduler_cfg = copy.deepcopy(config["scheduler"]) + scheduler = SCHEDULERS[scheduler_cfg.pop("type")]( optimizer, **scheduler_cfg, ) @@ -51,10 +51,10 @@ def create_from_config(cls, config, **kwargs): scheduler = None return cls( - model=kwargs['model'], + model=kwargs["model"], optimizer=optimizer, scheduler=scheduler, - clip_grad_threshold=config.get('clip_grad_threshold', None), + clip_grad_threshold=config.get("clip_grad_threshold", None), ) def step(self, loss): @@ -72,7 +72,7 @@ def step(self, loss): self._scheduler.step() def state_dict(self): - state_dict = {'optimizer': self._optimizer.state_dict()} + state_dict = {"optimizer": self._optimizer.state_dict()} if self._scheduler is not None: - state_dict.update({'scheduler': self._scheduler.state_dict()}) + state_dict.update({"scheduler": self._scheduler.state_dict()}) return state_dict diff --git a/src/irec/train.py b/src/irec/train.py index 7844439..8251822 100644 --- a/src/irec/train.py +++ b/src/irec/train.py @@ -43,20 +43,20 @@ def train( best_epoch = 0 best_checkpoint = None - logger.debug('Start training...') + logger.debug("Start training...") while (epoch_cnt is None or epoch_num < epoch_cnt) and ( step_cnt is None or step_num < step_cnt ): if best_epoch + epochs_threshold < epoch_num: logger.debug( - 'There is no progress during {} epochs. Finish training'.format( + "There is no progress during {} epochs. Finish training".format( epochs_threshold, ), ) break - logger.debug(f'Start epoch {epoch_num}') + logger.debug(f"Start epoch {epoch_num}") for step, batch in enumerate(dataloader): batch_ = copy.deepcopy(batch) @@ -87,7 +87,7 @@ def train( best_epoch = epoch_num epoch_num += 1 - logger.debug('Training procedure has been finished!') + logger.debug("Training procedure has been finished!") return best_checkpoint @@ -95,71 +95,72 @@ def main(): fix_random_seed(seed_val) config = parse_args() - if config.get('use_wandb', False): + if config.get("use_wandb", False): wandb.init( - project='irec', - name=config['experiment_name'], + project="irec", + name=config["experiment_name"], sync_tensorboard=True, ) - tensorboard_writer = irec.utils.tensorboards.TensorboardWriter(config['experiment_name']) + tensorboard_writer = irec.utils.tensorboards.TensorboardWriter( + config["experiment_name"] + ) irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER = tensorboard_writer log_dir = tensorboard_writer.log_dir - config_save_path = os.path.join(log_dir, 'config.json') - with open(config_save_path, 'w') as f: + config_save_path = os.path.join(log_dir, "config.json") + with open(config_save_path, "w") as f: json.dump(config, f, indent=2) - - logger.debug('Training config: \n{}'.format(json.dumps(config, indent=2))) - logger.debug('Current DEVICE: {}'.format(DEVICE)) - logger.info(f"Experiment config saved to: {config_save_path}") + logger.debug("Training config: \n{}".format(json.dumps(config, indent=2))) + logger.debug("Current DEVICE: {}".format(DEVICE)) + logger.info(f"Experiment config saved to: {config_save_path}") - dataset = BaseDataset.create_from_config(config['dataset']) + dataset = BaseDataset.create_from_config(config["dataset"]) train_sampler, validation_sampler, test_sampler = dataset.get_samplers() train_dataloader = BaseDataloader.create_from_config( - config['dataloader']['train'], + config["dataloader"]["train"], dataset=train_sampler, **dataset.meta, ) validation_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], + config["dataloader"]["validation"], dataset=validation_sampler, **dataset.meta, ) eval_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], + config["dataloader"]["validation"], dataset=test_sampler, **dataset.meta, ) - model = BaseModel.create_from_config(config['model'], **dataset.meta).to( + model = BaseModel.create_from_config(config["model"], **dataset.meta).to( DEVICE, ) - if 'checkpoint' in config: + if "checkpoint" in config: ensure_checkpoints_dir() checkpoint_path = os.path.join( - './checkpoints', + "./checkpoints", f'{config["checkpoint"]}.pth', ) - logger.debug('Loading checkpoint from {}'.format(checkpoint_path)) + logger.debug("Loading checkpoint from {}".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) logger.debug(checkpoint.keys()) model.load_state_dict(checkpoint) - loss_function = BaseLoss.create_from_config(config['loss']) + loss_function = BaseLoss.create_from_config(config["loss"]) optimizer = BaseOptimizer.create_from_config( - config['optimizer'], + config["optimizer"], model=model, ) callback = BaseCallback.create_from_config( - config['callback'], + config["callback"], model=model, train_dataloader=train_dataloader, validation_dataloader=validation_dataloader, @@ -170,7 +171,7 @@ def main(): # TODO add verbose option for all callbacks, multiple optimizer options (???) # TODO create pre/post callbacks - logger.debug('Everything is ready for training process!') + logger.debug("Everything is ready for training process!") # Train process _ = train( @@ -179,22 +180,22 @@ def main(): optimizer=optimizer, loss_function=loss_function, callback=callback, - epoch_cnt=config.get('train_epochs_num'), - step_cnt=config.get('train_steps_num'), - best_metric=config.get('best_metric'), + epoch_cnt=config.get("train_epochs_num"), + step_cnt=config.get("train_steps_num"), + best_metric=config.get("best_metric"), ) - logger.debug('Saving model...') + logger.debug("Saving model...") ensure_checkpoints_dir() - checkpoint_path = './checkpoints/{}_final_state.pth'.format( - config['experiment_name'], + checkpoint_path = "./checkpoints/{}_final_state.pth".format( + config["experiment_name"], ) torch.save(model.state_dict(), checkpoint_path) - logger.debug('Saved model as {}'.format(checkpoint_path)) + logger.debug("Saved model as {}".format(checkpoint_path)) - if config.get('use_wandb', False): + if config.get("use_wandb", False): wandb.finish() -if __name__ == '__main__': +if __name__ == "__main__": main()