diff --git a/README.md b/README.md
index 680c46e..f0053bc 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,5 @@
+
+# Steps on local machine
@@ -40,26 +42,6 @@
uv sync --frozen
```
-### Using pip
-
-1. Create and activate a virtual environment:
- ```bash
- python3 -m venv .venv
- source ./.venv/bin/activate
- ```
-
-2. Install dependencies:
-
- **For development:**
- ```bash
- pip install -e ".[dev]"
- ```
-
- **For production:**
- ```bash
- pip install -e .
- ```
-
## Preparing datasets
All pre-processed datasets used in our experiments are available for download from our cloud storage. This is the fastest way to get started.
@@ -91,4 +73,4 @@ The script has 1 input argument: `params` which is the path to the json file wit
-`callbacks` Different additional traning
--`use_wandb` Enable Weights & Biases logging for experiment tracking
\ No newline at end of file
+-`use_wandb` Enable Weights & Biases logging for experiment tracking
diff --git a/configs/inference/sasrec_inference_config.json b/configs/inference/sasrec_inference_config.json
deleted file mode 100644
index c91e21b..0000000
--- a/configs/inference/sasrec_inference_config.json
+++ /dev/null
@@ -1,88 +0,0 @@
-{
- "pred_prefix": "logits",
- "label_prefix": "labels",
- "experiment_name": "sasrec_beauty_grid__0-5__",
- "dataset": {
- "type": "sequence",
- "path_to_data_dir": "./data",
- "name": "Beauty",
- "max_sequence_length": 50,
- "samplers": {
- "type": "next_item_prediction",
- "negative_sampler_type": "random"
- }
- },
- "dataloader": {
- "train": {
- "type": "torch",
- "batch_size": 256,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": true,
- "shuffle": true
- },
- "validation": {
- "type": "torch",
- "batch_size": 256,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": false,
- "shuffle": false
- }
- },
- "model": {
- "type": "sasrec",
- "sequence_prefix": "item",
- "positive_prefix": "positive",
- "negative_prefix": "negative",
- "candidate_prefix": "candidates",
- "embedding_dim": 64,
- "num_heads": 2,
- "num_layers": 2,
- "dim_feedforward": 256,
- "dropout": 0.5,
- "activation": "gelu",
- "layer_norm_eps": 1e-9,
- "initializer_range": 0.02
- },
- "metrics": {
- "ndcg@5": {
- "type": "ndcg",
- "k": 5
- },
- "ndcg@10": {
- "type": "ndcg",
- "k": 10
- },
- "ndcg@20": {
- "type": "ndcg",
- "k": 20
- },
- "recall@5": {
- "type": "recall",
- "k": 5
- },
- "recall@10": {
- "type": "recall",
- "k": 10
- },
- "recall@20": {
- "type": "recall",
- "k": 20
- },
- "coverage@5": {
- "type": "coverage",
- "k": 5
- },
- "coverage@10": {
- "type": "coverage",
- "k": 10
- },
- "coverage@20": {
- "type": "coverage",
- "k": 20
- }
- }
-}
diff --git a/configs/train/mclsr_baseline_clothing_sanity.json b/configs/train/mclsr_baseline_clothing_sanity.json
new file mode 100644
index 0000000..0039582
--- /dev/null
+++ b/configs/train/mclsr_baseline_clothing_sanity.json
@@ -0,0 +1,233 @@
+{
+ "experiment_name": "mclsr_Clothing",
+ "use_wandb": true,
+ "best_metric": "validation/ndcg@20",
+ "dataset": {
+ "type": "graph",
+ "use_user_graph": true,
+ "use_item_graph": true,
+ "neighborhood_size": 50,
+ "graph_dir_path": "./data/Clothing",
+ "dataset": {
+ "type": "mclsr",
+ "path_to_data_dir": "./data",
+ "name": "Clothing",
+ "max_sequence_length": 20,
+ "samplers": {
+ "num_negatives_val": 1280,
+ "num_negatives_train": 1280,
+ "type": "mclsr",
+ "negative_sampler_type": "random"
+ }
+ }
+ },
+ "dataloader": {
+ "train": {
+ "type": "torch",
+ "batch_size": 128,
+ "batch_processor": {
+ "type": "basic"
+ },
+ "drop_last": true,
+ "shuffle": true,
+ "num_workers": 4,
+ "pin_memory" : true
+ },
+ "validation": {
+ "type": "torch",
+ "batch_size": 128,
+ "batch_processor": {
+ "type": "basic"
+ },
+ "drop_last": false,
+ "shuffle": false,
+ "num_workers": 4,
+ "pin_memory" : true
+ }
+ },
+ "model": {
+ "type": "mclsr",
+ "sequence_prefix": "item",
+ "user_prefix": "user",
+ "labels_prefix": "labels",
+ "candidate_prefix": "candidates",
+ "embedding_dim": 64,
+ "num_graph_layers": 2,
+ "dropout": 0.3,
+ "layer_norm_eps": 1e-9,
+ "graph_dropout": 0.3,
+ "initializer_range": 0.02,
+ "alpha": 0.5
+ },
+ "optimizer": {
+ "type": "basic",
+ "optimizer": {
+ "type": "adam",
+ "lr": 0.001
+ }
+ },
+ "loss": {
+ "type": "composite",
+ "losses": [
+ {
+ "type": "sampled_softmax",
+ "queries_prefix": "combined_representation",
+ "positive_prefix": "label_representation",
+ "negative_prefix": "negative_representation",
+ "output_prefix": "downstream_loss",
+ "weight": 1.0
+ },
+ {
+ "type": "fps",
+ "fst_embeddings_prefix": "sequential_representation",
+ "snd_embeddings_prefix": "graph_representation",
+ "output_prefix": "contrastive_interest_loss",
+ "weight": 1.0,
+ "temperature": 0.5
+ },
+ {
+ "type": "fps",
+ "fst_embeddings_prefix": "user_graph_user_embeddings",
+ "snd_embeddings_prefix": "common_graph_user_embeddings",
+ "output_prefix": "contrastive_user_feature_loss",
+ "weight": 0.05,
+ "temperature": 0.5
+ },
+ {
+ "type": "fps",
+ "fst_embeddings_prefix": "item_graph_item_embeddings",
+ "snd_embeddings_prefix": "common_graph_item_embeddings",
+ "output_prefix": "contrastive_item_feature_loss",
+ "weight": 0.05,
+ "temperature": 0.5
+ }
+ ],
+ "output_prefix": "loss"
+ },
+ "callback": {
+ "type": "composite",
+ "callbacks": [
+ {
+ "type": "metric",
+ "on_step": 1,
+ "loss_prefix": "loss"
+ },
+ {
+ "type": "metric",
+ "on_step": 1,
+ "loss_prefix": "downstream_loss"
+ },
+ {
+ "type": "metric",
+ "on_step": 1,
+ "loss_prefix": "contrastive_interest_loss"
+ },
+ {
+ "type": "metric",
+ "on_step": 1,
+ "loss_prefix": "contrastive_user_feature_loss"
+ },
+ {
+ "type": "metric",
+ "on_step": 1,
+ "loss_prefix": "contrastive_item_feature_loss"
+ },
+ {
+ "type": "validation",
+ "on_step": 64,
+ "pred_prefix": "predictions",
+ "labels_prefix": "labels",
+ "metrics": {
+ "ndcg@5": {
+ "type": "mclsr-ndcg",
+ "k": 5
+ },
+ "ndcg@10": {
+ "type": "mclsr-ndcg",
+ "k": 10
+ },
+ "ndcg@20": {
+ "type": "mclsr-ndcg",
+ "k": 20
+ },
+ "ndcg@50": {
+ "type": "mclsr-ndcg",
+ "k": 50
+ },
+ "recall@5": {
+ "type": "mclsr-recall",
+ "k": 5
+ },
+ "recall@10": {
+ "type": "mclsr-recall",
+ "k": 10
+ },
+ "recall@20": {
+ "type": "mclsr-recall",
+ "k": 20
+ },
+ "recall@50": {
+ "type": "mclsr-recall",
+ "k": 50
+ },
+ "hit@20": {
+ "type": "mclsr-hit",
+ "k": 20
+ },
+ "hit@50": {
+ "type": "mclsr-hit",
+ "k": 50
+ }
+ }
+ },
+ {
+ "type": "eval",
+ "on_step": 256,
+ "pred_prefix": "predictions",
+ "labels_prefix": "labels",
+ "metrics": {
+ "ndcg@5": {
+ "type": "mclsr-ndcg",
+ "k": 5
+ },
+ "ndcg@10": {
+ "type": "mclsr-ndcg",
+ "k": 10
+ },
+ "ndcg@20": {
+ "type": "mclsr-ndcg",
+ "k": 20
+ },
+ "ndcg@50": {
+ "type": "mclsr-ndcg",
+ "k": 50
+ },
+ "recall@5": {
+ "type": "mclsr-recall",
+ "k": 5
+ },
+ "recall@10": {
+ "type": "mclsr-recall",
+ "k": 10
+ },
+ "recall@20": {
+ "type": "mclsr-recall",
+ "k": 20
+ },
+ "recall@50": {
+ "type": "mclsr-recall",
+ "k": 50
+ },
+ "hit@20": {
+ "type": "mclsr-hit",
+ "k": 20
+ },
+ "hit@50": {
+ "type": "mclsr-hit",
+ "k": 50
+ }
+ }
+ }
+ ]
+ }
+}
\ No newline at end of file
diff --git a/configs/train/sasrec_ce_train_config.json b/configs/train/sasrec_ce_train_config.json
deleted file mode 100644
index ce17ae3..0000000
--- a/configs/train/sasrec_ce_train_config.json
+++ /dev/null
@@ -1,146 +0,0 @@
-{
- "experiment_name": "sasrec_test",
- "best_metric": "eval/ndcg@20",
- "dataset": {
- "type": "scientific",
- "path_to_data_dir": "./data",
- "name": "Beauty",
- "max_sequence_length": 50,
- "samplers": {
- "num_negatives_val": 100,
- "type": "next_item_prediction",
- "negative_sampler_type": "random"
- }
- },
- "dataloader": {
- "train": {
- "type": "torch",
- "batch_size": 256,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": true,
- "shuffle": true
- },
- "validation": {
- "type": "torch",
- "batch_size": 256,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": false,
- "shuffle": false
- }
- },
- "model": {
- "type": "sasrec",
- "sequence_prefix": "item",
- "positive_prefix": "positive",
- "negative_prefix": "negative",
- "candidate_prefix": "candidates",
- "embedding_dim": 64,
- "num_heads": 2,
- "num_layers": 2,
- "dim_feedforward": 256,
- "dropout": 0.3,
- "activation": "gelu",
- "use_ce": true,
- "layer_norm_eps": 1e-9,
- "initializer_range": 0.02
- },
- "optimizer": {
- "type": "basic",
- "optimizer": {
- "type": "adam",
- "lr": 0.001
- },
- "clip_grad_threshold": 5.0
- },
- "loss": {
- "type": "composite",
- "losses": [
- {
- "type": "sasrec",
- "positive_prefix": "positive_embeddings",
- "negative_prefix": "negative_embeddings",
- "representation_prefix": "current_embeddings",
- "output_prefix": "downstream_loss"
- }
- ],
- "output_prefix": "loss"
- },
- "callback": {
- "type": "composite",
- "callbacks": [
- {
- "type": "metric",
- "on_step": 1,
- "loss_prefix": "loss"
- },
- {
- "type": "validation",
- "on_step": 64,
- "pred_prefix": "logits",
- "labels_prefix": "labels",
- "metrics": {
- "ndcg@5": {
- "type": "ndcg",
- "k": 5
- },
- "ndcg@10": {
- "type": "ndcg",
- "k": 10
- },
- "ndcg@20": {
- "type": "ndcg",
- "k": 20
- },
- "recall@5": {
- "type": "recall",
- "k": 5
- },
- "recall@10": {
- "type": "recall",
- "k": 10
- },
- "recall@20": {
- "type": "recall",
- "k": 20
- }
- }
- },
- {
- "type": "eval",
- "on_step": 256,
- "pred_prefix": "logits",
- "labels_prefix": "labels",
- "metrics": {
- "ndcg@5": {
- "type": "ndcg",
- "k": 5
- },
- "ndcg@10": {
- "type": "ndcg",
- "k": 10
- },
- "ndcg@20": {
- "type": "ndcg",
- "k": 20
- },
- "recall@5": {
- "type": "recall",
- "k": 5
- },
- "recall@10": {
- "type": "recall",
- "k": 10
- },
- "recall@20": {
- "type": "recall",
- "k": 20
- }
- }
- }
- ]
- }
-}
diff --git a/configs/train/sasrec_in_batch_train_config.json b/configs/train/sasrec_in_batch_train_config.json
deleted file mode 100644
index 9950f31..0000000
--- a/configs/train/sasrec_in_batch_train_config.json
+++ /dev/null
@@ -1,146 +0,0 @@
-{
- "experiment_name": "sasrec_in_batch_test",
- "best_metric": "eval/ndcg@20",
- "dataset": {
- "type": "scientific",
- "path_to_data_dir": "data",
- "name": "Beauty",
- "max_sequence_length": 50,
- "samplers": {
- "num_negatives_val": 100,
- "type": "next_item_prediction",
- "negative_sampler_type": "random"
- }
- },
- "dataloader": {
- "train": {
- "type": "torch",
- "batch_size": 256,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": true,
- "shuffle": true
- },
- "validation": {
- "type": "torch",
- "batch_size": 256,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": false,
- "shuffle": false
- }
- },
- "model": {
- "type": "sasrec_in_batch",
- "sequence_prefix": "item",
- "positive_prefix": "positive",
- "negative_prefix": "negative",
- "candidate_prefix": "candidates",
- "embedding_dim": 64,
- "num_heads": 2,
- "num_layers": 2,
- "dim_feedforward": 256,
- "dropout": 0.3,
- "activation": "gelu",
- "use_ce": true,
- "layer_norm_eps": 1e-9,
- "initializer_range": 0.02
- },
- "optimizer": {
- "type": "basic",
- "optimizer": {
- "type": "adam",
- "lr": 0.001
- },
- "clip_grad_threshold": 5.0
- },
- "loss": {
- "type": "composite",
- "losses": [
- {
- "type": "sampled_softmax",
- "queries_prefix": "query_embeddings",
- "positive_prefix": "positive_embeddings",
- "negative_prefix": "negative_embeddings",
- "output_prefix": "downstream_loss"
- }
- ],
- "output_prefix": "loss"
- },
- "callback": {
- "type": "composite",
- "callbacks": [
- {
- "type": "metric",
- "on_step": 1,
- "loss_prefix": "loss"
- },
- {
- "type": "validation",
- "on_step": 64,
- "pred_prefix": "logits",
- "labels_prefix": "labels",
- "metrics": {
- "ndcg@5": {
- "type": "ndcg",
- "k": 5
- },
- "ndcg@10": {
- "type": "ndcg",
- "k": 10
- },
- "ndcg@20": {
- "type": "ndcg",
- "k": 20
- },
- "recall@5": {
- "type": "recall",
- "k": 5
- },
- "recall@10": {
- "type": "recall",
- "k": 10
- },
- "recall@20": {
- "type": "recall",
- "k": 20
- }
- }
- },
- {
- "type": "eval",
- "on_step": 256,
- "pred_prefix": "logits",
- "labels_prefix": "labels",
- "metrics": {
- "ndcg@5": {
- "type": "ndcg",
- "k": 5
- },
- "ndcg@10": {
- "type": "ndcg",
- "k": 10
- },
- "ndcg@20": {
- "type": "ndcg",
- "k": 20
- },
- "recall@5": {
- "type": "recall",
- "k": 5
- },
- "recall@10": {
- "type": "recall",
- "k": 10
- },
- "recall@20": {
- "type": "recall",
- "k": 20
- }
- }
- }
- ]
- }
-}
diff --git a/configs/train/sasrec_real_train_config.json b/configs/train/sasrec_real_train_config.json
deleted file mode 100644
index b04a5a0..0000000
--- a/configs/train/sasrec_real_train_config.json
+++ /dev/null
@@ -1,217 +0,0 @@
-{
- "experiment_name": "sasrec_real_beauty",
- "best_metric": "validation/ndcg@100",
- "dataset": {
- "type": "scientific",
- "path_to_data_dir": "./data",
- "name": "Beauty",
- "max_sequence_length": 100,
- "samplers": {
- "num_negatives_train": 1,
- "type": "next_item_prediction",
- "negative_sampler_type": "random"
- }
- },
- "dataloader": {
- "train": {
- "type": "torch",
- "batch_size": 256,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": true,
- "shuffle": true
- },
- "validation": {
- "type": "torch",
- "batch_size": 2048,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": false,
- "shuffle": false
- }
- },
- "model": {
- "type": "sasrec_real",
- "sequence_prefix": "item",
- "positive_prefix": "positive",
- "negative_prefix": "negative",
- "candidate_prefix": "candidates",
- "embedding_dim": 64,
- "num_heads": 2,
- "num_layers": 2,
- "dim_feedforward": 256,
- "dropout": 0.3,
- "activation": "gelu",
- "layer_norm_eps": 1e-9,
- "initializer_range": 0.02
- },
- "optimizer": {
- "type": "basic",
- "optimizer": {
- "type": "adam",
- "lr": 0.001
- },
- "clip_grad_threshold": 5.0
- },
- "loss": {
- "type": "composite",
- "losses": [
- {
- "type": "sasrec_real",
- "positive_prefix": "positive_scores",
- "negative_prefix": "negative_scores",
- "output_prefix": "downstream_loss"
- }
- ],
- "output_prefix": "loss"
- },
- "callback": {
- "type": "composite",
- "callbacks": [
- {
- "type": "metric",
- "on_step": 1,
- "loss_prefix": "loss"
- },
- {
- "type": "validation",
- "on_step": 64,
- "pred_prefix": "logits",
- "labels_prefix": "labels",
- "metrics": {
- "ndcg@5": {
- "type": "ndcg",
- "k": 5
- },
- "ndcg@10": {
- "type": "ndcg",
- "k": 10
- },
- "ndcg@20": {
- "type": "ndcg",
- "k": 20
- },
- "ndcg@50": {
- "type": "ndcg",
- "k": 50
- },
- "ndcg@100": {
- "type": "ndcg",
- "k": 100
- },
- "recall@5": {
- "type": "recall",
- "k": 5
- },
- "recall@10": {
- "type": "recall",
- "k": 10
- },
- "recall@20": {
- "type": "recall",
- "k": 20
- },
- "recall@50": {
- "type": "recall",
- "k": 50
- },
- "recall@100": {
- "type": "recall",
- "k": 100
- },
- "coverage@5": {
- "type": "coverage",
- "k": 5
- },
- "coverage@10": {
- "type": "coverage",
- "k": 10
- },
- "coverage@20": {
- "type": "coverage",
- "k": 20
- },
- "coverage@50": {
- "type": "coverage",
- "k": 50
- },
- "coverage@100": {
- "type": "coverage",
- "k": 100
- }
- }
- },
- {
- "type": "eval",
- "on_step": 256,
- "pred_prefix": "logits",
- "labels_prefix": "labels",
- "metrics": {
- "ndcg@5": {
- "type": "ndcg",
- "k": 5
- },
- "ndcg@10": {
- "type": "ndcg",
- "k": 10
- },
- "ndcg@20": {
- "type": "ndcg",
- "k": 20
- },
- "ndcg@50": {
- "type": "ndcg",
- "k": 50
- },
- "ndcg@100": {
- "type": "ndcg",
- "k": 100
- },
- "recall@5": {
- "type": "recall",
- "k": 5
- },
- "recall@10": {
- "type": "recall",
- "k": 10
- },
- "recall@20": {
- "type": "recall",
- "k": 20
- },
- "recall@50": {
- "type": "recall",
-
- "k": 50
- },
- "recall@100": {
- "type": "recall",
- "k": 100
- },
- "coverage@5": {
- "type": "coverage",
- "k": 5
- },
- "coverage@10": {
- "type": "coverage",
- "k": 10
- },
- "coverage@20": {
- "type": "coverage",
- "k": 20
- },
- "coverage@50": {
- "type": "coverage",
- "k": 50
- },
- "coverage@100": {
- "type": "coverage",
- "k": 100
- }
- }
- }
- ]
- }
-}
diff --git a/configs/train/sasrec_train_grid_config.json b/configs/train/sasrec_train_grid_config.json
deleted file mode 100644
index acf86b6..0000000
--- a/configs/train/sasrec_train_grid_config.json
+++ /dev/null
@@ -1,185 +0,0 @@
-{
- "start_from": 0,
- "experiment_name": "sasrec_beauty_grid",
- "best_metric": "validation/ndcg@20",
- "dataset": {
- "type": "sequence",
- "path_to_data_dir": "./data",
- "name": "Beauty",
- "max_sequence_length": 50,
- "samplers": {
- "type": "next_item_prediction",
- "negative_sampler_type": "random"
- }
- },
- "dataset_params": {
- },
- "dataloader": {
- "train": {
- "type": "torch",
- "batch_size": 256,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": true,
- "shuffle": true
- },
- "validation": {
- "type": "torch",
- "batch_size": 256,
- "batch_processor": {
- "type": "basic"
- },
- "drop_last": false,
- "shuffle": false
- }
- },
- "model": {
- "type": "sasrec",
- "sequence_prefix": "item",
- "positive_prefix": "positive",
- "negative_prefix": "negative",
- "candidate_prefix": "candidates",
- "embedding_dim": 64,
- "num_heads": 2,
- "num_layers": 2,
- "dim_feedforward": 256,
- "activation": "gelu",
- "layer_norm_eps": 1e-9,
- "initializer_range": 0.02
- },
- "model_params": {
- "dropout": [
- 0.1,
- 0.2,
- 0.3,
- 0.4,
- 0.5,
- 0.6,
- 0.7,
- 0.8,
- 0.9]
- },
- "optimizer": {
- "type": "basic",
- "optimizer": {
- "type": "adam",
- "lr": 0.001
- },
- "clip_grad_threshold": 5.0
- },
- "optimizer_params": {
- },
- "loss": {
- "type": "composite",
- "losses": [
- {
- "type": "sasrec",
- "positive_prefix": "positive_scores",
- "negative_prefix": "negative_scores",
- "output_prefix": "downstream_loss"
- }
- ],
- "output_prefix": "loss"
- },
- "loss_params": {
- },
- "callback": {
- "type": "composite",
- "callbacks": [
- {
- "type": "metric",
- "on_step": 1,
- "loss_prefix": "loss"
- },
- {
- "type": "validation",
- "on_step": 64,
- "pred_prefix": "logits",
- "labels_prefix": "labels",
- "metrics": {
- "ndcg@5": {
- "type": "ndcg",
- "k": 5
- },
- "ndcg@10": {
- "type": "ndcg",
- "k": 10
- },
- "ndcg@20": {
- "type": "ndcg",
- "k": 20
- },
- "recall@5": {
- "type": "recall",
- "k": 5
- },
- "recall@10": {
- "type": "recall",
- "k": 10
- },
- "recall@20": {
- "type": "recall",
- "k": 20
- },
- "coverage@5": {
- "type": "coverage",
- "k": 5
- },
- "coverage@10": {
- "type": "coverage",
- "k": 10
- },
- "coverage@20": {
- "type": "coverage",
- "k": 20
- }
- }
- },
- {
- "type": "eval",
- "on_step": 256,
- "pred_prefix": "logits",
- "labels_prefix": "labels",
- "metrics": {
- "ndcg@5": {
- "type": "ndcg",
- "k": 5
- },
- "ndcg@10": {
- "type": "ndcg",
- "k": 10
- },
- "ndcg@20": {
- "type": "ndcg",
- "k": 20
- },
- "recall@5": {
- "type": "recall",
- "k": 5
- },
- "recall@10": {
- "type": "recall",
- "k": 10
- },
- "recall@20": {
- "type": "recall",
- "k": 20
- },
- "coverage@5": {
- "type": "coverage",
- "k": 5
- },
- "coverage@10": {
- "type": "coverage",
- "k": 10
- },
- "coverage@20": {
- "type": "coverage",
- "k": 20
- }
- }
- }
- ]
- }
-}
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..5b0ca5a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+numpy~=1.26
+torch~=2.4
+tqdm~=4.66
+scipy~=1.14
+pandas~=2.2
+polars~=1.27
+matplotlib~=3.9
+tensorboard~=2.19
+wandb~=0.19
+jupyterlab~=4.4
+ipykernel~=6.29
+notebook~=7.4
diff --git a/src/irec/callbacks/base.py b/src/irec/callbacks/base.py
index 81271f5..77895d6 100644
--- a/src/irec/callbacks/base.py
+++ b/src/irec/callbacks/base.py
@@ -30,7 +30,7 @@ def __call__(self, inputs, step_num):
raise NotImplementedError
-class MetricCallback(BaseCallback, config_name='metric'):
+class MetricCallback(BaseCallback, config_name="metric"):
def __init__(
self,
model,
@@ -56,43 +56,39 @@ def __init__(
@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
- model=kwargs['model'],
- train_dataloader=kwargs['train_dataloader'],
- validation_dataloader=kwargs['validation_dataloader'],
- eval_dataloader=kwargs['eval_dataloader'],
- optimizer=kwargs['optimizer'],
- on_step=config['on_step'],
- metrics=config.get('metrics', None),
- loss_prefix=config['loss_prefix'],
+ model=kwargs["model"],
+ train_dataloader=kwargs["train_dataloader"],
+ validation_dataloader=kwargs["validation_dataloader"],
+ eval_dataloader=kwargs["eval_dataloader"],
+ optimizer=kwargs["optimizer"],
+ on_step=config["on_step"],
+ metrics=config.get("metrics", None),
+ loss_prefix=config["loss_prefix"],
)
def __call__(self, inputs, step_num):
if step_num % self._on_step == 0:
for metric_name, metric_function in self._metrics.items():
metric_value = metric_function(
- ground_truth=inputs[
- self._model.schema['ground_truth_prefix']
- ],
- predictions=inputs[
- self._model.schema['predictions_prefix']
- ],
+ ground_truth=inputs[self._model.schema["ground_truth_prefix"]],
+ predictions=inputs[self._model.schema["predictions_prefix"]],
)
irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar(
- 'train/{}'.format(metric_name),
+ "train/{}".format(metric_name),
metric_value,
step_num,
)
irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar(
- 'train/{}'.format(self._loss_prefix),
+ "train/{}".format(self._loss_prefix),
inputs[self._loss_prefix],
step_num,
)
irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.flush()
-class CheckpointCallback(BaseCallback, config_name='checkpoint'):
+class CheckpointCallback(BaseCallback, config_name="checkpoint"):
def __init__(
self,
model,
@@ -115,7 +111,7 @@ def __init__(
self._save_path = Path(os.path.join(save_path, model_name))
if self._save_path.exists():
logger.warning(
- 'Checkpoint path `{}` is already exists!'.format(
+ "Checkpoint path `{}` is already exists!".format(
self._save_path,
),
)
@@ -125,31 +121,31 @@ def __init__(
@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
- model=kwargs['model'],
- train_dataloader=kwargs['train_dataloader'],
- validation_dataloader=kwargs['validation_dataloader'],
- eval_dataloader=kwargs['eval_dataloader'],
- optimizer=kwargs['optimizer'],
- on_step=config['on_step'],
- save_path=config['save_path'],
- model_name=config['model_name'],
+ model=kwargs["model"],
+ train_dataloader=kwargs["train_dataloader"],
+ validation_dataloader=kwargs["validation_dataloader"],
+ eval_dataloader=kwargs["eval_dataloader"],
+ optimizer=kwargs["optimizer"],
+ on_step=config["on_step"],
+ save_path=config["save_path"],
+ model_name=config["model_name"],
)
def __call__(self, inputs, step_num):
if step_num % self._on_step == 0:
- logger.debug('Saving model state on step {}...'.format(step_num))
+ logger.debug("Saving model state on step {}...".format(step_num))
torch.save(
{
- 'step_num': step_num,
- 'model_state_dict': self._model.state_dict(),
- 'optimizer_state_dict': self._optimizer.state_dict(),
+ "step_num": step_num,
+ "model_state_dict": self._model.state_dict(),
+ "optimizer_state_dict": self._optimizer.state_dict(),
},
os.path.join(
self._save_path,
- 'checkpoint_{}.pth'.format(step_num),
+ "checkpoint_{}.pth".format(step_num),
),
)
- logger.debug('Saving done!')
+ logger.debug("Saving done!")
class InferenceCallback(BaseCallback):
@@ -183,24 +179,24 @@ def __init__(
def create_from_config(cls, config, **kwargs):
metrics = {
metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs)
- for metric_name, metric_cfg in config['metrics'].items()
+ for metric_name, metric_cfg in config["metrics"].items()
}
return cls(
- model=kwargs['model'],
- train_dataloader=kwargs['train_dataloader'],
- validation_dataloader=kwargs['validation_dataloader'],
- eval_dataloader=kwargs['eval_dataloader'],
- optimizer=kwargs['optimizer'],
- on_step=config['on_step'],
+ model=kwargs["model"],
+ train_dataloader=kwargs["train_dataloader"],
+ validation_dataloader=kwargs["validation_dataloader"],
+ eval_dataloader=kwargs["eval_dataloader"],
+ optimizer=kwargs["optimizer"],
+ on_step=config["on_step"],
metrics=metrics,
- pred_prefix=config['pred_prefix'],
- labels_prefix=config['labels_prefix'],
+ pred_prefix=config["pred_prefix"],
+ labels_prefix=config["labels_prefix"],
)
def __call__(self, inputs, step_num):
if step_num % self._on_step == 0: # TODO Add time monitoring
- logger.debug(f'Running {self._get_name()} on step {step_num}...')
+ logger.debug(f"Running {self._get_name()} on step {step_num}...")
running_params = {}
for metric_name, metric_function in self._metrics.items():
running_params[metric_name] = []
@@ -239,16 +235,16 @@ def __call__(self, inputs, step_num):
)
for label, value in running_params.items():
- inputs[f'{self._get_name()}/{label}'] = np.mean(value)
+ inputs[f"{self._get_name()}/{label}"] = np.mean(value)
irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar(
- f'{self._get_name()}/{label}',
+ f"{self._get_name()}/{label}",
np.mean(value),
step_num,
)
irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.flush()
logger.debug(
- f'Running {self._get_name()} on step {step_num} is done!',
+ f"Running {self._get_name()} on step {step_num} is done!",
)
def _get_name(self):
@@ -258,55 +254,55 @@ def _get_dataloader(self):
raise NotImplementedError
-class ValidationCallback(InferenceCallback, config_name='validation'):
+class ValidationCallback(InferenceCallback, config_name="validation"):
@classmethod
def create_from_config(cls, config, **kwargs):
metrics = {
metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs)
- for metric_name, metric_cfg in config['metrics'].items()
+ for metric_name, metric_cfg in config["metrics"].items()
}
return cls(
- model=kwargs['model'],
- train_dataloader=kwargs['train_dataloader'],
- validation_dataloader=kwargs['validation_dataloader'],
- eval_dataloader=kwargs['eval_dataloader'],
- optimizer=kwargs['optimizer'],
- on_step=config['on_step'],
+ model=kwargs["model"],
+ train_dataloader=kwargs["train_dataloader"],
+ validation_dataloader=kwargs["validation_dataloader"],
+ eval_dataloader=kwargs["eval_dataloader"],
+ optimizer=kwargs["optimizer"],
+ on_step=config["on_step"],
metrics=metrics,
- pred_prefix=config['pred_prefix'],
- labels_prefix=config['labels_prefix'],
+ pred_prefix=config["pred_prefix"],
+ labels_prefix=config["labels_prefix"],
)
def _get_dataloader(self):
return self._validation_dataloader
-class EvalCallback(InferenceCallback, config_name='eval'):
+class EvalCallback(InferenceCallback, config_name="eval"):
@classmethod
def create_from_config(cls, config, **kwargs):
metrics = {
metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs)
- for metric_name, metric_cfg in config['metrics'].items()
+ for metric_name, metric_cfg in config["metrics"].items()
}
return cls(
- model=kwargs['model'],
- train_dataloader=kwargs['train_dataloader'],
- validation_dataloader=kwargs['validation_dataloader'],
- eval_dataloader=kwargs['eval_dataloader'],
- optimizer=kwargs['optimizer'],
- on_step=config['on_step'],
+ model=kwargs["model"],
+ train_dataloader=kwargs["train_dataloader"],
+ validation_dataloader=kwargs["validation_dataloader"],
+ eval_dataloader=kwargs["eval_dataloader"],
+ optimizer=kwargs["optimizer"],
+ on_step=config["on_step"],
metrics=metrics,
- pred_prefix=config['pred_prefix'],
- labels_prefix=config['labels_prefix'],
+ pred_prefix=config["pred_prefix"],
+ labels_prefix=config["labels_prefix"],
)
def _get_dataloader(self):
return self._eval_dataloader
-class CompositeCallback(BaseCallback, config_name='composite'):
+class CompositeCallback(BaseCallback, config_name="composite"):
def __init__(
self,
model,
@@ -328,14 +324,14 @@ def __init__(
@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
- model=kwargs['model'],
- train_dataloader=kwargs['train_dataloader'],
- validation_dataloader=kwargs['validation_dataloader'],
- eval_dataloader=kwargs['eval_dataloader'],
- optimizer=kwargs['optimizer'],
+ model=kwargs["model"],
+ train_dataloader=kwargs["train_dataloader"],
+ validation_dataloader=kwargs["validation_dataloader"],
+ eval_dataloader=kwargs["eval_dataloader"],
+ optimizer=kwargs["optimizer"],
callbacks=[
BaseCallback.create_from_config(cfg, **kwargs)
- for cfg in config['callbacks']
+ for cfg in config["callbacks"]
],
)
diff --git a/src/irec/dataloader/base.py b/src/irec/dataloader/base.py
index 06fdefc..5e09f46 100644
--- a/src/irec/dataloader/base.py
+++ b/src/irec/dataloader/base.py
@@ -13,7 +13,7 @@ class BaseDataloader(metaclass=MetaParent):
pass
-class TorchDataloader(BaseDataloader, config_name='torch'):
+class TorchDataloader(BaseDataloader, config_name="torch"):
def __init__(self, dataloader):
self._dataloader = dataloader
@@ -27,17 +27,32 @@ def __len__(self):
def create_from_config(cls, config, **kwargs):
create_config = copy.deepcopy(config)
batch_processor = BaseBatchProcessor.create_from_config(
- create_config.pop('batch_processor')
- if 'batch_processor' in create_config
- else {'type': 'identity'},
+ (
+ create_config.pop("batch_processor")
+ if "batch_processor" in create_config
+ else {"type": "identity"}
+ ),
)
create_config.pop(
- 'type',
+ "type",
) # For passing as **config in torch DataLoader
+
+ pin_memory = create_config.pop("pin_memory", True)
+
return cls(
dataloader=DataLoader(
- kwargs['dataset'],
+ kwargs["dataset"],
collate_fn=batch_processor,
+ pin_memory=pin_memory,
**create_config,
),
)
+
+ # return cls(
+ # dataloader=DataLoader(
+ # kwargs['dataset'],
+ # collate_fn=batch_processor,
+ # pin_memory=True,
+ # **create_config,
+ # ),
+ # )
diff --git a/src/irec/dataloader/batch_processors.py b/src/irec/dataloader/batch_processors.py
index a8dbdde..f2f3927 100644
--- a/src/irec/dataloader/batch_processors.py
+++ b/src/irec/dataloader/batch_processors.py
@@ -1,4 +1,5 @@
import torch
+import itertools
from irec.utils import MetaParent
@@ -7,32 +8,59 @@ def __call__(self, batch):
raise NotImplementedError
-class IdentityBatchProcessor(BaseBatchProcessor, config_name='identity'):
+class IdentityBatchProcessor(BaseBatchProcessor, config_name="identity"):
def __call__(self, batch):
return torch.tensor(batch)
-class BasicBatchProcessor(BaseBatchProcessor, config_name='basic'):
+class BasicBatchProcessor(BaseBatchProcessor, config_name="basic"):
def __call__(self, batch):
processed_batch = {}
for key in batch[0].keys():
- if key.endswith('.ids'):
- prefix = key.split('.')[0]
- assert '{}.length'.format(prefix) in batch[0]
+ if key.endswith(".ids"):
+ prefix = key.split(".")[0]
+ length_key = f"{prefix}.length"
+ assert length_key in batch[0]
- processed_batch[f'{prefix}.ids'] = []
- processed_batch[f'{prefix}.length'] = []
+ # --- OLD SLOW IMPLEMENTATION (Python loop with manual .extend) ---
+ # processed_batch[f'{prefix}.ids'] = []
+ # processed_batch[f'{prefix}.length'] = []
+ # for sample in batch:
+ # processed_batch[f'{prefix}.ids'].extend(
+ # sample[f'{prefix}.ids'],
+ # )
+ # processed_batch[f'{prefix}.length'].append(
+ # sample[f'{prefix}.length'],
+ # )
- for sample in batch:
- processed_batch[f'{prefix}.ids'].extend(
- sample[f'{prefix}.ids'],
- )
- processed_batch[f'{prefix}.length'].append(
- sample[f'{prefix}.length'],
- )
+ # --- NEW OPTIMIZED IMPLEMENTATION (Books-Scale Ready) ---
+ """
+ Optimization Strategy: C-level Flattening via itertools.
+
+ Justification for Amazon Books scale:
+ 1. Avoiding Reallocations: Python's list.extend() repeatedly triggers
+ memory reallocation as the list grows. For large batches on a
+ 9-million-interaction dataset, this creates significant overhead.
+ 2. itertools.chain.from_iterable: This is implemented in C. It creates
+ a flat iterator over the sequence slices without creating
+ intermediate Python list objects, which is much faster.
+ 3. List Comprehension: Collecting lengths via a comprehension is
+ consistently faster than manual .append() calls in a for-loop.
+ """
+ # Efficiently flatten all sequence IDs into one long list
+ ids_iter = itertools.chain.from_iterable(s[key] for s in batch)
+ processed_batch[key] = torch.tensor(list(ids_iter), dtype=torch.long)
+ # Efficiently collect all lengths into a tensor
+ lengths_list = [s[length_key] for s in batch]
+ processed_batch[length_key] = torch.tensor(
+ lengths_list, dtype=torch.long
+ )
+
+ # Final conversion for any keys that might have missed the .ids check
for part, values in processed_batch.items():
- processed_batch[part] = torch.tensor(values, dtype=torch.long)
+ if not isinstance(processed_batch[part], torch.Tensor):
+ processed_batch[part] = torch.tensor(values, dtype=torch.long)
return processed_batch
diff --git a/src/irec/dataset/base.py b/src/irec/dataset/base.py
index b07d00c..5936a31 100644
--- a/src/irec/dataset/base.py
+++ b/src/irec/dataset/base.py
@@ -33,7 +33,8 @@ def num_items(self):
@property
def max_sequence_length(self):
return self._max_sequence_length
-
+
+
class BaseSequenceDataset(BaseDataset):
def __init__(
self,
@@ -52,7 +53,7 @@ def __init__(
self._max_sequence_length = max_sequence_length
@staticmethod
- def _create_sequences(data, max_sample_len): # TODO
+ def _create_sequences(data, max_sample_len): # TODO
user_sequences = []
item_sequences = []
@@ -61,10 +62,8 @@ def _create_sequences(data, max_sample_len): # TODO
max_sequence_length = 0
for sample in data:
- sample = sample.strip('\n').split(' ')
- item_ids = [int(item_id) for item_id in sample[1:]][
- -max_sample_len:
- ]
+ sample = sample.strip("\n").split(" ")
+ item_ids = [int(item_id) for item_id in sample[1:]][-max_sample_len:]
user_id = int(sample[0])
max_user_id = max(max_user_id, user_id)
@@ -92,12 +91,13 @@ def get_samplers(self):
@property
def meta(self):
return {
- 'num_users': self.num_users,
- 'num_items': self.num_items,
- 'max_sequence_length': self.max_sequence_length,
+ "num_users": self.num_users,
+ "num_items": self.num_items,
+ "max_sequence_length": self.max_sequence_length,
}
-
-class SequenceDataset(BaseDataset, config_name='sequence'):
+
+
+class SequenceDataset(BaseDataset, config_name="sequence"):
def __init__(
self,
train_sampler,
@@ -117,24 +117,24 @@ def __init__(
@classmethod
def create_from_config(cls, config, **kwargs):
data_dir_path = os.path.join(
- config['path_to_data_dir'],
- config['name'],
+ config["path_to_data_dir"],
+ config["name"],
)
common_params_for_creation = {
- 'dir_path': data_dir_path,
- 'max_sequence_length': config['max_sequence_length'],
- 'use_cached': config.get('use_cached', False),
+ "dir_path": data_dir_path,
+ "max_sequence_length": config["max_sequence_length"],
+ "use_cached": config.get("use_cached", False),
}
train_dataset, train_max_user_id, train_max_item_id, train_seq_len = (
- cls._create_dataset(part='train', **common_params_for_creation)
+ cls._create_dataset(part="train", **common_params_for_creation)
)
validation_dataset, valid_max_user_id, valid_max_item_id, valid_seq_len = (
- cls._create_dataset(part='valid', **common_params_for_creation)
+ cls._create_dataset(part="valid", **common_params_for_creation)
)
test_dataset, test_max_user_id, test_max_item_id, test_seq_len = (
- cls._create_dataset(part='test', **common_params_for_creation)
+ cls._create_dataset(part="test", **common_params_for_creation)
)
max_user_id = max(
@@ -145,11 +145,11 @@ def create_from_config(cls, config, **kwargs):
)
max_seq_len = max([train_seq_len, valid_seq_len, test_seq_len])
- logger.info('Train dataset size: {}'.format(len(train_dataset)))
- logger.info('Test dataset size: {}'.format(len(test_dataset)))
- logger.info('Max user id: {}'.format(max_user_id))
- logger.info('Max item id: {}'.format(max_item_id))
- logger.info('Max sequence length: {}'.format(max_seq_len))
+ logger.info("Train dataset size: {}".format(len(train_dataset)))
+ logger.info("Test dataset size: {}".format(len(test_dataset)))
+ logger.info("Max user id: {}".format(max_user_id))
+ logger.info("Max item id: {}".format(max_item_id))
+ logger.info("Max sequence length: {}".format(max_seq_len))
train_interactions = sum(
list(map(lambda x: len(x), train_dataset)),
@@ -161,15 +161,15 @@ def create_from_config(cls, config, **kwargs):
test_dataset,
) # each new interaction as a sample
logger.info(
- '{} dataset sparsity: {}'.format(
- config['name'],
+ "{} dataset sparsity: {}".format(
+ config["name"],
(train_interactions + valid_interactions + test_interactions)
/ max_user_id
/ max_item_id,
),
)
- samplers_config = config['samplers']
+ samplers_config = config["samplers"]
train_sampler = TrainSampler.create_from_config(
samplers_config,
dataset=train_dataset,
@@ -206,31 +206,30 @@ def _create_dataset(
max_sequence_length=None,
use_cached=False,
):
- cache_path = os.path.join(dir_path, '{}.pkl'.format(part))
+ cache_path = os.path.join(dir_path, "{}.pkl".format(part))
if use_cached and os.path.exists(cache_path):
- logger.info(
- 'Loading cached dataset from {}'.format(cache_path)
- )
- with open(cache_path, 'rb') as f:
+ logger.info("Loading cached dataset from {}".format(cache_path))
+ with open(cache_path, "rb") as f:
return pickle.load(f)
-
- return cls._build_and_cache_dataset(dir_path, part, max_sequence_length, cache_path, use_cached)
+ return cls._build_and_cache_dataset(
+ dir_path, part, max_sequence_length, cache_path, use_cached
+ )
@classmethod
- def _build_and_cache_dataset(cls, dir_path, part, max_sequence_length, cache_path, use_cached):
+ def _build_and_cache_dataset(
+ cls, dir_path, part, max_sequence_length, cache_path, use_cached
+ ):
logger.info(
- 'Cache is forcefully ignored.'
+ "Cache is forcefully ignored."
if not use_cached
- else 'No cached dataset has been found.'
- )
- dataset_path = os.path.join(dir_path, '{}.txt'.format(part))
- logger.info(
- 'Creating a dataset from {}...'.format(dataset_path)
+ else "No cached dataset has been found."
)
+ dataset_path = os.path.join(dir_path, "{}.txt".format(part))
+ logger.info("Creating a dataset from {}...".format(dataset_path))
- with open(dataset_path, 'r') as f:
+ with open(dataset_path, "r") as f:
data = f.readlines()
sequence_info = cls._create_sequences(data, max_sequence_length)
@@ -246,22 +245,22 @@ def _build_and_cache_dataset(cls, dir_path, part, max_sequence_length, cache_pat
for user_id, item_ids in zip(user_sequences, item_sequences):
dataset.append(
{
- 'user.ids': [user_id],
- 'user.length': 1,
- 'item.ids': item_ids,
- 'item.length': len(item_ids),
+ "user.ids": [user_id],
+ "user.length": 1,
+ "item.ids": item_ids,
+ "item.length": len(item_ids),
},
)
- logger.info('{} dataset size: {}'.format(part, len(dataset)))
+ logger.info("{} dataset size: {}".format(part, len(dataset)))
logger.info(
- '{} dataset max sequence length: {}'.format(
+ "{} dataset max sequence length: {}".format(
part,
max_sequence_len,
),
)
- with open(cache_path, 'wb') as dataset_file:
+ with open(cache_path, "wb") as dataset_file:
pickle.dump(
(dataset, max_user_id, max_item_id, max_sequence_len),
dataset_file,
@@ -270,7 +269,7 @@ def _build_and_cache_dataset(cls, dir_path, part, max_sequence_length, cache_pat
return dataset, max_user_id, max_item_id, max_sequence_len
@staticmethod
- def _create_sequences(data, max_sample_len): # TODO
+ def _create_sequences(data, max_sample_len): # TODO
user_sequences = []
item_sequences = []
@@ -279,10 +278,8 @@ def _create_sequences(data, max_sample_len): # TODO
max_sequence_length = 0
for sample in data:
- sample = sample.strip('\n').split(' ')
- item_ids = [int(item_id) for item_id in sample[1:]][
- -max_sample_len:
- ]
+ sample = sample.strip("\n").split(" ")
+ item_ids = [int(item_id) for item_id in sample[1:]][-max_sample_len:]
user_id = int(sample[0])
max_user_id = max(max_user_id, user_id)
@@ -307,7 +304,8 @@ def get_samplers(self):
self._test_sampler,
)
-class GraphDataset(BaseDataset, config_name='graph'):
+
+class GraphDataset(BaseDataset, config_name="graph"):
def __init__(
self,
dataset,
@@ -315,7 +313,7 @@ def __init__(
use_train_data_only=True,
use_user_graph=False,
use_item_graph=False,
- neighborhood_size=None
+ neighborhood_size=None,
):
self._dataset = dataset
self._graph_dir_path = graph_dir_path
@@ -327,11 +325,11 @@ def __init__(
self._num_users = dataset.num_users
self._num_items = dataset.num_items
- train_sampler, validation_sampler, test_sampler = (
- dataset.get_samplers()
- )
+ train_sampler, validation_sampler, test_sampler = dataset.get_samplers()
- interactions_data = self._collect_interactions(train_sampler, validation_sampler, test_sampler)
+ interactions_data = self._collect_interactions(
+ train_sampler, validation_sampler, test_sampler
+ )
train_interactions = interactions_data["train_interactions"]
train_user_interactions = interactions_data["train_user_interactions"]
train_item_interactions = interactions_data["train_item_interactions"]
@@ -343,49 +341,63 @@ def __init__(
self._train_item_interactions = np.array(train_item_interactions)
self._graph = self._build_or_load_bipartite_graph(
- graph_dir_path,
- train_user_interactions,
- train_item_interactions
+ graph_dir_path, train_user_interactions, train_item_interactions
)
-
self._user_graph = (
self._build_or_load_similarity_graph(
- 'user',
- self._train_user_interactions,
- self._train_item_interactions,
- train_item_2_users,
- train_user_2_items
- )
- if self._use_user_graph
+ "user",
+ self._train_user_interactions,
+ self._train_item_interactions,
+ train_item_2_users,
+ train_user_2_items,
+ )
+ if self._use_user_graph
else None
)
self._item_graph = (
self._build_or_load_similarity_graph(
- 'item',
- self._train_user_interactions,
- self._train_item_interactions,
- train_item_2_users,
- train_user_2_items
- )
- if self._use_item_graph
+ "item",
+ self._train_user_interactions,
+ self._train_item_interactions,
+ train_item_2_users,
+ train_user_2_items,
+ )
+ if self._use_item_graph
else None
)
def _build_or_load_similarity_graph(
- self,
- entity_type,
- train_user_interactions,
- train_item_interactions,
- train_item_2_users,
- train_user_2_items
+ self,
+ entity_type,
+ train_user_interactions,
+ train_item_interactions,
+ train_item_2_users,
+ train_user_2_items,
):
- if entity_type not in ['user', 'item']:
+ if entity_type not in ["user", "item"]:
raise ValueError("entity_type must be either 'user' or 'item'")
+ # have to delete and replace to not delete npz each time manually
+ # path_to_graph = os.path.join(self._graph_dir_path, '{}_graph.npz'.format(entity_type))
+
+ # instead better use such construction
+
+ # neighborhood_size
+ # The neighborhood_size is a filter that constrains the number of edges for each user or
+ # item node in the graph.
+ # k=50 implies that for each user, we find all possible neighbors, sort them based on
+ # co-occurrence counts, and keep only the top 50. All other connections are removed from the graph.
+ k_suffix = (
+ f"k{self._neighborhood_size}"
+ if self._neighborhood_size is not None
+ else "full"
+ )
+ train_suffix = "trainOnly" if self._use_train_data_only else "withValTest"
+ filename = f"{entity_type}_graph_{k_suffix}_{train_suffix}.npz"
+ path_to_graph = os.path.join(self._graph_dir_path, filename)
- path_to_graph = os.path.join(self._graph_dir_path, '{}_graph.npz'.format(entity_type))
- is_user_graph = (entity_type == 'user')
+ is_user_graph = entity_type == "user"
num_entities = self._num_users if is_user_graph else self._num_items
if os.path.exists(path_to_graph):
@@ -394,19 +406,28 @@ def _build_or_load_similarity_graph(
interactions_fst = []
interactions_snd = []
visited_user_item_pairs = set()
- visited_entity_pairs = set()
+ # have to delete cause
+ # 3.2 Graph Construction
+ # User-user/item-item graph
+ # ..the weight of each edge denotes the number of co-action behaviors between user i and user j
+
+ # visited_entity_pairs = set()
for user_id, item_id in tqdm(
zip(train_user_interactions, train_item_interactions),
- desc='Building {}-{} graph'.format(entity_type, entity_type) # TODO need?
+ desc="Building {}-{} graph".format(
+ entity_type, entity_type
+ ), # TODO need?
):
if (user_id, item_id) in visited_user_item_pairs:
continue
- visited_user_item_pairs.add((user_id, item_id))
+ visited_user_item_pairs.add((user_id, item_id))
# TODO look here at review
source_entity = user_id if is_user_graph else item_id
- connection_map = train_item_2_users if is_user_graph else train_user_2_items
+ connection_map = (
+ train_item_2_users if is_user_graph else train_user_2_items
+ )
connection_point = item_id if is_user_graph else user_id
for connected_entity in connection_map[connection_point]:
@@ -414,38 +435,38 @@ def _build_or_load_similarity_graph(
continue
pair_key = (source_entity, connected_entity)
- if pair_key in visited_entity_pairs:
- continue
-
- visited_entity_pairs.add(pair_key)
+ # if pair_key in visited_entity_pairs:
+ # continue
+
+ # visited_entity_pairs.add(pair_key)
interactions_fst.append(source_entity)
interactions_snd.append(connected_entity)
connections = csr_matrix(
- (np.ones(len(interactions_fst)),
- (
- interactions_fst,
- interactions_snd
- )
- ),
- shape=(num_entities + 2, num_entities + 2)
+ (np.ones(len(interactions_fst)), (interactions_fst, interactions_snd)),
+ shape=(num_entities + 2, num_entities + 2),
)
if self._neighborhood_size is not None:
- connections = self._filter_matrix_by_top_k(connections, self._neighborhood_size)
+ connections = self._filter_matrix_by_top_k(
+ connections, self._neighborhood_size
+ )
graph_matrix = self.get_sparse_graph_layer(
- connections,
- num_entities + 2,
- num_entities + 2,
- biparite=False
+ connections, num_entities + 2, num_entities + 2, biparite=False
)
sp.save_npz(path_to_graph, graph_matrix)
return self._convert_sp_mat_to_sp_tensor(graph_matrix).coalesce().to(DEVICE)
- def _build_or_load_bipartite_graph(self, graph_dir_path, train_user_interactions, train_item_interactions):
- path_to_graph = os.path.join(graph_dir_path, 'general_graph.npz')
+ def _build_or_load_bipartite_graph(
+ self, graph_dir_path, train_user_interactions, train_item_interactions
+ ):
+ # path_to_graph = os.path.join(graph_dir_path, 'general_graph.npz')
+ train_suffix = "trainOnly" if self._use_train_data_only else "withValTest"
+ filename = f"general_graph_{train_suffix}.npz"
+ path_to_graph = os.path.join(graph_dir_path, filename)
+
if os.path.exists(path_to_graph):
graph_matrix = sp.load_npz(path_to_graph)
else:
@@ -481,8 +502,8 @@ def _collect_interactions(self, train_sampler, validation_sampler, test_sampler)
for sampler in samplers_to_process:
for sample in sampler.dataset:
- user_id = sample['user.ids'][0]
- for item_id in sample['item.ids']:
+ user_id = sample["user.ids"][0]
+ for item_id in sample["item.ids"]:
if (user_id, item_id) not in visited_user_item_pairs:
train_interactions.append((user_id, item_id))
train_user_interactions.append(user_id)
@@ -492,7 +513,7 @@ def _collect_interactions(self, train_sampler, validation_sampler, test_sampler)
train_item_2_users[item_id].add(user_id)
visited_user_item_pairs.add((user_id, item_id))
-
+
return {
"train_interactions": train_interactions,
"train_user_interactions": train_user_interactions,
@@ -503,15 +524,49 @@ def _collect_interactions(self, train_sampler, validation_sampler, test_sampler)
@classmethod
def create_from_config(cls, config):
- dataset = BaseDataset.create_from_config(config['dataset'])
+ dataset = BaseDataset.create_from_config(config["dataset"])
return cls(
dataset=dataset,
- graph_dir_path=config['graph_dir_path'],
- use_user_graph=config.get('use_user_graph', False),
- use_item_graph=config.get('use_item_graph', False),
- neighborhood_size=config.get('neighborhood_size', None),
+ graph_dir_path=config["graph_dir_path"],
+ use_user_graph=config.get("use_user_graph", False),
+ use_item_graph=config.get("use_item_graph", False),
+ neighborhood_size=config.get("neighborhood_size", None),
)
+ # @staticmethod
+ # def get_sparse_graph_layer(
+ # sparse_matrix,
+ # fst_dim,
+ # snd_dim,
+ # biparite=False,
+ # ):
+ # if not biparite:
+ # adj_mat = sparse_matrix.tocsr()
+ # else:
+ # R = sparse_matrix.tocsr()
+
+ # upper_right = R
+ # lower_left = R.T
+
+ # upper_left = sp.csr_matrix((fst_dim, fst_dim))
+ # lower_right = sp.csr_matrix((snd_dim, snd_dim))
+
+ # adj_mat = sp.bmat([
+ # [upper_left, upper_right],
+ # [lower_left, lower_right]
+ # ])
+ # assert adj_mat.shape == (fst_dim + snd_dim, fst_dim + snd_dim), (
+ # f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}"
+ # )
+
+ # rowsum = np.array(adj_mat.sum(1))
+ # d_inv = np.power(rowsum, -0.5).flatten()
+ # d_inv[np.isinf(d_inv)] = 0.
+ # d_mat_inv = sp.diags(d_inv)
+
+ # norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv)
+ # return norm_adj.tocsr()
+
@staticmethod
def get_sparse_graph_layer(
sparse_matrix,
@@ -523,27 +578,52 @@ def get_sparse_graph_layer(
adj_mat = sparse_matrix.tocsr()
else:
R = sparse_matrix.tocsr()
-
+
upper_right = R
lower_left = R.T
-
+
upper_left = sp.csr_matrix((fst_dim, fst_dim))
lower_right = sp.csr_matrix((snd_dim, snd_dim))
-
- adj_mat = sp.bmat([
- [upper_left, upper_right],
- [lower_left, lower_right]
- ])
- assert adj_mat.shape == (fst_dim + snd_dim, fst_dim + snd_dim), (
- f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}"
- )
+
+ adj_mat = sp.bmat([[upper_left, upper_right], [lower_left, lower_right]])
+ assert adj_mat.shape == (
+ fst_dim + snd_dim,
+ fst_dim + snd_dim,
+ ), f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}"
+
+ # --- OLD IMPLEMENTATION (Slow & Memory Intensive for Large Corpus) ---
+ # rowsum = np.array(adj_mat.sum(1))
+ # d_inv = np.power(rowsum, -0.5).flatten()
+ # d_inv[np.isinf(d_inv)] = 0.
+ # d_mat_inv = sp.diags(d_inv)
+ # norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv)
+ # return norm_adj.tocsr()
+
+ # --- NEW OPTIMIZED IMPLEMENTATION ---
+ """
+ Optimization Strategy: Vectorized Symmetric Normalization (D^-0.5 * A * D^-0.5).
- rowsum = np.array(adj_mat.sum(1))
- d_inv = np.power(rowsum, -0.5).flatten()
- d_inv[np.isinf(d_inv)] = 0.
- d_mat_inv = sp.diags(d_inv)
+ Justification for Amazon Books Scale:
+ 1. Memory Efficiency: Creating an explicit diagonal matrix 'd_mat_inv'
+ (size N x N) via sp.diags is redundant. For 800k+ nodes, this consumes
+ significant RAM and creates heavy intermediate objects.
+ 2. Computational Speed: Traditional matrix-matrix multiplication (.dot)
+ in Scipy sparse has higher overhead compared to row/column-wise scaling.
+ 3. Implementation: We perform element-wise multiplication of the sparse
+ matrix by 1D degree vectors. Multiplying a sparse matrix by a column
+ vector scales rows, and by a row vector scales columns.
- norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv)
+ This results in an identical Laplacian matrix but is calculated in O(E)
+ time with minimal memory footprint, where E is the number of edges.
+ """
+ rowsum = np.array(adj_mat.sum(1)).flatten()
+ d_inv = np.power(rowsum, -0.5)
+ d_inv[np.isinf(d_inv)] = 0.0
+
+ # Scaling rows: multiply by column vector [N, 1]
+ # Scaling columns: multiply by row vector [1, N]
+ norm_adj = adj_mat.multiply(d_inv[:, np.newaxis]).multiply(d_inv)
+
return norm_adj.tocsr()
@staticmethod
@@ -557,18 +637,52 @@ def _convert_sp_mat_to_sp_tensor(X):
@staticmethod
def _filter_matrix_by_top_k(matrix, k):
- mat = matrix.tolil()
+ # --- OLD IMPLEMENTATION (Extremely slow conversion to LIL for large datasets) ---
+ # mat = matrix.tolil()
+ # for i in range(mat.shape[0]):
+ # if len(mat.rows[i]) <= k:
+ # continue
+ # data = np.array(mat.data[i])
+ # top_k_indices = np.argpartition(data, -k)[-k:]
+ # mat.data[i] = [mat.data[i][j] for j in top_k_indices]
+ # mat.rows[i] = [mat.rows[i][j] for j in top_k_indices]
+ # return mat.tocsr()
+
+ # --- NEW OPTIMIZED IMPLEMENTATION ---
+ """
+ Optimization Strategy: Direct CSR Array Manipulation with NumPy Partitioning.
+
+ Justification for Amazon Books Scale:
+ 1. Avoids LIL conversion: Converting a 450k x 300k matrix to LIL format
+ (List of Lists) is extremely memory-intensive and slow.
+ 2. In-place filtering: By accessing the CSR 'data' and 'indptr' arrays directly,
+ we perform the Top-K filtering with zero additional memory allocation
+ for the matrix structure itself.
+ 3. Algorithmic Speed: np.partition finds the threshold value in O(n) average
+ time. Slicing the underlying NumPy arrays is performed at near-C speed,
+ making this orders of magnitude faster than Python-level list operations.
+ """
+ mat = matrix.tocsr()
for i in range(mat.shape[0]):
- if len(mat.rows[i]) <= k:
- continue
- data = np.array(mat.data[i])
-
- top_k_indices = np.argpartition(data, -k)[-k:]
- mat.data[i] = [mat.data[i][j] for j in top_k_indices]
- mat.rows[i] = [mat.rows[i][j] for j in top_k_indices]
+ start = mat.indptr[i]
+ end = mat.indptr[i + 1]
+
+ # Only process rows that actually exceed the neighborhood size
+ if end - start > k:
+ row_slice = mat.data[start:end]
+
+ # Find the threshold value (the k-th largest element)
+ # np.partition is faster than a full sort: it puts the top-k values at the end
+ threshold = np.partition(row_slice, -k)[-k]
+
+ # Effectively prune edges by zeroing out everything below the threshold
+ # This keeps exactly k (or slightly more if there are ties) elements
+ row_slice[row_slice < threshold] = 0
- return mat.tocsr()
+ # Post-processing: remove the explicitly zeroed elements from the sparse structure
+ mat.eliminate_zeros()
+ return mat
def get_samplers(self):
return self._dataset.get_samplers()
@@ -576,50 +690,52 @@ def get_samplers(self):
@property
def meta(self):
meta = {
- 'user_graph': self._user_graph,
- 'item_graph': self._item_graph,
- 'graph': self._graph,
+ "user_graph": self._user_graph,
+ "item_graph": self._item_graph,
+ "graph": self._graph,
**self._dataset.meta,
}
return meta
-class ScientificDataset(BaseSequenceDataset, config_name='scientific'):
+class ScientificDataset(BaseSequenceDataset, config_name="scientific"):
@classmethod
def create_from_config(cls, config, **kwargs):
data_dir_path = os.path.join(
- config['path_to_data_dir'],
- config['name'],
+ config["path_to_data_dir"],
+ config["name"],
)
- max_sequence_length = config['max_sequence_length']
+ max_sequence_length = config["max_sequence_length"]
- dataset_path = os.path.join(data_dir_path, '{}.txt'.format('all_data'))
- with open(dataset_path, 'r') as f:
+ dataset_path = os.path.join(data_dir_path, "{}.txt".format("all_data"))
+ with open(dataset_path, "r") as f:
lines = f.readlines()
- datasets, max_user_id, max_item_id = cls._parse_and_split_data(lines, max_sequence_length)
+ datasets, max_user_id, max_item_id = cls._parse_and_split_data(
+ lines, max_sequence_length
+ )
- train_dataset = datasets['train']
- validation_dataset = datasets['validation']
- test_dataset = datasets['test']
+ train_dataset = datasets["train"]
+ validation_dataset = datasets["validation"]
+ test_dataset = datasets["test"]
cls._log_stats(
- train_dataset,
- test_dataset,
- max_user_id,
- max_item_id,
- max_sequence_length,
- config['name']
- )
+ train_dataset,
+ test_dataset,
+ max_user_id,
+ max_item_id,
+ max_sequence_length,
+ config["name"],
+ )
train_sampler, validation_sampler, test_sampler = cls._create_samplers(
- config['samplers'],
- train_dataset,
- validation_dataset,
- test_dataset,
- max_user_id,
- max_item_id
+ config["samplers"],
+ train_dataset,
+ validation_dataset,
+ test_dataset,
+ max_user_id,
+ max_item_id,
)
return cls(
@@ -630,71 +746,99 @@ def create_from_config(cls, config, **kwargs):
num_items=max_item_id,
max_sequence_length=max_sequence_length,
)
-
+
@staticmethod
- def _create_samplers(sampler_config, train_dataset, validation_dataset, test_dataset, num_users, num_items):
+ def _create_samplers(
+ sampler_config,
+ train_dataset,
+ validation_dataset,
+ test_dataset,
+ num_users,
+ num_items,
+ ):
train_sampler = TrainSampler.create_from_config(
- sampler_config, dataset=train_dataset, num_users=num_users, num_items=num_items
+ sampler_config,
+ dataset=train_dataset,
+ num_users=num_users,
+ num_items=num_items,
)
validation_sampler = EvalSampler.create_from_config(
- sampler_config, dataset=validation_dataset, num_users=num_users, num_items=num_items
+ sampler_config,
+ dataset=validation_dataset,
+ num_users=num_users,
+ num_items=num_items,
)
test_sampler = EvalSampler.create_from_config(
- sampler_config, dataset=test_dataset, num_users=num_users, num_items=num_items
+ sampler_config,
+ dataset=test_dataset,
+ num_users=num_users,
+ num_items=num_items,
)
return train_sampler, validation_sampler, test_sampler
-
+
@staticmethod
- def _log_stats(train_dataset, test_dataset, max_user_id, max_item_id, max_len, name):
- logger.info('Train dataset size: {}'.format(len(train_dataset)))
- logger.info('Test dataset size: {}'.format(len(test_dataset)))
- logger.info('Max user id: {}'.format(max_user_id))
- logger.info('Max item id: {}'.format(max_item_id))
- logger.info('Max sequence length: {}'.format(max_len))
-
+ def _log_stats(
+ train_dataset, test_dataset, max_user_id, max_item_id, max_len, name
+ ):
+ logger.info("Train dataset size: {}".format(len(train_dataset)))
+ logger.info("Test dataset size: {}".format(len(test_dataset)))
+ logger.info("Max user id: {}".format(max_user_id))
+ logger.info("Max item id: {}".format(max_item_id))
+ logger.info("Max sequence length: {}".format(max_len))
+
if max_user_id > 0 and max_item_id > 0:
- sparsity = (len(train_dataset) + len(test_dataset)) / max_user_id / max_item_id
- logger.info('{} dataset sparsity: {}'.format(name, sparsity))
+ sparsity = (
+ (len(train_dataset) + len(test_dataset)) / max_user_id / max_item_id
+ )
+ logger.info("{} dataset sparsity: {}".format(name, sparsity))
@staticmethod
def _parse_and_split_data(lines, max_sequence_length):
- datasets = {'train': [], 'validation': [], 'test': []}
+ datasets = {"train": [], "validation": [], "test": []}
- user_ids, item_sequences, max_user_id, max_item_id, _ = \
+ user_ids, item_sequences, max_user_id, max_item_id, _ = (
BaseSequenceDataset._create_sequences(lines)
+ )
for user_id, item_ids in zip(user_ids, item_sequences):
-
+
assert len(item_ids) >= 5
split_slices = {
- 'train': slice(None, -2),
- 'validation': slice(None, -1),
- 'test': slice(None, None)
+ "train": slice(None, -2),
+ "validation": slice(None, -1),
+ "test": slice(None, None),
}
-
+
for part_name, part_slice in split_slices.items():
sliced_items = item_ids[part_slice]
final_items = sliced_items[-max_sequence_length:]
-
- assert len(item_ids[-max_sequence_length:]) == len(set(item_ids[-max_sequence_length:]),)
- datasets[part_name].append({
- 'user.ids': [user_id], 'user.length': 1,
- 'item.ids': final_items, 'item.length': len(final_items),
- })
+ assert len(item_ids[-max_sequence_length:]) == len(
+ set(item_ids[-max_sequence_length:]),
+ )
+
+ datasets[part_name].append(
+ {
+ "user.ids": [user_id],
+ "user.length": 1,
+ "item.ids": final_items,
+ "item.length": len(final_items),
+ }
+ )
return datasets, max_user_id, max_item_id
-class MCLSRDataset(BaseSequenceDataset, config_name='mclsr'):
+
+class MCLSRDataset(BaseSequenceDataset, config_name="mclsr"):
@staticmethod
def _create_sequences_from_file(filepath, max_len=None):
sequences = {}
max_user, max_item = 0, 0
-
- with open(filepath, 'r') as f:
+
+ with open(filepath, "r") as f:
for line in f:
- parts = line.strip().split(' ')
+ parts = line.strip().split(" ")
user_id = int(parts[0])
item_ids = [int(i) for i in parts[1:]]
if max_len:
@@ -704,40 +848,98 @@ def _create_sequences_from_file(filepath, max_len=None):
if item_ids:
max_item = max(max_item, max(item_ids))
return sequences, max_user, max_item
-
+
@classmethod
def _create_evaluation_sets(cls, data_dir, max_seq_len):
- valid_hist, u2, i2 = cls._create_sequences_from_file(os.path.join(data_dir, 'valid_history.txt'), max_seq_len)
- valid_trg, u3, i3 = cls._create_sequences_from_file(os.path.join(data_dir, 'valid_target.txt'))
+ valid_hist, u2, i2 = cls._create_sequences_from_file(
+ os.path.join(data_dir, "valid_history.txt"), max_seq_len
+ )
+ valid_trg, u3, i3 = cls._create_sequences_from_file(
+ os.path.join(data_dir, "valid_target.txt")
+ )
- validation_dataset = [{'user.ids': [uid], 'history': valid_hist[uid], 'target': valid_trg[uid]} for uid in valid_hist if uid in valid_trg]
-
- test_hist, u4, i4 = cls._create_sequences_from_file(os.path.join(data_dir, 'test_history.txt'), max_seq_len)
- test_trg, u5, i5 = cls._create_sequences_from_file(os.path.join(data_dir, 'test_target.txt'))
+ validation_dataset = [
+ {"user.ids": [uid], "history": valid_hist[uid], "target": valid_trg[uid]}
+ for uid in valid_hist
+ if uid in valid_trg
+ ]
- test_dataset = [{'user.ids': [uid], 'history': test_hist[uid], 'target': test_trg[uid]} for uid in test_hist if uid in test_trg]
+ test_hist, u4, i4 = cls._create_sequences_from_file(
+ os.path.join(data_dir, "test_history.txt"), max_seq_len
+ )
+ test_trg, u5, i5 = cls._create_sequences_from_file(
+ os.path.join(data_dir, "test_target.txt")
+ )
+
+ test_dataset = [
+ {"user.ids": [uid], "history": test_hist[uid], "target": test_trg[uid]}
+ for uid in test_hist
+ if uid in test_trg
+ ]
- return validation_dataset, test_dataset, max(u2, u3, u4, u5), max(i2, i3, i4, i5)
+ return (
+ validation_dataset,
+ test_dataset,
+ max(u2, u3, u4, u5),
+ max(i2, i3, i4, i5),
+ )
@classmethod
def create_from_config(cls, config, **kwargs):
- data_dir = os.path.join(config['path_to_data_dir'], config['name'])
- max_seq_len = config.get('max_sequence_length')
+ data_dir = os.path.join(config["path_to_data_dir"], config["name"])
+ max_seq_len = config.get("max_sequence_length")
- train_sequences, u1, i1 = cls._create_sequences_from_file(os.path.join(data_dir, 'train_mclsr.txt'), max_seq_len)
- train_dataset = [{'user.ids': [uid], 'user.length': 1, 'item.ids': seq, 'item.length': len(seq)} for uid, seq in train_sequences.items()]
+ train_sequences, u1, i1 = cls._create_sequences_from_file(
+ os.path.join(data_dir, "train_mclsr.txt"), max_seq_len
+ )
+ train_dataset = [
+ {
+ "user.ids": [uid],
+ "user.length": 1,
+ "item.ids": seq,
+ "item.length": len(seq),
+ }
+ for uid, seq in train_sequences.items()
+ ]
user_to_all_seen_items = defaultdict(set)
- for sample in train_dataset: user_to_all_seen_items[sample['user.ids'][0]].update(sample['item.ids'])
- kwargs['user_to_all_seen_items'] = user_to_all_seen_items
+ for sample in train_dataset:
+ user_to_all_seen_items[sample["user.ids"][0]].update(sample["item.ids"])
+ kwargs["user_to_all_seen_items"] = user_to_all_seen_items
- validation_dataset, test_dataset, u_eval, i_eval = cls._create_evaluation_sets(data_dir, max_seq_len)
+ validation_dataset, test_dataset, u_eval, i_eval = cls._create_evaluation_sets(
+ data_dir, max_seq_len
+ )
num_users = max(u1, u_eval)
num_items = max(i1, i_eval)
-
- train_sampler = TrainSampler.create_from_config(config['samplers'], dataset=train_dataset, num_users=num_users, num_items=num_items, **kwargs)
- validation_sampler = EvalSampler.create_from_config(config['samplers'], dataset=validation_dataset, num_users=num_users, num_items=num_items, **kwargs)
- test_sampler = EvalSampler.create_from_config(config['samplers'], dataset=test_dataset, num_users=num_users, num_items=num_items, **kwargs)
- return cls(train_sampler, validation_sampler, test_sampler, num_users, num_items, max_seq_len)
-
+ train_sampler = TrainSampler.create_from_config(
+ config["samplers"],
+ dataset=train_dataset,
+ num_users=num_users,
+ num_items=num_items,
+ **kwargs,
+ )
+ validation_sampler = EvalSampler.create_from_config(
+ config["samplers"],
+ dataset=validation_dataset,
+ num_users=num_users,
+ num_items=num_items,
+ **kwargs,
+ )
+ test_sampler = EvalSampler.create_from_config(
+ config["samplers"],
+ dataset=test_dataset,
+ num_users=num_users,
+ num_items=num_items,
+ **kwargs,
+ )
+
+ return cls(
+ train_sampler,
+ validation_sampler,
+ test_sampler,
+ num_users,
+ num_items,
+ max_seq_len,
+ )
diff --git a/src/irec/dataset/negative_samplers/__init__.py b/src/irec/dataset/negative_samplers/__init__.py
index 04b82f9..654f7e1 100644
--- a/src/irec/dataset/negative_samplers/__init__.py
+++ b/src/irec/dataset/negative_samplers/__init__.py
@@ -3,7 +3,7 @@
from .random import RandomNegativeSampler
__all__ = [
- 'BaseNegativeSampler',
- 'PopularNegativeSampler',
- 'RandomNegativeSampler',
+ "BaseNegativeSampler",
+ "PopularNegativeSampler",
+ "RandomNegativeSampler",
]
diff --git a/src/irec/dataset/negative_samplers/base.py b/src/irec/dataset/negative_samplers/base.py
index b4d1224..3360179 100644
--- a/src/irec/dataset/negative_samplers/base.py
+++ b/src/irec/dataset/negative_samplers/base.py
@@ -11,8 +11,8 @@ def __init__(self, dataset, num_users, num_items):
self._seen_items = defaultdict(set)
for sample in self._dataset:
- user_id = sample['user.ids'][0]
- items = list(sample['item.ids'])
+ user_id = sample["user.ids"][0]
+ items = list(sample["item.ids"])
self._seen_items[user_id].update(items)
def generate_negative_samples(self, sample, num_negatives):
diff --git a/src/irec/dataset/negative_samplers/popular.py b/src/irec/dataset/negative_samplers/popular.py
index ca91bf6..f04822e 100644
--- a/src/irec/dataset/negative_samplers/popular.py
+++ b/src/irec/dataset/negative_samplers/popular.py
@@ -1,44 +1,70 @@
+import numpy as np
from irec.dataset.negative_samplers.base import BaseNegativeSampler
-
from collections import Counter
-class PopularNegativeSampler(BaseNegativeSampler, config_name='popular'):
+class PopularNegativeSampler(BaseNegativeSampler, config_name="popular"):
def __init__(self, dataset, num_users, num_items):
super().__init__(
dataset=dataset,
num_users=num_users,
num_items=num_items,
)
-
- self._popular_items = self._items_by_popularity()
+ self._item_ids, self._probs = self._calculate_item_probabilities()
@classmethod
def create_from_config(cls, _, **kwargs):
return cls(
- dataset=kwargs['dataset'],
- num_users=kwargs['num_users'],
- num_items=kwargs['num_items'],
+ dataset=kwargs["dataset"],
+ num_users=kwargs["num_users"],
+ num_items=kwargs["num_items"],
)
- def _items_by_popularity(self):
- popularity = Counter()
-
+ def _calculate_item_probabilities(self):
+ """
+ Calculates sampling probabilities proportional to item popularity.
+ This distribution is required to provide non-zero p_j values for LogQ correction.
+ """
+ counts = Counter()
for sample in self._dataset:
- for item_id in sample['item.ids']:
- popularity[item_id] += 1
+ for item_id in sample["item.ids"]:
+ counts[item_id] += 1
+
+ items = np.array(list(counts.keys()))
+ freqs = np.array(list(counts.values()), dtype=np.float32)
+ probabilities = freqs / freqs.sum()
- popular_items = sorted(popularity, key=popularity.get, reverse=True)
- return popular_items
+ return items, probabilities
def generate_negative_samples(self, sample, num_negatives):
- user_id = sample['user.ids'][0]
- popularity_idx = 0
- negatives = []
+ """
+ Stochastic sampling proportional to popularity.
+
+ Justification:
+ The original implementation always picked the same Top-K popular items.
+ For LogQ correction (Yi et al., Google 2019), we need a stochastic
+ sampling process where p_j > 0 for all items in the distribution.
+ This allows the model to see a diverse set of negatives across epochs
+ while penalizing popular items correctly via the log(p_j) term.
+ """
+ user_id = sample["user.ids"][0]
+ seen = self._seen_items[user_id]
+
+ negatives = set()
while len(negatives) < num_negatives:
- negative_idx = self._popular_items[popularity_idx]
- if negative_idx not in self._seen_items[user_id]:
- negatives.append(negative_idx)
- popularity_idx += 1
+ # Sample items based on the pre-calculated frequency distribution
+ sampled_ids = np.random.choice(
+ self._item_ids,
+ size=num_negatives - len(negatives),
+ p=self._probs,
+ replace=True,
+ )
+
+ # Filter out items already seen by the user (False Negatives)
+ for idx in sampled_ids:
+ if idx not in seen:
+ negatives.add(idx)
+ if len(negatives) == num_negatives:
+ break
- return negatives
+ return list(negatives)
diff --git a/src/irec/dataset/negative_samplers/random.py b/src/irec/dataset/negative_samplers/random.py
index 79e245a..d7365a8 100644
--- a/src/irec/dataset/negative_samplers/random.py
+++ b/src/irec/dataset/negative_samplers/random.py
@@ -3,26 +3,41 @@
import numpy as np
-class RandomNegativeSampler(BaseNegativeSampler, config_name='random'):
+class RandomNegativeSampler(BaseNegativeSampler, config_name="random"):
@classmethod
def create_from_config(cls, _, **kwargs):
return cls(
- dataset=kwargs['dataset'],
- num_users=kwargs['num_users'],
- num_items=kwargs['num_items'],
+ dataset=kwargs["dataset"],
+ num_users=kwargs["num_users"],
+ num_items=kwargs["num_items"],
)
def generate_negative_samples(self, sample, num_negatives):
- user_id = sample['user.ids'][0]
- all_items = list(range(1, self._num_items + 1))
- np.random.shuffle(all_items)
-
- negatives = []
- running_idx = 0
- while len(negatives) < num_negatives and running_idx < len(all_items):
- negative_idx = all_items[running_idx]
- if negative_idx not in self._seen_items[user_id]:
- negatives.append(negative_idx)
- running_idx += 1
-
- return negatives
+ """
+ Optimized via Rejection Sampling (O(k) complexity).
+
+ Mathematical Proof of Equivalence:
+ Let V be the set of all items and H be the user's history.
+ We need a uniform random sample S ⊂ (V \ H) such that |S| = k.
+
+ 1. Shuffle Approach (Previous): Generates a random permutation of V,
+ then filters H. Complexity: O(|V|).
+ 2. Rejection Sampling (Current): Independently draws i ~ Uniform(V)
+ and accepts i if i ∉ H and i ∉ S. Complexity: O(k * 1/p),
+ where p = (|V| - |H|) / |V|.
+
+ Since |H| << |V|, the probability p ≈ 1, making the expected complexity
+ effectively O(k). Both methods yield an identical uniform distribution
+ over the valid item space.
+ """
+ user_id = sample["user.ids"][0]
+ seen = self._seen_items[user_id]
+
+ negatives = set()
+ while len(negatives) < num_negatives:
+ # Drawing a random index is O(1)
+ negative_idx = np.random.randint(1, self._num_items + 1)
+ if negative_idx not in seen:
+ negatives.add(negative_idx)
+
+ return list(negatives)
diff --git a/src/irec/dataset/samplers/base.py b/src/irec/dataset/samplers/base.py
index 158bede..7017b71 100644
--- a/src/irec/dataset/samplers/base.py
+++ b/src/irec/dataset/samplers/base.py
@@ -33,16 +33,16 @@ def __len__(self):
return len(self._dataset)
def __getitem__(self, index):
- sample = copy.deepcopy(self._dataset[index])
+ sample = self._dataset[index]
- item_sequence = sample['item.ids'][:-1]
- next_item = sample['item.ids'][-1]
+ item_sequence = sample["item.ids"][:-1]
+ next_item = sample["item.ids"][-1]
return {
- 'user.ids': sample['user.ids'],
- 'user.length': sample['user.length'],
- 'item.ids': item_sequence,
- 'item.length': len(item_sequence),
- 'labels.ids': [next_item],
- 'labels.length': 1,
+ "user.ids": sample["user.ids"],
+ "user.length": sample["user.length"],
+ "item.ids": item_sequence,
+ "item.length": len(item_sequence),
+ "labels.ids": [next_item],
+ "labels.length": 1,
}
diff --git a/src/irec/dataset/samplers/mclsr.py b/src/irec/dataset/samplers/mclsr.py
index b957cde..3a9276a 100644
--- a/src/irec/dataset/samplers/mclsr.py
+++ b/src/irec/dataset/samplers/mclsr.py
@@ -4,8 +4,16 @@
import random
-class MCLSRTrainSampler(TrainSampler, config_name='mclsr'):
- def __init__(self, dataset, num_users, num_items, user_to_all_seen_items, num_negatives, **kwargs):
+class MCLSRTrainSampler(TrainSampler, config_name="mclsr"):
+ def __init__(
+ self,
+ dataset,
+ num_users,
+ num_items,
+ user_to_all_seen_items,
+ num_negatives,
+ **kwargs,
+ ):
super().__init__()
self._dataset = dataset
self._num_users = num_users
@@ -16,65 +24,81 @@ def __init__(self, dataset, num_users, num_items, user_to_all_seen_items, num_ne
@classmethod
def create_from_config(cls, config, **kwargs):
- num_negatives = config['num_negatives_train']
+ num_negatives = config["num_negatives_train"]
print(num_negatives)
return cls(
- dataset=kwargs['dataset'],
- num_users=kwargs['num_users'],
- num_items=kwargs['num_items'],
+ dataset=kwargs["dataset"],
+ num_users=kwargs["num_users"],
+ num_items=kwargs["num_items"],
num_negatives=num_negatives,
- user_to_all_seen_items=kwargs['user_to_all_seen_items'],
+ user_to_all_seen_items=kwargs["user_to_all_seen_items"],
)
-
def __getitem__(self, index):
sample = self._dataset[index]
- user_id = sample['user.ids'][0]
- item_sequence = sample['item.ids'][:-1]
- positive_item = sample['item.ids'][-1]
+ user_id = sample["user.ids"][0]
+ item_sequence = sample["item.ids"][:-1]
+ positive_item = sample["item.ids"][-1]
user_seen = self._user_to_all_seen_items[user_id]
- unseen_items = list(self._all_items_set - user_seen)
-
- negatives = random.sample(unseen_items, self._num_negatives)
-
+ # unseen_items = list(self._all_items_set - user_seen)
+ # negatives = random.sample(unseen_items, self._num_negatives)
+
+ # --- OPTIMIZATION: Rejection Sampling ---
+ # Mathematically equivalent to: random.sample(list(all_items - user_seen), k)
+ # Logic: Instead of allocating a huge list of unseen items (O(N) memory),
+ # we draw random indices until we have the required number of unique negatives.
+ # Since |user_seen| << |num_items|, the collision probability is near zero.
+ negatives = set()
+ while len(negatives) < self._num_negatives:
+ # Draw a random item index from the range [1, num_items]
+ candidate = random.randint(1, self._num_items)
+
+ # Rejection step: Only accept if the user has never interacted with it.
+ # This ensures we only sample from the "unseen" pool.
+ if candidate not in user_seen:
+ negatives.add(candidate)
+
+ # Convert back to list to match the expected format for BatchProcessor
+ negatives = list(negatives)
+ # ----------------------------------------
return {
- 'user.ids': [user_id],
- 'user.length': sample['user.length'],
- 'item.ids': item_sequence,
- 'item.length': len(item_sequence),
- 'labels.ids': [positive_item],
- 'labels.length': 1,
- 'negatives.ids': negatives,
- 'negatives.length': len(negatives),
+ "user.ids": [user_id],
+ "user.length": sample["user.length"],
+ "item.ids": item_sequence,
+ "item.length": len(item_sequence),
+ "labels.ids": [positive_item],
+ "labels.length": 1,
+ "negatives.ids": negatives,
+ "negatives.length": len(negatives),
}
-class MCLSRPredictionEvalSampler(EvalSampler, config_name='mclsr'):
+class MCLSRPredictionEvalSampler(EvalSampler, config_name="mclsr"):
def __init__(self, dataset, num_users, num_items):
super().__init__(dataset, num_users, num_items)
@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
- dataset=kwargs['dataset'],
- num_users=kwargs['num_users'],
- num_items=kwargs['num_items'],
+ dataset=kwargs["dataset"],
+ num_users=kwargs["num_users"],
+ num_items=kwargs["num_items"],
)
-
+
def __getitem__(self, index):
sample = self._dataset[index]
- history_sequence = sample['history']
- target_items = sample['target']
+ history_sequence = sample["history"]
+ target_items = sample["target"]
return {
- 'user.ids': sample['user.ids'],
- 'user.length': 1,
- 'item.ids': history_sequence,
- 'item.length': len(history_sequence),
- 'labels.ids': target_items,
- 'labels.length': len(target_items),
+ "user.ids": sample["user.ids"],
+ "user.length": 1,
+ "item.ids": history_sequence,
+ "item.length": len(history_sequence),
+ "labels.ids": target_items,
+ "labels.length": len(target_items),
}
diff --git a/src/irec/dataset/samplers/next_item_prediction.py b/src/irec/dataset/samplers/next_item_prediction.py
index e263062..f757b4b 100644
--- a/src/irec/dataset/samplers/next_item_prediction.py
+++ b/src/irec/dataset/samplers/next_item_prediction.py
@@ -6,7 +6,7 @@
class NextItemPredictionTrainSampler(
TrainSampler,
- config_name='next_item_prediction',
+ config_name="next_item_prediction",
):
def __init__(
self,
@@ -26,61 +26,59 @@ def __init__(
@classmethod
def create_from_config(cls, config, **kwargs):
negative_sampler = BaseNegativeSampler.create_from_config(
- {'type': config['negative_sampler_type']},
+ {"type": config["negative_sampler_type"]},
**kwargs,
)
return cls(
- dataset=kwargs['dataset'],
- num_users=kwargs['num_users'],
- num_items=kwargs['num_items'],
+ dataset=kwargs["dataset"],
+ num_users=kwargs["num_users"],
+ num_items=kwargs["num_items"],
negative_sampler=negative_sampler,
- num_negatives=config.get('num_negatives_train', 0),
+ num_negatives=config.get("num_negatives_train", 0),
)
def __getitem__(self, index):
- sample = copy.deepcopy(self._dataset[index])
+ sample = self._dataset[index]
- item_sequence = sample['item.ids'][:-1]
- next_item_sequence = sample['item.ids'][1:]
+ item_sequence = sample["item.ids"][:-1]
+ next_item_sequence = sample["item.ids"][1:]
if self._num_negatives == 0:
return {
- 'user.ids': sample['user.ids'],
- 'user.length': sample['user.length'],
- 'item.ids': item_sequence,
- 'item.length': len(item_sequence),
- 'positive.ids': next_item_sequence,
- 'positive.length': len(next_item_sequence),
+ "user.ids": sample["user.ids"],
+ "user.length": sample["user.length"],
+ "item.ids": item_sequence,
+ "item.length": len(item_sequence),
+ "positive.ids": next_item_sequence,
+ "positive.length": len(next_item_sequence),
}
else:
- negative_sequence = (
- self._negative_sampler.generate_negative_samples(
- sample,
- self._num_negatives,
- )
+ negative_sequence = self._negative_sampler.generate_negative_samples(
+ sample,
+ self._num_negatives,
)
return {
- 'user.ids': sample['user.ids'],
- 'user.length': sample['user.length'],
- 'item.ids': item_sequence,
- 'item.length': len(item_sequence),
- 'positive.ids': next_item_sequence,
- 'positive.length': len(next_item_sequence),
- 'negative.ids': negative_sequence,
- 'negative.length': len(negative_sequence),
+ "user.ids": sample["user.ids"],
+ "user.length": sample["user.length"],
+ "item.ids": item_sequence,
+ "item.length": len(item_sequence),
+ "positive.ids": next_item_sequence,
+ "positive.length": len(next_item_sequence),
+ "negative.ids": negative_sequence,
+ "negative.length": len(negative_sequence),
}
class NextItemPredictionEvalSampler(
EvalSampler,
- config_name='next_item_prediction',
+ config_name="next_item_prediction",
):
@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
- dataset=kwargs['dataset'],
- num_users=kwargs['num_users'],
- num_items=kwargs['num_items'],
+ dataset=kwargs["dataset"],
+ num_users=kwargs["num_users"],
+ num_items=kwargs["num_items"],
)
diff --git a/src/irec/dataset/sasrec.py b/src/irec/dataset/sasrec.py
index 0cba986..0c1dd97 100644
--- a/src/irec/dataset/sasrec.py
+++ b/src/irec/dataset/sasrec.py
@@ -3,35 +3,41 @@
from .base import BaseSequenceDataset, SequenceDataset, MCLSRDataset
from .samplers import TrainSampler, EvalSampler
-class SASRecDataset(BaseSequenceDataset, config_name='sasrec_comparison'):
+
+class SASRecDataset(BaseSequenceDataset, config_name="sasrec_comparison"):
@classmethod
def create_from_config(cls, config, **kwargs):
- data_dir = os.path.join(config['path_to_data_dir'], config['name'])
- max_seq_len = config.get('max_sequence_length')
+ data_dir = os.path.join(config["path_to_data_dir"], config["name"])
+ max_seq_len = config.get("max_sequence_length")
train_dataset, u1, i1, _ = SequenceDataset._create_dataset(
- dir_path=data_dir,
- part='train_sasrec',
- max_sequence_length=max_seq_len
+ dir_path=data_dir, part="train_sasrec", max_sequence_length=max_seq_len
)
- validation_dataset, test_dataset, u_eval, i_eval = \
+ validation_dataset, test_dataset, u_eval, i_eval = (
MCLSRDataset._create_evaluation_sets(data_dir, max_seq_len)
+ )
num_users = max(u1, u_eval)
num_items = max(i1, i_eval)
-
+
train_sampler = TrainSampler.create_from_config(
- config['train_sampler'],
- dataset=train_dataset, num_users=num_users, num_items=num_items
+ config["train_sampler"],
+ dataset=train_dataset,
+ num_users=num_users,
+ num_items=num_items,
)
validation_sampler = EvalSampler.create_from_config(
- config['eval_sampler'],
- dataset=validation_dataset, num_users=num_users, num_items=num_items
+ config["eval_sampler"],
+ dataset=validation_dataset,
+ num_users=num_users,
+ num_items=num_items,
)
test_sampler = EvalSampler.create_from_config(
- config['eval_sampler'],
- dataset=test_dataset, num_users=num_users, num_items=num_items
+ config["eval_sampler"],
+ dataset=test_dataset,
+ num_users=num_users,
+ num_items=num_items,
)
return cls(
@@ -40,5 +46,5 @@ def create_from_config(cls, config, **kwargs):
test_sampler=test_sampler,
num_users=num_users,
num_items=num_items,
- max_sequence_length=max_seq_len
+ max_sequence_length=max_seq_len,
)
diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py
index 10f689e..f8a0762 100644
--- a/src/irec/loss/base.py
+++ b/src/irec/loss/base.py
@@ -1,12 +1,14 @@
-import copy
-
from irec.utils import (
MetaParent,
maybe_to_list,
)
+import copy
import torch
import torch.nn as nn
+import pickle
+import os
+import logging
class BaseLoss(metaclass=MetaParent):
@@ -17,12 +19,12 @@ class TorchLoss(BaseLoss, nn.Module):
pass
-class IdentityLoss(BaseLoss, config_name='identity'):
+class IdentityLoss(BaseLoss, config_name="identity"):
def __call__(self, inputs):
return inputs
-class CompositeLoss(TorchLoss, config_name='composite'):
+class CompositeLoss(TorchLoss, config_name="composite"):
def __init__(self, losses, weights=None, output_prefix=None):
super().__init__()
self._losses = losses
@@ -34,8 +36,8 @@ def create_from_config(cls, config, **kwargs):
losses = []
weights = []
- for loss_cfg in copy.deepcopy(config)['losses']:
- weight = loss_cfg.pop('weight') if 'weight' in loss_cfg else 1.0
+ for loss_cfg in copy.deepcopy(config)["losses"]:
+ weight = loss_cfg.pop("weight") if "weight" in loss_cfg else 1.0
loss_function = BaseLoss.create_from_config(loss_cfg)
weights.append(weight)
@@ -44,7 +46,7 @@ def create_from_config(cls, config, **kwargs):
return cls(
losses=losses,
weights=weights,
- output_prefix=config.get('output_prefix'),
+ output_prefix=config.get("output_prefix"),
)
def forward(self, inputs):
@@ -58,7 +60,7 @@ def forward(self, inputs):
return total_loss
-class FpsLoss(TorchLoss, config_name='fps'):
+class FpsLoss(TorchLoss, config_name="fps"):
def __init__(
self,
fst_embeddings_prefix,
@@ -73,7 +75,7 @@ def __init__(
self._snd_embeddings_prefix = snd_embeddings_prefix
self._tau = tau
self._loss_function = nn.CrossEntropyLoss(
- reduction='mean' if use_mean else 'sum',
+ reduction="mean" if use_mean else "sum",
)
self._normalize_embeddings = normalize_embeddings
self._output_prefix = output_prefix
@@ -82,21 +84,17 @@ def __init__(
@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
- fst_embeddings_prefix=config['fst_embeddings_prefix'],
- snd_embeddings_prefix=config['snd_embeddings_prefix'],
- tau=config.get('temperature', 1.0),
- normalize_embeddings=config.get('normalize_embeddings', False),
- use_mean=config.get('use_mean', True),
- output_prefix=config.get('output_prefix')
+ fst_embeddings_prefix=config["fst_embeddings_prefix"],
+ snd_embeddings_prefix=config["snd_embeddings_prefix"],
+ tau=config.get("temperature", 1.0),
+ normalize_embeddings=config.get("normalize_embeddings", False),
+ use_mean=config.get("use_mean", True),
+ output_prefix=config.get("output_prefix"),
)
def forward(self, inputs):
- fst_embeddings = inputs[
- self._fst_embeddings_prefix
- ] # (x, embedding_dim)
- snd_embeddings = inputs[
- self._snd_embeddings_prefix
- ] # (x, embedding_dim)
+ fst_embeddings = inputs[self._fst_embeddings_prefix] # (x, embedding_dim)
+ snd_embeddings = inputs[self._snd_embeddings_prefix] # (x, embedding_dim)
batch_size = fst_embeddings.shape[0]
@@ -123,7 +121,9 @@ def forward(self, inputs):
torch.diag(similarity_scores, -batch_size),
),
dim=0,
- ).reshape(2 * batch_size, 1) # (2 * x, 1)
+ ).reshape(
+ 2 * batch_size, 1
+ ) # (2 * x, 1)
assert torch.allclose(
torch.diag(similarity_scores, batch_size),
torch.diag(similarity_scores, -batch_size),
@@ -160,14 +160,9 @@ def forward(self, inputs):
return loss
-class SASRecLoss(TorchLoss, config_name='sasrec'):
+class SASRecLoss(TorchLoss, config_name="sasrec"):
- def __init__(
- self,
- positive_prefix,
- negative_prefix,
- output_prefix=None
- ):
+ def __init__(self, positive_prefix, negative_prefix, output_prefix=None):
super().__init__()
self._positive_prefix = positive_prefix
self._negative_prefix = negative_prefix
@@ -190,7 +185,7 @@ def forward(self, inputs):
return loss
-class SamplesSoftmaxLoss(TorchLoss, config_name='sampled_softmax'):
+class SamplesSoftmaxLoss(TorchLoss, config_name="sampled_softmax"):
def __init__(
self,
queries_prefix,
@@ -205,9 +200,7 @@ def __init__(
self._output_prefix = output_prefix
def forward(self, inputs):
- queries_embeddings = inputs[
- self._queries_prefix
- ] # (batch_size, embedding_dim)
+ queries_embeddings = inputs[self._queries_prefix] # (batch_size, embedding_dim)
positive_embeddings = inputs[
self._positive_prefix
] # (batch_size, embedding_dim)
@@ -217,15 +210,17 @@ def forward(self, inputs):
# b -- batch_size, d -- embedding_dim
positive_scores = torch.einsum(
- 'bd,bd->b',
+ "bd,bd->b",
queries_embeddings,
positive_embeddings,
- ).unsqueeze(-1) # (batch_size, 1)
+ ).unsqueeze(
+ -1
+ ) # (batch_size, 1)
if negative_embeddings.dim() == 2: # (num_negatives, embedding_dim)
# b -- batch_size, n -- num_negatives, d -- embedding_dim
negative_scores = torch.einsum(
- 'bd,nd->bn',
+ "bd,nd->bn",
queries_embeddings,
negative_embeddings,
) # (batch_size, num_negatives)
@@ -235,7 +230,7 @@ def forward(self, inputs):
) # (batch_size, num_negatives, embedding_dim)
# b -- batch_size, n -- num_negatives, d -- embedding_dim
negative_scores = torch.einsum(
- 'bd,bnd->bn',
+ "bd,bnd->bn",
queries_embeddings,
negative_embeddings,
) # (batch_size, num_negatives)
@@ -257,7 +252,7 @@ def forward(self, inputs):
return loss
-class MCLSRLoss(TorchLoss, config_name='mclsr'):
+class MCLSRLoss(TorchLoss, config_name="mclsr"):
def __init__(
self,
all_scores_prefix,
@@ -294,12 +289,11 @@ def forward(self, inputs):
assert torch.allclose(all_scores[0, 0], positive_scores[0])
assert torch.allclose(all_scores[-1, -1], positive_scores[-1])
- # Maybe try mean over sequence TODO
loss = torch.sum(
torch.log(
torch.sigmoid(positive_scores.unsqueeze(1) - negative_scores),
),
- ) # (1)
+ )
if self._output_prefix is not None:
inputs[self._output_prefix] = loss.cpu().item()
diff --git a/src/irec/metric/base.py b/src/irec/metric/base.py
index da68965..87c0190 100644
--- a/src/irec/metric/base.py
+++ b/src/irec/metric/base.py
@@ -12,7 +12,7 @@ def reduce(self):
raise NotImplementedError
-class StaticMetric(BaseMetric, config_name='dummy'):
+class StaticMetric(BaseMetric, config_name="dummy"):
def __init__(self, name, value):
self._name = name
self._value = value
@@ -23,16 +23,14 @@ def __call__(self, inputs):
return inputs
-class CompositeMetric(BaseMetric, config_name='composite'):
+class CompositeMetric(BaseMetric, config_name="composite"):
def __init__(self, metrics):
self._metrics = metrics
@classmethod
def create_from_config(cls, config):
return cls(
- metrics=[
- BaseMetric.create_from_config(cfg) for cfg in config['metrics']
- ],
+ metrics=[BaseMetric.create_from_config(cfg) for cfg in config["metrics"]],
)
def __call__(self, inputs):
@@ -41,7 +39,7 @@ def __call__(self, inputs):
return inputs
-class NDCGMetric(BaseMetric, config_name='ndcg'):
+class NDCGMetric(BaseMetric, config_name="ndcg"):
def __init__(self, k):
self._k = k
@@ -50,9 +48,8 @@ def __call__(self, inputs, pred_prefix, labels_prefix):
:,
: self._k,
].float() # (batch_size, top_k_indices)
- labels = inputs['{}.ids'.format(labels_prefix)].float() # (batch_size)
+ labels = inputs["{}.ids".format(labels_prefix)].float() # (batch_size)
-
assert labels.shape[0] == predictions.shape[0]
hits = torch.eq(
@@ -61,13 +58,15 @@ def __call__(self, inputs, pred_prefix, labels_prefix):
).float() # (batch_size, top_k_indices)
discount_factor = 1 / torch.log2(
torch.arange(1, self._k + 1, 1).float() + 1.0,
- ).to(hits.device) # (k)
- dcg = torch.einsum('bk,k->b', hits, discount_factor) # (batch_size)
+ ).to(
+ hits.device
+ ) # (k)
+ dcg = torch.einsum("bk,k->b", hits, discount_factor) # (batch_size)
return dcg.cpu().tolist()
-class RecallMetric(BaseMetric, config_name='recall'):
+class RecallMetric(BaseMetric, config_name="recall"):
def __init__(self, k):
self._k = k
@@ -76,7 +75,7 @@ def __call__(self, inputs, pred_prefix, labels_prefix):
:,
: self._k,
].float() # (batch_size, top_k_indices)
- labels = inputs['{}.ids'.format(labels_prefix)].float() # (batch_size)
+ labels = inputs["{}.ids".format(labels_prefix)].float() # (batch_size)
assert labels.shape[0] == predictions.shape[0]
@@ -89,35 +88,34 @@ def __call__(self, inputs, pred_prefix, labels_prefix):
return recall.cpu().tolist()
-class CoverageMetric(StatefullMetric, config_name='coverage'):
+class CoverageMetric(StatefullMetric, config_name="coverage"):
def __init__(self, k, num_items):
self._k = k
self._num_items = num_items
@classmethod
def create_from_config(cls, config, **kwargs):
- return cls(k=config['k'], num_items=kwargs['num_items'])
+ return cls(k=config["k"], num_items=kwargs["num_items"])
def __call__(self, inputs, pred_prefix, labels_prefix):
predictions = inputs[pred_prefix][
:,
: self._k,
].float() # (batch_size, top_k_indices)
- return (
- predictions.view(-1).long().cpu().detach().tolist()
- ) # (batch_size * k)
+ return predictions.view(-1).long().cpu().detach().tolist() # (batch_size * k)
def reduce(self, values):
return len(set(values)) / self._num_items
-class MCLSRNDCGMetric(BaseMetric, config_name='mclsr-ndcg'):
+
+class MCLSRNDCGMetric(BaseMetric, config_name="mclsr-ndcg"):
def __init__(self, k):
self._k = k
def __call__(self, inputs, pred_prefix, labels_prefix):
- predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k)
- labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,)
- labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,)
+ predictions = inputs[pred_prefix][:, : self._k] # (batch_size, k)
+ labels_flat = inputs[f"{labels_prefix}.ids"] # (total_labels,)
+ labels_lengths = inputs[f"{labels_prefix}.length"] # (batch_size,)
assert predictions.shape[0] == labels_lengths.shape[0]
@@ -129,30 +127,30 @@ def __call__(self, inputs, pred_prefix, labels_prefix):
user_labels = labels_flat[offset : offset + num_user_labels]
offset += num_user_labels
- hits_mask = torch.isin(user_predictions, user_labels) # (k,) -> True/False
-
+ hits_mask = torch.isin(user_predictions, user_labels) # (k,) -> True/False
+
positions = torch.arange(2, self._k + 2, device=predictions.device)
weights = 1 / torch.log2(positions.float())
dcg = (hits_mask.float() * weights).sum()
-
+
num_ideal_hits = min(self._k, num_user_labels)
idcg_weights = weights[:num_ideal_hits]
idcg = idcg_weights.sum()
-
+
ndcg = dcg / idcg if idcg > 0 else torch.tensor(0.0)
dcg_scores.append(ndcg.cpu().item())
-
+
return dcg_scores
-class MCLSRRecallMetric(BaseMetric, config_name='mclsr-recall'):
+class MCLSRRecallMetric(BaseMetric, config_name="mclsr-recall"):
def __init__(self, k):
self._k = k
def __call__(self, inputs, pred_prefix, labels_prefix):
- predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k)
- labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,)
- labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,)
+ predictions = inputs[pred_prefix][:, : self._k] # (batch_size, k)
+ labels_flat = inputs[f"{labels_prefix}.ids"] # (total_labels,)
+ labels_lengths = inputs[f"{labels_prefix}.length"] # (batch_size,)
assert predictions.shape[0] == labels_lengths.shape[0]
@@ -163,22 +161,25 @@ def __call__(self, inputs, pred_prefix, labels_prefix):
num_user_labels = labels_lengths[i]
user_labels = labels_flat[offset : offset + num_user_labels]
offset += num_user_labels
-
+
hits = torch.isin(user_predictions, user_labels).sum().float()
-
- recall = hits / num_user_labels if num_user_labels > 0 else torch.tensor(0.0)
+
+ recall = (
+ hits / num_user_labels if num_user_labels > 0 else torch.tensor(0.0)
+ )
recall_scores.append(recall.cpu().item())
-
+
return recall_scores
-class MCLSRHitRateMetric(BaseMetric, config_name='mclsr-hit'):
+
+class MCLSRHitRateMetric(BaseMetric, config_name="mclsr-hit"):
def __init__(self, k):
self._k = k
def __call__(self, inputs, pred_prefix, labels_prefix):
- predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k)
- labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,)
- labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,)
+ predictions = inputs[pred_prefix][:, : self._k] # (batch_size, k)
+ labels_flat = inputs[f"{labels_prefix}.ids"] # (total_labels,)
+ labels_lengths = inputs[f"{labels_prefix}.length"] # (batch_size,)
assert predictions.shape[0] == labels_lengths.shape[0]
@@ -194,9 +195,9 @@ def __call__(self, inputs, pred_prefix, labels_prefix):
user_labels = labels_flat[offset : offset + num_user_labels]
offset += num_user_labels
-
+
is_hit = torch.isin(user_predictions, user_labels).any()
-
+
hit_scores.append(float(is_hit))
-
- return hit_scores
\ No newline at end of file
+
+ return hit_scores
diff --git a/src/irec/models/__init__.py b/src/irec/models/__init__.py
index 351a048..761e479 100644
--- a/src/irec/models/__init__.py
+++ b/src/irec/models/__init__.py
@@ -1,13 +1,9 @@
from .base import BaseModel, SequentialTorchModel
from .mclsr import MCLSRModel
from .sasrec import SasRecModel, SasRecInBatchModel
-from .sasrec_ce import SasRecCeModel
__all__ = [
'BaseModel',
'MCLSRModel',
'SasRecModel',
- 'SasRecInBatchModel',
- 'SasRecCeModel',
- 'SasRecRealModel',
]
diff --git a/src/irec/models/base.py b/src/irec/models/base.py
index 2f059f7..22555f4 100644
--- a/src/irec/models/base.py
+++ b/src/irec/models/base.py
@@ -21,8 +21,8 @@ class TorchModel(nn.Module, BaseModel):
@torch.no_grad()
def _init_weights(self, initializer_range):
for key, value in self.named_parameters():
- if 'weight' in key:
- if 'norm' in key:
+ if "weight" in key:
+ if "norm" in key:
nn.init.ones_(value.data)
else:
nn.init.trunc_normal_(
@@ -31,10 +31,10 @@ def _init_weights(self, initializer_range):
a=-2 * initializer_range,
b=2 * initializer_range,
)
- elif 'bias' in key:
+ elif "bias" in key:
nn.init.zeros_(value.data)
else:
- raise ValueError(f'Unknown transformer weight: {key}')
+ raise ValueError(f"Unknown transformer weight: {key}")
@staticmethod
def _get_last_embedding(embeddings, mask):
@@ -54,12 +54,12 @@ def _get_last_embedding(embeddings, mask):
) # (batch_size, 1, emb_dim)
last_embeddings = last_embeddings[last_masks] # (batch_size, emb_dim)
if not torch.allclose(embeddings[mask][-1], last_embeddings[-1]):
- logger.debug(f'Embeddings: {embeddings}')
+ logger.debug(f"Embeddings: {embeddings}")
logger.debug(
- f'Lengths: {lengths}, max: {lengths.max()}, min: {lengths.min()}',
+ f"Lengths: {lengths}, max: {lengths.max()}, min: {lengths.min()}",
)
- logger.debug(f'Last embedding from mask: {embeddings[mask][-1]}')
- logger.debug(f'Last embedding from gather: {last_embeddings[-1]}')
+ logger.debug(f"Last embedding from mask: {embeddings[mask][-1]}")
+ logger.debug(f"Last embedding from gather: {last_embeddings[-1]}")
assert False
return last_embeddings
@@ -74,7 +74,7 @@ def __init__(
num_layers,
dim_feedforward,
dropout=0.0,
- activation='relu',
+ activation="relu",
layer_norm_eps=1e-5,
is_causal=True,
):
@@ -85,8 +85,7 @@ def __init__(
self._embedding_dim = embedding_dim
self._item_embeddings = nn.Embedding(
- num_embeddings=num_items
- + 2, # add zero embedding + mask embedding
+ num_embeddings=num_items + 2, # add zero embedding + mask embedding
embedding_dim=embedding_dim,
)
self._position_embeddings = nn.Embedding(
@@ -135,9 +134,7 @@ def _apply_sequential_encoder(self, events, lengths, add_cls_token=False):
.tile([batch_size, 1])
.long()
) # (batch_size, seq_len)
- positions_mask = (
- positions < lengths[:, None]
- ) # (batch_size, max_seq_len)
+ positions_mask = positions < lengths[:, None] # (batch_size, max_seq_len)
positions = positions[positions_mask] # (all_batch_events)
position_embeddings = self._position_embeddings(
@@ -215,7 +212,9 @@ def _add_cls_token(items, lengths, cls_token_id=0):
dim=0,
index=torch.cat(
[torch.LongTensor([0]).to(DEVICE), lengths + 1],
- ).cumsum(dim=0)[:-1],
+ ).cumsum(
+ dim=0
+ )[:-1],
) # (num_new_items)
new_items[old_items_mask] = items
new_length = lengths + 1
diff --git a/src/irec/models/mclsr.py b/src/irec/models/mclsr.py
index fc8c6ef..2de98f2 100644
--- a/src/irec/models/mclsr.py
+++ b/src/irec/models/mclsr.py
@@ -5,8 +5,10 @@
from irec.utils import create_masked_tensor
+torch.backends.cudnn.benchmark = True
-class MCLSRModel(TorchModel, config_name='mclsr'):
+
+class MCLSRModel(TorchModel, config_name="mclsr"):
def __init__(
self,
sequence_prefix,
@@ -50,8 +52,7 @@ def __init__(
self._item_graph = item_graph
self._item_embeddings = nn.Embedding(
- num_embeddings=num_items
- + 2, # add zero embedding + mask embedding
+ num_embeddings=num_items + 2, # add zero embedding + mask embedding
embedding_dim=embedding_dim,
)
self._position_embeddings = nn.Embedding(
@@ -61,8 +62,7 @@ def __init__(
)
self._user_embeddings = nn.Embedding(
- num_embeddings=num_users
- + 2, # add zero embedding + mask embedding
+ num_embeddings=num_users + 2, # add zero embedding + mask embedding
embedding_dim=embedding_dim,
)
@@ -155,23 +155,23 @@ def __init__(
@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
- sequence_prefix=config['sequence_prefix'],
- user_prefix=config['user_prefix'],
- labels_prefix=config['labels_prefix'],
- negatives_prefix=config.get('negatives_prefix', 'negatives'),
- candidate_prefix=config['candidate_prefix'],
- num_users=kwargs['num_users'],
- num_items=kwargs['num_items'],
- max_sequence_length=kwargs['max_sequence_length'],
- embedding_dim=config['embedding_dim'],
- num_graph_layers=config['num_graph_layers'],
- common_graph=kwargs['graph'],
- user_graph=kwargs['user_graph'],
- item_graph=kwargs['item_graph'],
- dropout=config.get('dropout', 0.0),
- layer_norm_eps=config.get('layer_norm_eps', 1e-5),
- graph_dropout=config.get('graph_dropout', 0.0),
- initializer_range=config.get('initializer_range', 0.02),
+ sequence_prefix=config["sequence_prefix"],
+ user_prefix=config["user_prefix"],
+ labels_prefix=config["labels_prefix"],
+ negatives_prefix=config.get("negatives_prefix", "negatives"),
+ candidate_prefix=config["candidate_prefix"],
+ num_users=kwargs["num_users"],
+ num_items=kwargs["num_items"],
+ max_sequence_length=kwargs["max_sequence_length"],
+ embedding_dim=config["embedding_dim"],
+ num_graph_layers=config["num_graph_layers"],
+ common_graph=kwargs["graph"],
+ user_graph=kwargs["user_graph"],
+ item_graph=kwargs["item_graph"],
+ dropout=config.get("dropout", 0.0),
+ layer_norm_eps=config.get("layer_norm_eps", 1e-5),
+ graph_dropout=config.get("graph_dropout", 0.0),
+ initializer_range=config.get("initializer_range", 0.02),
)
def _apply_graph_encoder(self, embeddings, graph, use_mean=False):
@@ -200,12 +200,12 @@ def _apply_graph_encoder(self, embeddings, graph, use_mean=False):
def forward(self, inputs):
all_sample_events = inputs[
- '{}.ids'.format(self._sequence_prefix)
+ "{}.ids".format(self._sequence_prefix)
] # (all_batch_events)
all_sample_lengths = inputs[
- '{}.length'.format(self._sequence_prefix)
+ "{}.length".format(self._sequence_prefix)
] # (batch_size)
- user_ids = inputs['{}.ids'.format(self._user_prefix)] # (batch_size)
+ user_ids = inputs["{}.ids".format(self._user_prefix)] # (batch_size)
embeddings = self._item_embeddings(
all_sample_events,
@@ -243,11 +243,11 @@ def forward(self, inputs):
lengths=all_sample_lengths,
) # (batch_size, seq_len, embedding_dim)
assert torch.allclose(position_embeddings[~mask], embeddings[~mask])
-
+
positioned_embeddings = (
embeddings + position_embeddings
) # (batch_size, seq_len, embedding_dim)
-
+
positioned_embeddings = self._layernorm(
positioned_embeddings,
) # (batch_size, seq_len, embedding_dim)
@@ -258,7 +258,7 @@ def forward(self, inputs):
# formula 2
sequential_attention_matrix = self._current_interest_learning_encoder(
- positioned_embeddings, # E_u,p
+ positioned_embeddings, # E_u,p
).squeeze() # (batch_size, seq_len)
sequential_attention_matrix[~mask] = -torch.inf
@@ -269,134 +269,205 @@ def forward(self, inputs):
# formula 3
sequential_representation = torch.einsum(
- 'bs,bsd->bd',
- sequential_attention_matrix, # A^s
+ "bs,bsd->bd",
+ sequential_attention_matrix, # A^s
embeddings,
) # (batch_size, embedding_dim)
if self.training:
# general interest
# formula 4
- all_init_embeddings = torch.cat([self._user_embeddings.weight,
- self._item_embeddings.weight],
- dim=0)
- all_graph_embeddings = self._apply_graph_encoder(embeddings=all_init_embeddings,
- graph=self._graph)
+ all_init_embeddings = torch.cat(
+ [self._user_embeddings.weight, self._item_embeddings.weight], dim=0
+ )
+ all_graph_embeddings = self._apply_graph_encoder(
+ embeddings=all_init_embeddings, graph=self._graph
+ )
common_graph_user_embs_all, common_graph_item_embs_all = torch.split(
all_graph_embeddings, [self._num_users + 2, self._num_items + 2]
)
common_graph_user_embs_batch = common_graph_user_embs_all[user_ids]
common_graph_item_embs_batch, _ = create_masked_tensor(
- data=common_graph_item_embs_all[all_sample_events],
- lengths=all_sample_lengths
+ data=common_graph_item_embs_all[all_sample_events],
+ lengths=all_sample_lengths,
)
# formula 5: A_c = softmax(tanh(W_3 * h_u,uv) * (E_u,uv)^T)
- graph_attention_matrix = torch.einsum('bd,bsd->bs',
- self._general_interest_learning_encoder
- (common_graph_user_embs_batch),
- common_graph_item_embs_batch)
+ graph_attention_matrix = torch.einsum(
+ "bd,bsd->bs",
+ self._general_interest_learning_encoder(common_graph_user_embs_batch),
+ common_graph_item_embs_batch,
+ )
graph_attention_matrix[~mask] = -torch.inf
graph_attention_matrix = torch.softmax(graph_attention_matrix, dim=1)
# formula 6: I_c = A_c * E_u,uv
- original_graph_representation = torch.einsum('bs,bsd->bd',
- graph_attention_matrix,
- common_graph_item_embs_batch)
+ original_graph_representation = torch.einsum(
+ "bs,bsd->bd", graph_attention_matrix, common_graph_item_embs_batch
+ )
original_sequential_representation = sequential_representation
# formula 13: I_comb = alpha * I_s + (1 - alpha) * I_c
# L_P (Downstream Loss)
- combined_representation = (self._alpha * original_sequential_representation +
- (1 - self._alpha) * original_graph_representation)
- labels = inputs['{}.ids'.format(self._labels_prefix)]
+ combined_representation = (
+ self._alpha * original_sequential_representation
+ + (1 - self._alpha) * original_graph_representation
+ )
+ labels = inputs["{}.ids".format(self._labels_prefix)]
labels_embeddings = self._item_embeddings(labels)
-
+
# formula 7
# L_IL (Interest-level CL)
- sequential_representation_proj = self._sequential_projector(original_sequential_representation)
- graph_representation_proj = self._graph_projector(original_graph_representation)
+ sequential_representation_proj = self._sequential_projector(
+ original_sequential_representation
+ )
+ graph_representation_proj = self._graph_projector(
+ original_graph_representation
+ )
# formula 9: H_u,uu = GraphEncoder(H_u, G_uu)
# L_UC (User-level CL)
- user_graph_user_embs_all = self._apply_graph_encoder(embeddings=self._user_embeddings.weight,
- graph=self._user_graph)
+ user_graph_user_embs_all = self._apply_graph_encoder(
+ embeddings=self._user_embeddings.weight, graph=self._user_graph
+ )
user_graph_user_embs_batch = user_graph_user_embs_all[user_ids]
# formula 10
# T_f,uu = MLP(H_u,uu) и T_f,uv = MLP(H_u,uv)
- user_graph_user_embeddings_proj = self._user_projection(user_graph_user_embs_batch)
- common_graph_user_embeddings_proj = self._user_projection(common_graph_user_embs_batch)
+ user_graph_user_embeddings_proj = self._user_projection(
+ user_graph_user_embs_batch
+ )
+ common_graph_user_embeddings_proj = self._user_projection(
+ common_graph_user_embs_batch
+ )
# item level CL
common_graph_items_flat = common_graph_item_embs_batch[mask]
-
- item_graph_items_all = self._apply_graph_encoder(embeddings=self._item_embeddings.weight,
- graph=self._item_graph)
+
+ item_graph_items_all = self._apply_graph_encoder(
+ embeddings=self._item_embeddings.weight, graph=self._item_graph
+ )
item_graph_items_flat = item_graph_items_all[all_sample_events]
- unique_item_ids, inverse_indices = torch.unique(all_sample_events,
- return_inverse=True)
+ unique_item_ids, inverse_indices = torch.unique(
+ all_sample_events, return_inverse=True
+ )
+ # OLD BAD
+ # try:
+ # from torch_scatter import scatter_mean
+ # except ImportError:
+ # # print("Warning: torch_scatter not found. Using a slower fallback function.")
+ # def scatter_mean(src, index, dim=0, dim_size=None):
+ # out_size = dim_size if dim_size is not None else index.max() + 1
+ # out = torch.zeros((out_size, src.size(1)), dtype=src.dtype, device=src.device)
+ # counts = torch.bincount(index, minlength=out_size).unsqueeze(-1).clamp(min=1)
+ # return out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) / counts
+
+ # --- OPTIMIZED AGGREGATION: scatter_mean ---
+ # We use scatter_mean to aggregate features of unique items within a batch
+ # for Item-level Feature Contrastive Learning.
+ #
+ # Performance Note:
+ # We prioritize the 'torch-scatter' library because it provides highly optimized
+ # C++/CUDA kernels that perform in-place aggregation.
+ #
+ # The 'except ImportError' fallback is provided for environment compatibility,
+ # but it is NOT recommended for large-scale datasets like Amazon Books.
+ # The fallback implementation uses 'expand_as', which creates massive temporary
+ # tensors in GPU memory, potentially leading to Out-Of-Memory (OOM) errors
+ # when processing millions of interaction events.
try:
from torch_scatter import scatter_mean
except ImportError:
- # print("Warning: torch_scatter not found. Using a slower fallback function.")
+
def scatter_mean(src, index, dim=0, dim_size=None):
out_size = dim_size if dim_size is not None else index.max() + 1
- out = torch.zeros((out_size, src.size(1)), dtype=src.dtype, device=src.device)
- counts = torch.bincount(index, minlength=out_size).unsqueeze(-1).clamp(min=1)
- return out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) / counts
-
+ out = torch.zeros(
+ (out_size, src.size(1)), dtype=src.dtype, device=src.device
+ )
+ counts = (
+ torch.bincount(index, minlength=out_size)
+ .unsqueeze(-1)
+ .clamp(min=1)
+ )
+ # WARNING: .expand_as() below is a memory bottleneck for large tensors
+ return (
+ out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src)
+ / counts
+ )
+
num_unique_items = unique_item_ids.shape[0]
- unique_common_graph_items = scatter_mean(common_graph_items_flat,
- inverse_indices, dim=0,
- dim_size=num_unique_items)
+ unique_common_graph_items = scatter_mean(
+ common_graph_items_flat,
+ inverse_indices,
+ dim=0,
+ dim_size=num_unique_items,
+ )
- unique_item_graph_items = scatter_mean(item_graph_items_flat,
- inverse_indices, dim=0,
- dim_size=num_unique_items)
+ unique_item_graph_items = scatter_mean(
+ item_graph_items_flat, inverse_indices, dim=0, dim_size=num_unique_items
+ )
# projection for Item-level Feature CL
- unique_common_graph_items_proj = self._item_projection(unique_common_graph_items)
- unique_item_graph_items_proj = self._item_projection(unique_item_graph_items)
+ unique_common_graph_items_proj = self._item_projection(
+ unique_common_graph_items
+ )
+ unique_item_graph_items_proj = self._item_projection(
+ unique_item_graph_items
+ )
- negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] # (batch_size, num_negatives)
- negative_embeddings = self._item_embeddings(negative_ids) # (batch_size, num_negatives, embedding_dim)
+ # negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] # (batch_size, num_negatives)
+ # negative_embeddings = self._item_embeddings(negative_ids) # (batch_size, num_negatives, embedding_dim)
+
+ raw_negative_ids = inputs["{}.ids".format(self._negatives_prefix)]
+ num_negatives = raw_negative_ids.shape[0] // batch_size
+ negative_ids = raw_negative_ids.view(
+ batch_size, num_negatives
+ ) # (Batch, NumNegs)
+ negative_embeddings = self._item_embeddings(
+ negative_ids
+ ) # (Batch, NumNegs, Dim)
# import code; code.interact(local=locals())
return {
# L_P (formula 14)
- 'combined_representation': combined_representation,
- 'label_representation': labels_embeddings,
-
- 'negative_representation': negative_embeddings,
-
+ "combined_representation": combined_representation,
+ "label_representation": labels_embeddings,
+ "negative_representation": negative_embeddings,
+ # --- ID PASS-THROUGH FOR LOGQ & MASKING ---
+ # We pass raw item and user indices to enable advanced loss operations:
+ # 1. False Negative Masking: Allows SamplesSoftmaxLoss to identify and
+ # neutralize cases where a target item accidentally appears in the
+ # negative sampling pool.
+ # 2. Per-item LogQ Correction: Enables mapping item IDs to global frequency
+ # stats (item_counts.pkl) to remove popularity bias (Sampling Bias).
+ "positive_ids": labels, # Item target indices
+ "negative_ids": negative_ids, # Sampled negative item indices
+ # Useful for potential User-level LogQ correction as requested
+ # by the supervisor to handle highly active users.
+ "user_ids": user_ids,
# for L_IL (formula 8)
-
- 'sequential_representation': sequential_representation_proj,
- 'graph_representation': graph_representation_proj,
-
+ "sequential_representation": sequential_representation_proj,
+ "graph_representation": graph_representation_proj,
# for L_UC (formula 11)
- 'user_graph_user_embeddings': user_graph_user_embeddings_proj,
- 'common_graph_user_embeddings': common_graph_user_embeddings_proj,
-
+ "user_graph_user_embeddings": user_graph_user_embeddings_proj,
+ "common_graph_user_embeddings": common_graph_user_embeddings_proj,
# for L_IC
- 'item_graph_item_embeddings': unique_item_graph_items_proj,
- 'common_graph_item_embeddings': unique_common_graph_items_proj,
+ "item_graph_item_embeddings": unique_item_graph_items_proj,
+ "common_graph_item_embeddings": unique_common_graph_items_proj,
}
else: # eval mode
# formula 16: R(u,N) = Top-N((I_s)^T * h_o)
- if '{}.ids'.format(self._candidate_prefix) in inputs:
+ if "{}.ids".format(self._candidate_prefix) in inputs:
candidate_events = inputs[
- '{}.ids'.format(self._candidate_prefix)
+ "{}.ids".format(self._candidate_prefix)
] # (all_batch_candidates)
candidate_lengths = inputs[
- '{}.length'.format(self._candidate_prefix)
-
+ "{}.length".format(self._candidate_prefix)
] # (batch_size)
candidate_embeddings = self._item_embeddings(
@@ -409,23 +480,22 @@ def scatter_mean(src, index, dim=0, dim_size=None):
) # (batch_size, num_candidates, embedding_dim)
candidate_scores = torch.einsum(
- 'bd,bnd->bn',
- sequential_representation, # I_s
- candidate_embeddings, # h_o (and h_k)
+ "bd,bnd->bn",
+ sequential_representation, # I_s
+ candidate_embeddings, # h_o (and h_k)
) # (batch_size, num_candidates)
else:
candidate_embeddings = (
self._item_embeddings.weight
) # (num_items, embedding_dim)
candidate_scores = torch.einsum(
- 'bd,nd->bn',
- sequential_representation, # I_s
- candidate_embeddings, # all h_v
+ "bd,nd->bn",
+ sequential_representation, # I_s
+ candidate_embeddings, # all h_v
) # (batch_size, num_items)
candidate_scores[:, 0] = -torch.inf
candidate_scores[:, self._num_items + 1 :] = -torch.inf
-
values, indices = torch.topk(
candidate_scores,
k=50,
@@ -433,4 +503,4 @@ def scatter_mean(src, index, dim=0, dim_size=None):
largest=True,
) # (batch_size, 100), (batch_size, 100)
- return indices
\ No newline at end of file
+ return indices
diff --git a/src/irec/models/sasrec.py b/src/irec/models/sasrec.py
index e97019a..52384d2 100644
--- a/src/irec/models/sasrec.py
+++ b/src/irec/models/sasrec.py
@@ -4,22 +4,22 @@
import torch
-class SasRecModel(SequentialTorchModel, config_name='sasrec'):
+class SasRecModel(SequentialTorchModel, config_name="sasrec"):
def __init__(
- self,
- sequence_prefix,
- positive_prefix,
- num_items,
- max_sequence_length,
- embedding_dim,
- num_heads,
- num_layers,
- dim_feedforward,
- dropout=0.0,
- activation='relu',
- layer_norm_eps=1e-9,
- initializer_range=0.02
+ self,
+ sequence_prefix,
+ positive_prefix,
+ num_items,
+ max_sequence_length,
+ embedding_dim,
+ num_heads,
+ num_layers,
+ dim_feedforward,
+ dropout=0.0,
+ activation="relu",
+ layer_norm_eps=1e-9,
+ initializer_range=0.02,
):
super().__init__(
num_items=num_items,
@@ -31,7 +31,7 @@ def __init__(
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
- is_causal=True
+ is_causal=True,
)
self._sequence_prefix = sequence_prefix
self._positive_prefix = positive_prefix
@@ -41,103 +41,100 @@ def __init__(
@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
- sequence_prefix=config['sequence_prefix'],
- positive_prefix=config['positive_prefix'],
- num_items=kwargs['num_items'],
- max_sequence_length=kwargs['max_sequence_length'],
- embedding_dim=config['embedding_dim'],
- num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)),
- num_layers=config['num_layers'],
- dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']),
- dropout=config.get('dropout', 0.0),
- initializer_range=config.get('initializer_range', 0.02)
+ sequence_prefix=config["sequence_prefix"],
+ positive_prefix=config["positive_prefix"],
+ num_items=kwargs["num_items"],
+ max_sequence_length=kwargs["max_sequence_length"],
+ embedding_dim=config["embedding_dim"],
+ num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)),
+ num_layers=config["num_layers"],
+ dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]),
+ dropout=config.get("dropout", 0.0),
+ initializer_range=config.get("initializer_range", 0.02),
)
def forward(self, inputs):
- all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events)
- all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size)
+ all_sample_events = inputs[
+ "{}.ids".format(self._sequence_prefix)
+ ] # (all_batch_events)
+ all_sample_lengths = inputs[
+ "{}.length".format(self._sequence_prefix)
+ ] # (batch_size)
embeddings, mask = self._apply_sequential_encoder(
all_sample_events, all_sample_lengths
) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len)
if self.training: # training mode
- all_positive_sample_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events)
+ all_positive_sample_events = inputs[
+ "{}.ids".format(self._positive_prefix)
+ ] # (all_batch_events)
- all_sample_embeddings = embeddings[mask] # (all_batch_events, embedding_dim)
+ all_sample_embeddings = embeddings[
+ mask
+ ] # (all_batch_events, embedding_dim)
all_embeddings = self._item_embeddings.weight # (num_items, embedding_dim)
# a -- all_batch_events, n -- num_items, d -- embedding_dim
all_scores = torch.einsum(
- 'ad,nd->an',
- all_sample_embeddings,
- all_embeddings
+ "ad,nd->an", all_sample_embeddings, all_embeddings
) # (all_batch_events, num_items)
positive_scores = torch.gather(
- input=all_scores,
- dim=1,
- index=all_positive_sample_events[..., None]
- )[:, 0] # (all_batch_items)
+ input=all_scores, dim=1, index=all_positive_sample_events[..., None]
+ )[
+ :, 0
+ ] # (all_batch_items)
negative_scores = torch.gather(
input=all_scores,
dim=1,
- index=torch.randint(low=0, high=all_scores.shape[1], size=all_positive_sample_events.shape, device=all_positive_sample_events.device)[..., None]
- )[:, 0] # (all_batch_items)
-
- # sample_ids, _ = create_masked_tensor(
- # data=all_sample_events,
- # lengths=all_sample_lengths
- # ) # (batch_size, seq_len)
-
- # sample_ids = torch.repeat_interleave(sample_ids, all_sample_lengths, dim=0) # (all_batch_events, seq_len)
-
- # negative_scores = torch.scatter(
- # input=all_scores,
- # dim=1,
- # index=sample_ids,
- # src=torch.ones_like(sample_ids) * (-torch.inf)
- # ) # (all_batch_events, num_items)
+ index=torch.randint(
+ low=0,
+ high=all_scores.shape[1],
+ size=all_positive_sample_events.shape,
+ device=all_positive_sample_events.device,
+ )[..., None],
+ )[
+ :, 0
+ ] # (all_batch_items)
return {
- 'positive_scores': positive_scores,
- 'negative_scores': negative_scores
+ "positive_scores": positive_scores,
+ "negative_scores": negative_scores,
}
else: # eval mode
- last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim)
+ last_embeddings = self._get_last_embedding(
+ embeddings, mask
+ ) # (batch_size, embedding_dim)
# b - batch_size, n - num_candidates, d - embedding_dim
candidate_scores = torch.einsum(
- 'bd,nd->bn',
- last_embeddings,
- self._item_embeddings.weight
+ "bd,nd->bn", last_embeddings, self._item_embeddings.weight
) # (batch_size, num_items + 2)
_, indices = torch.topk(
- candidate_scores,
- k=50, dim=-1, largest=True
+ candidate_scores, k=50, dim=-1, largest=True
) # (batch_size, 20)
return indices
-class SasRecInBatchModel(SasRecModel, config_name='sasrec_in_batch'):
-
+class SasRecInBatchModel(SasRecModel, config_name="sasrec_in_batch"):
def __init__(
- self,
- sequence_prefix,
- positive_prefix,
- num_items,
- max_sequence_length,
- embedding_dim,
- num_heads,
- num_layers,
- dim_feedforward,
- dropout=0.0,
- activation='relu',
- layer_norm_eps=1e-9,
- initializer_range=0.02
+ self,
+ sequence_prefix,
+ positive_prefix,
+ num_items,
+ max_sequence_length,
+ embedding_dim,
+ num_heads,
+ num_layers,
+ dim_feedforward,
+ dropout=0.0,
+ activation="relu",
+ layer_norm_eps=1e-9,
+ initializer_range=0.02,
):
super().__init__(
sequence_prefix=sequence_prefix,
@@ -151,27 +148,31 @@ def __init__(
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
- initializer_range=initializer_range
+ initializer_range=initializer_range,
)
@classmethod
def create_from_config(cls, config, **kwargs):
return cls(
- sequence_prefix=config['sequence_prefix'],
- positive_prefix=config['positive_prefix'],
- num_items=kwargs['num_items'],
- max_sequence_length=kwargs['max_sequence_length'],
- embedding_dim=config['embedding_dim'],
- num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)),
- num_layers=config['num_layers'],
- dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']),
- dropout=config.get('dropout', 0.0),
- initializer_range=config.get('initializer_range', 0.02)
+ sequence_prefix=config["sequence_prefix"],
+ positive_prefix=config["positive_prefix"],
+ num_items=kwargs["num_items"],
+ max_sequence_length=kwargs["max_sequence_length"],
+ embedding_dim=config["embedding_dim"],
+ num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)),
+ num_layers=config["num_layers"],
+ dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]),
+ dropout=config.get("dropout", 0.0),
+ initializer_range=config.get("initializer_range", 0.02),
)
def forward(self, inputs):
- all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events)
- all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size)
+ all_sample_events = inputs[
+ "{}.ids".format(self._sequence_prefix)
+ ] # (all_batch_events)
+ all_sample_lengths = inputs[
+ "{}.length".format(self._sequence_prefix)
+ ] # (batch_size)
embeddings, mask = self._apply_sequential_encoder(
all_sample_events, all_sample_lengths
@@ -179,10 +180,14 @@ def forward(self, inputs):
if self.training: # training mode
# queries
- in_batch_queries_embeddings = embeddings[mask] # (all_batch_events, embedding_dim)
+ in_batch_queries_embeddings = embeddings[
+ mask
+ ] # (all_batch_events, embedding_dim)
# positives
- in_batch_positive_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events)
+ in_batch_positive_events = inputs[
+ "{}.ids".format(self._positive_prefix)
+ ] # (all_batch_events)
in_batch_positive_embeddings = self._item_embeddings(
in_batch_positive_events
) # (all_batch_events, embedding_dim)
@@ -197,25 +202,33 @@ def forward(self, inputs):
) # (batch_size, embedding_dim)
return {
- 'query_embeddings': in_batch_queries_embeddings,
- 'positive_embeddings': in_batch_positive_embeddings,
- 'negative_embeddings': in_batch_negative_embeddings
+ "query_embeddings": in_batch_queries_embeddings,
+ "positive_embeddings": in_batch_positive_embeddings,
+ "negative_embeddings": in_batch_negative_embeddings,
+ # --- ID PASS-THROUGH FOR DOWNSTREAM LOSS OPERATIONS ---
+ # We pass the raw item IDs to the loss function to enable:
+ # 1. False Negative Masking: Identifying and neutralizing cases where a positive
+ # item for one user accidentally appears as a negative sample for another
+ # user in the same batch (critical for In-Batch Negatives stability).
+ # 2. Per-item LogQ Correction: Mapping item IDs to their global frequencies
+ # to subtract log(Q) based on the specific item popularity.
+ "positive_ids": in_batch_positive_events,
+ "negative_ids": in_batch_negative_ids,
}
else: # eval mode
- last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim)
+ last_embeddings = self._get_last_embedding(
+ embeddings, mask
+ ) # (batch_size, embedding_dim)
# b - batch_size, n - num_candidates, d - embedding_dim
candidate_scores = torch.einsum(
- 'bd,nd->bn',
- last_embeddings,
- self._item_embeddings.weight
+ "bd,nd->bn", last_embeddings, self._item_embeddings.weight
) # (batch_size, num_items + 2)
candidate_scores[:, 0] = -torch.inf
- candidate_scores[:, self._num_items + 1:] = -torch.inf
+ candidate_scores[:, self._num_items + 1 :] = -torch.inf
_, indices = torch.topk(
- candidate_scores,
- k=50, dim=-1, largest=True
+ candidate_scores, k=50, dim=-1, largest=True
) # (batch_size, 20)
- return indices
\ No newline at end of file
+ return indices
diff --git a/src/irec/models/sasrec_ce.py b/src/irec/models/sasrec_ce.py
deleted file mode 100644
index c07b9ff..0000000
--- a/src/irec/models/sasrec_ce.py
+++ /dev/null
@@ -1,108 +0,0 @@
-from .base import SequentialTorchModel
-
-import torch
-import torch.nn as nn
-
-
-class SasRecCeModel(SequentialTorchModel, config_name='sasrec_ce'):
- def __init__(
- self,
- sequence_prefix,
- positive_prefix,
- num_items,
- max_sequence_length,
- embedding_dim,
- num_heads,
- num_layers,
- dim_feedforward,
- dropout=0.0,
- activation='relu',
- layer_norm_eps=1e-9,
- initializer_range=0.02,
- ):
- super().__init__(
- num_items=num_items,
- max_sequence_length=max_sequence_length,
- embedding_dim=embedding_dim,
- num_heads=num_heads,
- num_layers=num_layers,
- dim_feedforward=dim_feedforward,
- dropout=dropout,
- activation=activation,
- layer_norm_eps=layer_norm_eps,
- is_causal=True,
- )
- self._sequence_prefix = sequence_prefix
- self._positive_prefix = positive_prefix
-
- self._output_projection = nn.Linear(
- in_features=embedding_dim,
- out_features=embedding_dim,
- )
-
- self._init_weights(initializer_range)
-
- @classmethod
- def create_from_config(cls, config, **kwargs):
- return cls(
- sequence_prefix=config['sequence_prefix'],
- positive_prefix=config['positive_prefix'],
- num_items=kwargs['num_items'],
- max_sequence_length=kwargs['max_sequence_length'],
- embedding_dim=config['embedding_dim'],
- num_heads=config.get(
- 'num_heads',
- int(config['embedding_dim'] // 64),
- ),
- num_layers=config['num_layers'],
- dim_feedforward=config.get(
- 'dim_feedforward',
- 4 * config['embedding_dim'],
- ),
- dropout=config.get('dropout', 0.0),
- initializer_range=config.get('initializer_range', 0.02),
- )
-
- def forward(self, inputs):
- all_sample_events = inputs[
- '{}.ids'.format(self._sequence_prefix)
- ] # (all_batch_events)
- all_sample_lengths = inputs[
- '{}.length'.format(self._sequence_prefix)
- ] # (batch_size)
-
- embeddings, mask = self._apply_sequential_encoder(
- all_sample_events,
- all_sample_lengths,
- ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len)
-
- embeddings = self._output_projection(
- embeddings,
- ) # (batch_size, seq_len, embedding_dim)
- embeddings = torch.nn.functional.gelu(
- embeddings,
- ) # (batch_size, seq_len, embedding_dim)
- embeddings = torch.einsum(
- 'bsd,nd->bsn',
- embeddings,
- self._item_embeddings.weight,
- ) # (batch_size, seq_len, num_items + 2)
-
- if self.training: # training mode
- return {'logits': embeddings[mask]}
- else: # eval mode
- candidate_scores = self._get_last_embedding(
- embeddings,
- mask,
- ) # (batch_size, num_items + 2)
- candidate_scores[:, 0] = -torch.inf
- candidate_scores[:, self._num_items + 1 :] = -torch.inf
-
- _, indices = torch.topk(
- candidate_scores,
- k=20,
- dim=-1,
- largest=True,
- ) # (batch_size, 20)
-
- return indices
diff --git a/src/irec/optimizer/base.py b/src/irec/optimizer/base.py
index 0fc1f70..51d9464 100644
--- a/src/irec/optimizer/base.py
+++ b/src/irec/optimizer/base.py
@@ -5,14 +5,14 @@
import torch
OPTIMIZERS = {
- 'sgd': torch.optim.SGD,
- 'adam': torch.optim.Adam,
- 'adamw': torch.optim.AdamW,
+ "sgd": torch.optim.SGD,
+ "adam": torch.optim.Adam,
+ "adamw": torch.optim.AdamW,
}
SCHEDULERS = {
- 'step': torch.optim.lr_scheduler.StepLR,
- 'cyclic': torch.optim.lr_scheduler.CyclicLR,
+ "step": torch.optim.lr_scheduler.StepLR,
+ "cyclic": torch.optim.lr_scheduler.CyclicLR,
}
@@ -20,7 +20,7 @@ class BaseOptimizer(metaclass=MetaParent):
pass
-class BasicOptimizer(BaseOptimizer, config_name='basic'):
+class BasicOptimizer(BaseOptimizer, config_name="basic"):
def __init__(
self,
model,
@@ -35,15 +35,15 @@ def __init__(
@classmethod
def create_from_config(cls, config, **kwargs):
- optimizer_cfg = copy.deepcopy(config['optimizer'])
- optimizer = OPTIMIZERS[optimizer_cfg.pop('type')](
- kwargs['model'].parameters(),
+ optimizer_cfg = copy.deepcopy(config["optimizer"])
+ optimizer = OPTIMIZERS[optimizer_cfg.pop("type")](
+ kwargs["model"].parameters(),
**optimizer_cfg,
)
- if 'scheduler' in config:
- scheduler_cfg = copy.deepcopy(config['scheduler'])
- scheduler = SCHEDULERS[scheduler_cfg.pop('type')](
+ if "scheduler" in config:
+ scheduler_cfg = copy.deepcopy(config["scheduler"])
+ scheduler = SCHEDULERS[scheduler_cfg.pop("type")](
optimizer,
**scheduler_cfg,
)
@@ -51,10 +51,10 @@ def create_from_config(cls, config, **kwargs):
scheduler = None
return cls(
- model=kwargs['model'],
+ model=kwargs["model"],
optimizer=optimizer,
scheduler=scheduler,
- clip_grad_threshold=config.get('clip_grad_threshold', None),
+ clip_grad_threshold=config.get("clip_grad_threshold", None),
)
def step(self, loss):
@@ -72,7 +72,7 @@ def step(self, loss):
self._scheduler.step()
def state_dict(self):
- state_dict = {'optimizer': self._optimizer.state_dict()}
+ state_dict = {"optimizer": self._optimizer.state_dict()}
if self._scheduler is not None:
- state_dict.update({'scheduler': self._scheduler.state_dict()})
+ state_dict.update({"scheduler": self._scheduler.state_dict()})
return state_dict
diff --git a/src/irec/train.py b/src/irec/train.py
index 7844439..8251822 100644
--- a/src/irec/train.py
+++ b/src/irec/train.py
@@ -43,20 +43,20 @@ def train(
best_epoch = 0
best_checkpoint = None
- logger.debug('Start training...')
+ logger.debug("Start training...")
while (epoch_cnt is None or epoch_num < epoch_cnt) and (
step_cnt is None or step_num < step_cnt
):
if best_epoch + epochs_threshold < epoch_num:
logger.debug(
- 'There is no progress during {} epochs. Finish training'.format(
+ "There is no progress during {} epochs. Finish training".format(
epochs_threshold,
),
)
break
- logger.debug(f'Start epoch {epoch_num}')
+ logger.debug(f"Start epoch {epoch_num}")
for step, batch in enumerate(dataloader):
batch_ = copy.deepcopy(batch)
@@ -87,7 +87,7 @@ def train(
best_epoch = epoch_num
epoch_num += 1
- logger.debug('Training procedure has been finished!')
+ logger.debug("Training procedure has been finished!")
return best_checkpoint
@@ -95,71 +95,72 @@ def main():
fix_random_seed(seed_val)
config = parse_args()
- if config.get('use_wandb', False):
+ if config.get("use_wandb", False):
wandb.init(
- project='irec',
- name=config['experiment_name'],
+ project="irec",
+ name=config["experiment_name"],
sync_tensorboard=True,
)
- tensorboard_writer = irec.utils.tensorboards.TensorboardWriter(config['experiment_name'])
+ tensorboard_writer = irec.utils.tensorboards.TensorboardWriter(
+ config["experiment_name"]
+ )
irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER = tensorboard_writer
log_dir = tensorboard_writer.log_dir
- config_save_path = os.path.join(log_dir, 'config.json')
- with open(config_save_path, 'w') as f:
+ config_save_path = os.path.join(log_dir, "config.json")
+ with open(config_save_path, "w") as f:
json.dump(config, f, indent=2)
-
- logger.debug('Training config: \n{}'.format(json.dumps(config, indent=2)))
- logger.debug('Current DEVICE: {}'.format(DEVICE))
- logger.info(f"Experiment config saved to: {config_save_path}")
+ logger.debug("Training config: \n{}".format(json.dumps(config, indent=2)))
+ logger.debug("Current DEVICE: {}".format(DEVICE))
+ logger.info(f"Experiment config saved to: {config_save_path}")
- dataset = BaseDataset.create_from_config(config['dataset'])
+ dataset = BaseDataset.create_from_config(config["dataset"])
train_sampler, validation_sampler, test_sampler = dataset.get_samplers()
train_dataloader = BaseDataloader.create_from_config(
- config['dataloader']['train'],
+ config["dataloader"]["train"],
dataset=train_sampler,
**dataset.meta,
)
validation_dataloader = BaseDataloader.create_from_config(
- config['dataloader']['validation'],
+ config["dataloader"]["validation"],
dataset=validation_sampler,
**dataset.meta,
)
eval_dataloader = BaseDataloader.create_from_config(
- config['dataloader']['validation'],
+ config["dataloader"]["validation"],
dataset=test_sampler,
**dataset.meta,
)
- model = BaseModel.create_from_config(config['model'], **dataset.meta).to(
+ model = BaseModel.create_from_config(config["model"], **dataset.meta).to(
DEVICE,
)
- if 'checkpoint' in config:
+ if "checkpoint" in config:
ensure_checkpoints_dir()
checkpoint_path = os.path.join(
- './checkpoints',
+ "./checkpoints",
f'{config["checkpoint"]}.pth',
)
- logger.debug('Loading checkpoint from {}'.format(checkpoint_path))
+ logger.debug("Loading checkpoint from {}".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
logger.debug(checkpoint.keys())
model.load_state_dict(checkpoint)
- loss_function = BaseLoss.create_from_config(config['loss'])
+ loss_function = BaseLoss.create_from_config(config["loss"])
optimizer = BaseOptimizer.create_from_config(
- config['optimizer'],
+ config["optimizer"],
model=model,
)
callback = BaseCallback.create_from_config(
- config['callback'],
+ config["callback"],
model=model,
train_dataloader=train_dataloader,
validation_dataloader=validation_dataloader,
@@ -170,7 +171,7 @@ def main():
# TODO add verbose option for all callbacks, multiple optimizer options (???)
# TODO create pre/post callbacks
- logger.debug('Everything is ready for training process!')
+ logger.debug("Everything is ready for training process!")
# Train process
_ = train(
@@ -179,22 +180,22 @@ def main():
optimizer=optimizer,
loss_function=loss_function,
callback=callback,
- epoch_cnt=config.get('train_epochs_num'),
- step_cnt=config.get('train_steps_num'),
- best_metric=config.get('best_metric'),
+ epoch_cnt=config.get("train_epochs_num"),
+ step_cnt=config.get("train_steps_num"),
+ best_metric=config.get("best_metric"),
)
- logger.debug('Saving model...')
+ logger.debug("Saving model...")
ensure_checkpoints_dir()
- checkpoint_path = './checkpoints/{}_final_state.pth'.format(
- config['experiment_name'],
+ checkpoint_path = "./checkpoints/{}_final_state.pth".format(
+ config["experiment_name"],
)
torch.save(model.state_dict(), checkpoint_path)
- logger.debug('Saved model as {}'.format(checkpoint_path))
+ logger.debug("Saved model as {}".format(checkpoint_path))
- if config.get('use_wandb', False):
+ if config.get("use_wandb", False):
wandb.finish()
-if __name__ == '__main__':
+if __name__ == "__main__":
main()