diff --git a/scripts/plum-lsvd/callbacks.py b/scripts/plum-lsvd/callbacks.py
new file mode 100644
index 0000000..43ec460
--- /dev/null
+++ b/scripts/plum-lsvd/callbacks.py
@@ -0,0 +1,64 @@
+import torch
+
+import irec.callbacks as cb
+from irec.runners import TrainingRunner, TrainingRunnerContext
+
+class InitCodebooks(cb.TrainingCallback):
+ def __init__(self, dataloader):
+ super().__init__()
+ self._dataloader = dataloader
+
+ @torch.no_grad()
+ def before_run(self, runner: TrainingRunner):
+ for i in range(len(runner.model.codebooks)):
+ X = next(iter(self._dataloader))['embedding']
+ idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])]
+ remainder = runner.model.encoder(X[idx])
+
+ for j in range(i):
+ codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j])
+ codebook_vectors = runner.model.codebooks[j][codebook_indices]
+ remainder = remainder - codebook_vectors
+
+ runner.model.codebooks[i].data = remainder.detach()
+
+
+class FixDeadCentroids(cb.TrainingCallback):
+ def __init__(self, dataloader):
+ super().__init__()
+ self._dataloader = dataloader
+
+ def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext):
+ for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)):
+ context.metrics[f'num_dead/{i}'] = num_fixed
+
+ @torch.no_grad()
+ def fix_dead_codebooks(self, runner: TrainingRunner):
+ num_fixed = []
+ for codebook_idx, codebook in enumerate(runner.model.codebooks):
+ centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device)
+ random_batch = next(iter(self._dataloader))['embedding']
+
+ for batch in self._dataloader:
+ remainder = runner.model.encoder(batch['embedding'])
+ for l in range(codebook_idx):
+ ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
+ remainder = remainder - runner.model.codebooks[l][ind]
+
+ indices = runner.model.get_codebook_indices(remainder, codebook)
+ centroid_counts.scatter_add_(0, indices, torch.ones_like(indices))
+
+ dead_mask = (centroid_counts == 0)
+ num_dead = int(dead_mask.sum().item())
+ num_fixed.append(num_dead)
+ if num_dead == 0:
+ continue
+
+ remainder = runner.model.encoder(random_batch)
+ for l in range(codebook_idx):
+ ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
+ remainder = remainder - runner.model.codebooks[l][ind]
+ remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead]
+ codebook[dead_mask] = remainder.detach()
+
+ return num_fixed
diff --git a/scripts/plum-lsvd/cooc_data.py b/scripts/plum-lsvd/cooc_data.py
new file mode 100644
index 0000000..7cea906
--- /dev/null
+++ b/scripts/plum-lsvd/cooc_data.py
@@ -0,0 +1,117 @@
+import json
+from collections import defaultdict, Counter
+from data import InteractionsDatasetParquet
+from collections import defaultdict, Counter
+
+
+class CoocMappingDataset:
+ def __init__(
+ self,
+ train_sampler,
+ num_items,
+ cooccur_counter_mapping=None
+ ):
+ self._train_sampler = train_sampler
+ self._num_items = num_items
+ self._cooccur_counter_mapping = cooccur_counter_mapping
+
+ @classmethod
+ def create(cls, inter_json_path, window_size):
+ max_item_id = 0
+ train_dataset = []
+
+ with open(inter_json_path, 'r') as f:
+ user_interactions = json.load(f)
+
+ for user_id_str, item_ids in user_interactions.items():
+ user_id = int(user_id_str)
+ if item_ids:
+ max_item_id = max(max_item_id, max(item_ids))
+ if len(item_ids) >= 5:
+ print(f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items')
+ train_dataset.append({
+ 'user_ids': [user_id],
+ 'item_ids': item_ids[:-2],
+ })
+
+
+ cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size)
+ print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}')
+
+
+ train_sampler = train_dataset
+
+
+ return cls(
+ train_sampler=train_sampler,
+ num_items=max_item_id + 1,
+ cooccur_counter_mapping=cooccur_counter_mapping
+ )
+
+
+ @classmethod
+ def create_from_split_part(
+ cls,
+ train_inter_parquet_path,
+ window_size,
+ ):
+
+ max_item_id = 0
+ train_dataset = []
+
+
+ train_interactions = InteractionsDatasetParquet(train_inter_parquet_path)
+
+ actions_num = 0
+ for session in train_interactions:
+ user_id, item_ids = int(session['user_id']), session['item_ids']
+ if item_ids.any():
+ max_item_id = max(max_item_id, max(item_ids))
+ actions_num += len(item_ids)
+ train_dataset.append({
+ 'user_ids': [user_id],
+ 'item_ids': item_ids,
+ })
+
+
+ print(f'Train: {len(train_dataset)} users')
+ print(f'Max item ID: {max_item_id}')
+ print(f"Actions num: {actions_num}")
+
+
+ cooccur_counter_mapping = cls.build_cooccur_counter_mapping(
+ train_dataset,
+ window_size=window_size
+ )
+
+
+ print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items')
+
+
+ return cls(
+ train_sampler=train_dataset,
+ num_items=max_item_id + 1,
+ cooccur_counter_mapping=cooccur_counter_mapping
+ )
+
+
+
+ @staticmethod
+ def build_cooccur_counter_mapping(train_dataset, window_size):
+ cooccur_counts = defaultdict(Counter)
+ for session in train_dataset:
+ items = session['item_ids']
+ for i in range(len(items)):
+ item_i = items[i]
+ for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)):
+ if i != j:
+ cooccur_counts[item_i][items[j]] += 1
+ max_hist_len = max(len(counter) for counter in cooccur_counts.values()) if cooccur_counts else 0
+ print(f"Max cooccurrence history length is {max_hist_len}")
+ return cooccur_counts
+
+
+
+ @property
+ def cooccur_counter_mapping(self):
+ return self._cooccur_counter_mapping
\ No newline at end of file
diff --git a/scripts/plum-lsvd/data.py b/scripts/plum-lsvd/data.py
new file mode 100644
index 0000000..5a780fb
--- /dev/null
+++ b/scripts/plum-lsvd/data.py
@@ -0,0 +1,87 @@
+import numpy as np
+import pickle
+
+from irec.data.base import BaseDataset
+from irec.data.transforms import Transform
+
+
+import polars as pl
+
+class InteractionsDatasetParquet(BaseDataset):
+ def __init__(self, data_path, max_items=None):
+ self.df = pl.read_parquet(data_path)
+ assert 'uid' in self.df.columns, "Missing 'uid' column"
+ assert 'item_ids' in self.df.columns, "Missing 'item_ids' column"
+ print(f"Dataset loaded: {len(self.df)} users")
+
+ if max_items is not None:
+ self.df = self.df.with_columns(
+ pl.col("item_ids").list.slice(-max_items).alias("item_ids")
+ )
+
+ def __getitem__(self, idx):
+ row = self.df.row(idx, named=True)
+ return {
+ 'user_id': row['uid'],
+ 'item_ids': np.array(row['item_ids'], dtype=np.uint32),
+ }
+
+ def __len__(self):
+ return len(self.df)
+
+ def __iter__(self):
+ for idx in range(len(self)):
+ yield self[idx]
+
+
+class EmbeddingDatasetParquet(BaseDataset):
+ def __init__(self, data_path):
+ self.df = pl.read_parquet(data_path)
+ self.item_ids = np.array(self.df['item_id'], dtype=np.int64)
+ self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32)
+ print(f"embedding dim: {self.embeddings[0].shape}")
+
+ def __getitem__(self, idx):
+ index = self.item_ids[idx]
+ tensor_emb = self.embeddings[idx]
+ return {
+ 'item_id': index,
+ 'embedding': tensor_emb,
+ 'embedding_dim': len(tensor_emb)
+ }
+
+ def __len__(self):
+ return len(self.embeddings)
+
+
+class EmbeddingDataset(BaseDataset):
+ def __init__(self, data_path):
+ self.data_path = data_path
+ with open(data_path, 'rb') as f:
+ self.data = pickle.load(f)
+
+ self.item_ids = np.array(self.data['item_id'], dtype=np.int64)
+ self.embeddings = np.array(self.data['embedding'], dtype=np.float32)
+
+ def __getitem__(self, idx):
+ index = self.item_ids[idx]
+ tensor_emb = self.embeddings[idx]
+ return {
+ 'item_id': index,
+ 'embedding': tensor_emb,
+ 'embedding_dim': len(tensor_emb)
+ }
+
+ def __len__(self):
+ return len(self.embeddings)
+
+
+class ProcessEmbeddings(Transform):
+ def __init__(self, embedding_dim, keys):
+ self.embedding_dim = embedding_dim
+ self.keys = keys
+
+ def __call__(self, batch):
+ for key in self.keys:
+ batch[key] = batch[key].reshape(-1, self.embedding_dim)
+ return batch
\ No newline at end of file
diff --git a/scripts/plum-lsvd/letter_base_gap/create_base_gap_mapping_from_all.py b/scripts/plum-lsvd/letter_base_gap/create_base_gap_mapping_from_all.py
new file mode 100644
index 0000000..734bfdf
--- /dev/null
+++ b/scripts/plum-lsvd/letter_base_gap/create_base_gap_mapping_from_all.py
@@ -0,0 +1,50 @@
+import json
+import pandas as pd
+from pathlib import Path
+
+
+ALL_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless.json"
+TRAIN_INTERACTIONS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet"
+OUTPUT_TRAIN_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/only_base_with_gap_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless_from_all.json"
+with open(ALL_MAPPING_PATH, 'r') as f:
+ all_mapping = json.load(f)
+print(f"Loaded {len(all_mapping)} items from all_mapping")
+
+train_interactions = pd.read_parquet(TRAIN_INTERACTIONS_PATH)
+
+train_item_ids = set()
+for item_ids_array in train_interactions['item_ids']:
+ train_item_ids.update(item_ids_array)
+
+print(f"Found {len(train_item_ids)} unique train items")
+
+train_mapping = {}
+missing_count = 0
+
+for item_id in train_item_ids:
+ item_id_str = str(item_id)
+ if item_id_str in all_mapping:
+ train_mapping[item_id_str] = all_mapping[item_id_str]
+ else:
+ missing_count += 1
+
+if missing_count > 0:
+ print(f"{missing_count} items from train not found in all_mapping")
+
+print(f"Created train_mapping with {len(train_mapping)} items")
+
+Path(OUTPUT_TRAIN_MAPPING_PATH).parent.mkdir(parents=True, exist_ok=True)
+with open(OUTPUT_TRAIN_MAPPING_PATH, 'w') as f:
+ json.dump(train_mapping, f, indent=2)
+print(f"Saved to {OUTPUT_TRAIN_MAPPING_PATH}")
+
+print(f"all_mapping size: {len(all_mapping)}")
+print(f"train_mapping size: {len(train_mapping)}")
+print(f"train_mapping/all_mapping ratio: {len(train_mapping)/len(all_mapping):.1%}")
+
+sample_matches = 0
+for item_id_str in list(train_mapping.keys())[:100]:
+ if all_mapping[item_id_str] == train_mapping[item_id_str]:
+ sample_matches += 1
+
+print(f"Verified: {sample_matches}/100 sampled items have identical codes")
diff --git a/scripts/plum-lsvd/letter_base_gap/infer_letter.py b/scripts/plum-lsvd/letter_base_gap/infer_letter.py
new file mode 100644
index 0000000..1e3d45a
--- /dev/null
+++ b/scripts/plum-lsvd/letter_base_gap/infer_letter.py
@@ -0,0 +1,155 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+from letter import LetterRQVAE
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from collections import Counter
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 35
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 64
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+
+RQVAE_LOSS_WEIGHT=1.0
+CF_LOSS_WEIGHT=0.01
+
+EXPERIMENT_NAME = f'all_items_letter_vk-lsvd-15ts_base_with_gap_e{NUM_EPOCHS}_rqvae_{RQVAE_LOSS_WEIGHT}_cf_{CF_LOSS_WEIGHT}'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/items_metadata_remapped.parquet"
+IREC_PATH = '../../../'
+
+MODEL_PATH = "/home/jovyan/IRec/checkpoints-lsvd/letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_best_0.0067.pth"
+RESULTS_PATH = os.path.join(IREC_PATH, 'results-lsvd-2/base_gap/letter')
+
+# SASREC_MODEL_PATH = ""
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = LetterRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=RQVAE_LOSS_WEIGHT,
+ cf_loss_weight=CF_LOSS_WEIGHT,
+ cf_embeddings=None
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'cf_loss': model_outputs['cf_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/cf_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ collision_stats = []
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ collision_stats.append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ if collision_stats:
+ max_col_tok = max(collision_stats)
+ avg_col_tok = np.mean(collision_stats)
+ collision_distribution = Counter(collision_stats)
+
+ print(f"Max collision token: {max_col_tok}")
+ print(f"Avg collision token: {avg_col_tok:.2f}")
+ print(f"Total items with collisions: {len(collision_stats)}")
+ print(f"Collision solver distribution: {dict(collision_distribution)}")
+ else:
+ print("No collisions detected")
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/letter_base_gap/letter.py b/scripts/plum-lsvd/letter_base_gap/letter.py
new file mode 100644
index 0000000..786772f
--- /dev/null
+++ b/scripts/plum-lsvd/letter_base_gap/letter.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class LetterRQVAE(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_codebooks,
+ codebook_size,
+ embedding_dim,
+ beta=0.25,
+ quant_loss_weight=1.0,
+ cf_loss_weight=1.0,
+ cf_embeddings=None
+ ):
+ super().__init__()
+ self.register_buffer('beta', torch.tensor(beta))
+
+ self.input_dim = input_dim
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.embedding_dim = embedding_dim
+ self.quant_loss_weight = quant_loss_weight
+
+ self.cf_loss_weight = cf_loss_weight
+
+ self.cf_embeddings = None
+ if cf_embeddings is not None:
+ self.cf_embeddings = torch.tensor(cf_embeddings, dtype=torch.float32)
+ print(self.cf_embeddings.requires_grad)
+ else:
+ print("CF EMBEDS IS NONE!!!! ONLY FOR INFER")
+
+ self.encoder = self.make_encoding_tower(input_dim, embedding_dim)
+ self.decoder = self.make_encoding_tower(embedding_dim, input_dim)
+
+ self.codebooks = torch.nn.ParameterList()
+ for _ in range(num_codebooks):
+ cb = torch.FloatTensor(codebook_size, embedding_dim)
+ #nn.init.normal_(cb)
+ self.codebooks.append(cb)
+
+ @staticmethod
+ def make_encoding_tower(d1, d2, bias=False):
+ return torch.nn.Sequential(
+ nn.Linear(d1, d1),
+ nn.ReLU(),
+ nn.Linear(d1, d2),
+ nn.ReLU(),
+ nn.Linear(d2, d2, bias=bias)
+ )
+
+ @staticmethod
+ def get_codebook_indices(remainder, codebook):
+ dist = torch.cdist(remainder, codebook)
+ return dist.argmin(dim=-1)
+
+ def forward(self, inputs):
+ latent_vector = self.encoder(inputs['embedding'])
+ item_ids = inputs['item_id']
+
+ latent_restored = 0
+ rqvae_loss = 0
+ clusters = []
+ remainder = latent_vector
+ for codebook in self.codebooks:
+ codebook_indices = self.get_codebook_indices(remainder, codebook)
+ clusters.append(codebook_indices)
+
+ quantized = codebook[codebook_indices]
+ codebook_vectors = remainder + (quantized - remainder).detach()
+
+ rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach())
+ rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach())
+
+ latent_restored += codebook_vectors
+ remainder = remainder - codebook_vectors
+
+ embeddings_restored = self.decoder(latent_restored)
+ recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding'])
+
+ if self.cf_embeddings is not None:
+ cf_embedding_in_batch = self.cf_embeddings[item_ids]
+ cf_loss = self.CF_loss(latent_restored, cf_embedding_in_batch)
+ else:
+ cf_loss = torch.as_tensor(0.0)
+
+ loss = (recon_loss + self.quant_loss_weight * rqvae_loss + self.cf_loss_weight * cf_loss).mean()
+
+ clusters_counts = []
+ for cluster in clusters:
+ clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size))
+
+ return loss, {
+ 'loss': loss.item(),
+ 'recon_loss': recon_loss.mean().item(),
+ 'rqvae_loss': rqvae_loss.mean().item(),
+ 'cf_loss': cf_loss.item(),
+
+ 'clusters_counts': clusters_counts,
+ 'clusters': torch.stack(clusters).T,
+ 'embedding_hat': embeddings_restored,
+ }
+
+ def CF_loss(self, quantized_rep, encoded_rep):
+ batch_size = quantized_rep.size(0)
+ labels = torch.arange(batch_size, dtype=torch.long, device=quantized_rep.device)
+ similarities = quantized_rep @ encoded_rep.T
+ cf_loss = F.cross_entropy(similarities, labels)
+ return cf_loss
diff --git a/scripts/plum-lsvd/letter_base_gap/train_letter.py b/scripts/plum-lsvd/letter_base_gap/train_letter.py
new file mode 100644
index 0000000..08b884b
--- /dev/null
+++ b/scripts/plum-lsvd/letter_base_gap/train_letter.py
@@ -0,0 +1,165 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from letter import LetterRQVAE
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 35
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 64
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+
+RQVAE_LOSS_WEIGHT=1.0
+CF_LOSS_WEIGHT=0.01
+
+EXPERIMENT_NAME = f'letter_vk-lsvd-15ts_base_with_gap_e{NUM_EPOCHS}_rqvae_{RQVAE_LOSS_WEIGHT}_cf_{CF_LOSS_WEIGHT}'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/only_base_with_gap_items_metadata_remapped.parquet"
+IREC_PATH = '../../../'
+
+SASREC_MODEL_PATH = "/home/jovyan/IRec/checkpoints-lsvd-transformer/sasrec_vk_lsvd_base_gap_msl_20_e_300_best_0.0223.pth"
+
+print(EMBEDDINGS_PATH)
+print(SASREC_MODEL_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH,
+ )
+
+ train_dataloader = DataLoader( #call в основном потоке делается нужно исправить
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ )
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ cf_embeddings = torch.load(
+ SASREC_MODEL_PATH,
+ map_location=DEVICE
+ )['_orig_mod._item_embeddings.weight']
+
+ print(f"cf embeds shape: {cf_embeddings.shape}")
+
+ model = LetterRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=RQVAE_LOSS_WEIGHT,
+ cf_loss_weight=CF_LOSS_WEIGHT,
+ cf_embeddings=cf_embeddings
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'cf_loss': model_outputs['cf_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/cf_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'cf_loss': model_outputs['cf_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/cf_loss': cb.MeanAccumulator(),
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints-lsvd', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/models.py b/scripts/plum-lsvd/models.py
new file mode 100644
index 0000000..d475712
--- /dev/null
+++ b/scripts/plum-lsvd/models.py
@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class PlumRQVAE(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_codebooks,
+ codebook_size,
+ embedding_dim,
+ beta=0.25,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=0.0,
+ ):
+ super().__init__()
+ self.register_buffer('beta', torch.tensor(beta))
+ self.temperature = temperature
+
+ self.input_dim = input_dim
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.embedding_dim = embedding_dim
+ self.quant_loss_weight = quant_loss_weight
+
+ self.contrastive_loss_weight = contrastive_loss_weight
+
+ self.encoder = self.make_encoding_tower(input_dim, embedding_dim)
+ self.decoder = self.make_encoding_tower(embedding_dim, input_dim)
+
+ self.codebooks = torch.nn.ParameterList()
+ for _ in range(num_codebooks):
+ cb = torch.FloatTensor(codebook_size, embedding_dim)
+ #nn.init.normal_(cb)
+ self.codebooks.append(cb)
+
+ @staticmethod
+ def make_encoding_tower(d1, d2, bias=False):
+ return torch.nn.Sequential(
+ nn.Linear(d1, d1),
+ nn.ReLU(),
+ nn.Linear(d1, d2),
+ nn.ReLU(),
+ nn.Linear(d2, d2, bias=bias)
+ )
+
+ @staticmethod
+ def get_codebook_indices(remainder, codebook):
+ dist = torch.cdist(remainder, codebook)
+ return dist.argmin(dim=-1)
+
+ def _quantize_representation(self, latent_vector):
+ latent_restored = 0
+ remainder = latent_vector
+
+ for codebook in self.codebooks:
+ codebook_indices = self.get_codebook_indices(remainder, codebook)
+ quantized = codebook[codebook_indices]
+ codebook_vectors = remainder + (quantized - remainder).detach()
+ latent_restored += codebook_vectors
+ remainder = remainder - codebook_vectors
+
+ return latent_restored
+
+ def contrastive_loss(self, p_i, p_i_star):
+ N_b = p_i.size(0)
+
+ p_i = F.normalize(p_i, p=2, dim=-1) #TODO посмотреть без нормалайза
+ p_i_star = F.normalize(p_i_star, p=2, dim=-1)
+
+ similarities = torch.matmul(p_i, p_i_star.T) / self.temperature
+
+ labels = torch.arange(N_b, dtype=torch.long, device=p_i.device)
+
+ loss = F.cross_entropy(similarities, labels)
+
+ return loss
+
+ def forward(self, inputs):
+ latent_vector = self.encoder(inputs['embedding'])
+ item_ids = inputs['item_id']
+
+ latent_restored = 0
+ rqvae_loss = 0
+ clusters = []
+ remainder = latent_vector
+
+ for codebook in self.codebooks:
+ codebook_indices = self.get_codebook_indices(remainder, codebook)
+ clusters.append(codebook_indices)
+
+ quantized = codebook[codebook_indices]
+ codebook_vectors = remainder + (quantized - remainder).detach()
+
+ rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach())
+ rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach())
+
+ latent_restored += codebook_vectors
+ remainder = remainder - codebook_vectors
+
+ embeddings_restored = self.decoder(latent_restored)
+ recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding'])
+
+ if 'cooccurrence_embedding' in inputs:
+ cooccurrence_latent = self.encoder(inputs['cooccurrence_embedding'].to(latent_restored.device))
+ cooccurrence_restored = self._quantize_representation(cooccurrence_latent)
+ con_loss = self.contrastive_loss(latent_restored, cooccurrence_restored)
+ else:
+ con_loss = torch.as_tensor(0.0, device=latent_vector.device)
+
+ loss = (
+ recon_loss
+ + self.quant_loss_weight * rqvae_loss
+ + self.contrastive_loss_weight * con_loss
+ ).mean()
+
+ clusters_counts = []
+ for cluster in clusters:
+ clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size))
+
+ return loss, {
+ 'loss': loss.item(),
+ 'recon_loss': recon_loss.mean().item(),
+ 'rqvae_loss': rqvae_loss.mean().item(),
+ 'con_loss': con_loss.item(),
+
+ 'clusters_counts': clusters_counts,
+ 'clusters': torch.stack(clusters).T,
+ 'embedding_hat': embeddings_restored,
+ }
\ No newline at end of file
diff --git a/scripts/plum-lsvd/plum_base_gap/create_base_gap_mapping_from_all.py b/scripts/plum-lsvd/plum_base_gap/create_base_gap_mapping_from_all.py
new file mode 100644
index 0000000..112c579
--- /dev/null
+++ b/scripts/plum-lsvd/plum_base_gap/create_base_gap_mapping_from_all.py
@@ -0,0 +1,52 @@
+import json
+import pandas as pd
+from pathlib import Path
+
+
+ALL_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_clusters_colisionless.json"
+TRAIN_INTERACTIONS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet"
+OUTPUT_TRAIN_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/only_base_with_gap_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_clusters_colisionless_from_all.json"
+
+
+with open(ALL_MAPPING_PATH, 'r') as f:
+ all_mapping = json.load(f)
+print(f"Loaded {len(all_mapping)} items from all_mapping")
+
+train_interactions = pd.read_parquet(TRAIN_INTERACTIONS_PATH)
+
+train_item_ids = set()
+for item_ids_array in train_interactions['item_ids']:
+ train_item_ids.update(item_ids_array)
+
+print(f"Found {len(train_item_ids)} unique train items")
+
+train_mapping = {}
+missing_count = 0
+
+for item_id in train_item_ids:
+ item_id_str = str(item_id)
+ if item_id_str in all_mapping:
+ train_mapping[item_id_str] = all_mapping[item_id_str]
+ else:
+ missing_count += 1
+
+if missing_count > 0:
+ print(f"{missing_count} items from train not found in all_mapping")
+
+print(f"Created train_mapping with {len(train_mapping)} items")
+
+Path(OUTPUT_TRAIN_MAPPING_PATH).parent.mkdir(parents=True, exist_ok=True)
+with open(OUTPUT_TRAIN_MAPPING_PATH, 'w') as f:
+ json.dump(train_mapping, f, indent=2)
+print(f"Saved to {OUTPUT_TRAIN_MAPPING_PATH}")
+
+print(f"all_mapping size: {len(all_mapping)}")
+print(f"train_mapping size: {len(train_mapping)}")
+print(f"train_mapping/all_mapping ratio: {len(train_mapping)/len(all_mapping):.1%}")
+
+sample_matches = 0
+for item_id_str in list(train_mapping.keys())[:100]:
+ if all_mapping[item_id_str] == train_mapping[item_id_str]:
+ sample_matches += 1
+
+print(f"Verified: {sample_matches}/100 sampled items have identical codes")
diff --git a/scripts/plum-lsvd/plum_base_gap/infer_plum_on_all_items.py b/scripts/plum-lsvd/plum_base_gap/infer_plum_on_all_items.py
new file mode 100644
index 0000000..d1473cc
--- /dev/null
+++ b/scripts/plum-lsvd/plum_base_gap/infer_plum_on_all_items.py
@@ -0,0 +1,148 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 35
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+K=3000
+
+RQVAE_LOSS_WEIGHT=1.0
+CON_LOSS_WEIGHT=0.01
+
+EXPERIMENT_NAME = f'all_items_plum_vk-lsvd-15ts_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_e{NUM_EPOCHS}_con_{CON_LOSS_WEIGHT}_rqvae_{RQVAE_LOSS_WEIGHT}'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/items_metadata_remapped.parquet"
+IREC_PATH = '../../../'
+
+MODEL_PATH = "/home/jovyan/IRec/checkpoints-lsvd/plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_best_0.008.pth"
+RESULTS_PATH = os.path.join(IREC_PATH, 'results-lsvd-2/base_gap')
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=RQVAE_LOSS_WEIGHT,
+ contrastive_loss_weight=CON_LOSS_WEIGHT,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/plum_base_gap/train_plum_base_gap.py b/scripts/plum-lsvd/plum_base_gap/train_plum_base_gap.py
new file mode 100644
index 0000000..3960337
--- /dev/null
+++ b/scripts/plum-lsvd/plum_base_gap/train_plum_base_gap.py
@@ -0,0 +1,186 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddingsVectorized
+from cooc_data import CoocMappingDataset
+
+# ЭКСПЕРИМЕНТ С ОБРЕЗАННОЙ ИСТОРИЕЙ
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 35
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+K=3000
+
+RQVAE_LOSS_WEIGHT=1.0
+CON_LOSS_WEIGHT=0.01
+
+EXPERIMENT_NAME = f'plum_vk-lsvd-15ts_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_e{NUM_EPOCHS}_con_{CON_LOSS_WEIGHT}_rqvae_{RQVAE_LOSS_WEIGHT}'
+INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet"
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/only_base_with_gap_items_metadata_remapped.parquet"
+IREC_PATH = '../../'
+
+print(INTER_TRAIN_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH,
+ )
+
+ data = CoocMappingDataset.create_from_split_part(
+ train_inter_parquet_path=INTER_TRAIN_PATH,
+ window_size=WINDOW_SIZE
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE)
+ all_item_ids.append(item_id)
+
+ # add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized(data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids, K, DEVICE,seed=42)
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized(
+ cooccur_counts=data.cooccur_counter_mapping,
+ item_id_to_embedding=item_id_to_embedding,
+ all_item_ids=all_item_ids,
+ device=DEVICE,
+ max_neighbors=K,
+ seed=42
+ )
+
+ train_dataloader = DataLoader( #call в основном потоке делается нужно исправить
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform
+ ).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform)
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=RQVAE_LOSS_WEIGHT,
+ contrastive_loss_weight=CON_LOSS_WEIGHT,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints-lsvd', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/rqvae_base_gap/create_base_gap_mapping_from_all.py b/scripts/plum-lsvd/rqvae_base_gap/create_base_gap_mapping_from_all.py
new file mode 100644
index 0000000..a86f0c2
--- /dev/null
+++ b/scripts/plum-lsvd/rqvae_base_gap/create_base_gap_mapping_from_all.py
@@ -0,0 +1,52 @@
+import json
+import pandas as pd
+from pathlib import Path
+
+
+ALL_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_clusters_colisionless.json"
+TRAIN_INTERACTIONS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet"
+OUTPUT_TRAIN_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/only_base_with_gap_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_clusters_colisionless_from_all.json"
+
+
+with open(ALL_MAPPING_PATH, 'r') as f:
+ all_mapping = json.load(f)
+print(f"Loaded {len(all_mapping)} items from all_mapping")
+
+train_interactions = pd.read_parquet(TRAIN_INTERACTIONS_PATH)
+
+train_item_ids = set()
+for item_ids_array in train_interactions['item_ids']:
+ train_item_ids.update(item_ids_array)
+
+print(f"Found {len(train_item_ids)} unique train items")
+
+train_mapping = {}
+missing_count = 0
+
+for item_id in train_item_ids:
+ item_id_str = str(item_id)
+ if item_id_str in all_mapping:
+ train_mapping[item_id_str] = all_mapping[item_id_str]
+ else:
+ missing_count += 1
+
+if missing_count > 0:
+ print(f"{missing_count} items from train not found in all_mapping")
+
+print(f"Created train_mapping with {len(train_mapping)} items")
+
+Path(OUTPUT_TRAIN_MAPPING_PATH).parent.mkdir(parents=True, exist_ok=True)
+with open(OUTPUT_TRAIN_MAPPING_PATH, 'w') as f:
+ json.dump(train_mapping, f, indent=2)
+print(f"Saved to {OUTPUT_TRAIN_MAPPING_PATH}")
+
+print(f"all_mapping size: {len(all_mapping)}")
+print(f"train_mapping size: {len(train_mapping)}")
+print(f"train_mapping/all_mapping ratio: {len(train_mapping)/len(all_mapping):.1%}")
+
+sample_matches = 0
+for item_id_str in list(train_mapping.keys())[:100]:
+ if all_mapping[item_id_str] == train_mapping[item_id_str]:
+ sample_matches += 1
+
+print(f"Verified: {sample_matches}/100 sampled items have identical codes")
diff --git a/scripts/plum-lsvd/rqvae_base_gap/infer_rqvae.py b/scripts/plum-lsvd/rqvae_base_gap/infer_rqvae.py
new file mode 100644
index 0000000..3bd030d
--- /dev/null
+++ b/scripts/plum-lsvd/rqvae_base_gap/infer_rqvae.py
@@ -0,0 +1,158 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from collections import Counter
+from models import PlumRQVAE
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 35
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+
+RQVAE_LOSS_WEIGHT=1.0
+
+EXPERIMENT_NAME = f'all_items_rqvae_vk-lsvd-15ts_base_with_gap_e{NUM_EPOCHS}_rqvae_{RQVAE_LOSS_WEIGHT}'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/items_metadata_remapped.parquet"
+IREC_PATH = '../../../'
+
+MODEL_PATH = "/home/jovyan/IRec/checkpoints-lsvd/rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_best_0.0073.pth"
+RESULTS_PATH = os.path.join(IREC_PATH, 'results-lsvd-2/base_gap')
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=RQVAE_LOSS_WEIGHT
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ collision_stats = []
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ collision_stats.append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ if collision_stats:
+ max_col_tok = max(collision_stats)
+ avg_col_tok = np.mean(collision_stats)
+ collision_distribution = Counter(collision_stats)
+
+ print(f"Max collision token: {max_col_tok}")
+ print(f"Avg collision token: {avg_col_tok:.2f}")
+ print(f"Total items with collisions: {len(collision_stats)}")
+ print(f"Collision solver distribution: {dict(collision_distribution)}")
+ else:
+ print("No collisions detected")
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/rqvae_base_gap/train_rqvae.py b/scripts/plum-lsvd/rqvae_base_gap/train_rqvae.py
new file mode 100644
index 0000000..50c0bce
--- /dev/null
+++ b/scripts/plum-lsvd/rqvae_base_gap/train_rqvae.py
@@ -0,0 +1,161 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 35
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+
+RQVAE_LOSS_WEIGHT=1.0
+
+EXPERIMENT_NAME = f'rqvae_vk-lsvd-15ts_base_with_gap_e{NUM_EPOCHS}_rqvae_{RQVAE_LOSS_WEIGHT}'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/only_base_with_gap_items_metadata_remapped.parquet"
+IREC_PATH = '../../../'
+
+print(EMBEDDINGS_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH,
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE)
+ all_item_ids.append(item_id)
+
+ train_dataloader = DataLoader( #call в основном потоке делается нужно исправить
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ )
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=RQVAE_LOSS_WEIGHT
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints-lsvd', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/transform_test.py b/scripts/plum-lsvd/transform_test.py
new file mode 100644
index 0000000..b977be8
--- /dev/null
+++ b/scripts/plum-lsvd/transform_test.py
@@ -0,0 +1,145 @@
+import torch
+import numpy as np
+import pytest
+from transforms import AddWeightedCooccurrenceEmbeddingsVectorized
+
+def test_add_weighted_cooccurrence_embeddings():
+ torch.manual_seed(42)
+ np.random.seed(42)
+ device = torch.device('cpu') # Тестируем на CPU для простоты
+
+ print("\n" + "="*80)
+ print("TEST 1: Normal case with cooccurrences")
+ print("="*80)
+
+ # Граф:
+ # 0 -> {1: 10, 2: 5} (чаще 1)
+ # 1 -> {0: 10} (только 0)
+ # 2 -> {} (нет соседей)
+ cooccur_counts = {
+ 0: {1: 10, 2: 5},
+ 1: {0: 10},
+ 2: {},
+ }
+
+ # Эмбеддинги (one-hot для наглядности)
+ # 0: [1, 0, 0]
+ # 1: [0, 1, 0]
+ # 2: [0, 0, 1]
+ item_embeddings = {
+ 0: torch.tensor([1.0, 0.0, 0.0]),
+ 1: torch.tensor([0.0, 1.0, 0.0]),
+ 2: torch.tensor([0.0, 0.0, 1.0]),
+ }
+
+ all_item_ids = [0, 1, 2]
+
+ # Инициализация
+ transform = AddWeightedCooccurrenceEmbeddingsVectorized(
+ cooccur_counts=cooccur_counts,
+ item_id_to_embedding=item_embeddings,
+ all_item_ids=all_item_ids,
+ device=device,
+ max_neighbors=2, # Ограничим до 2 для проверки обрезки
+ seed=42
+ )
+
+ # ----------------------------------------------------------------
+ # Проверка 1: Айтем 1 должен всегда выбирать соседа 0 (вероятность 1.0)
+ # ----------------------------------------------------------------
+ batch_1 = {'item_id': torch.tensor([1, 1, 1], device=device)}
+ res_1 = transform(batch_1)
+ # Ожидаем эмбеддинг айтема 0: [1.0, 0.0, 0.0]
+ expected_emb_0 = item_embeddings[0].to(device)
+
+ print("Checking deterministic neighbor (1 -> 0)...")
+ for emb in res_1['cooccurrence_embedding']:
+ assert torch.allclose(emb, expected_emb_0), \
+ f"Item 1 should strictly link to 0. Got {emb}"
+ print("✅ Deterministic neighbor check passed")
+
+ # ----------------------------------------------------------------
+ # Проверка 2: Айтем 2 (без соседей) должен выдавать валидный случайный эмбеддинг
+ # ----------------------------------------------------------------
+ batch_2 = {'item_id': torch.tensor([2] * 100, device=device)}
+ res_2 = transform(batch_2)
+ embs_2 = res_2['cooccurrence_embedding']
+
+ print("Checking fallback for item without neighbors...")
+ # Проверяем, что нет NaN
+ assert not torch.isnan(embs_2).any(), "NaN found in fallback embeddings"
+
+ # Проверяем, что возвращаются реальные эмбеддинги из словаря
+ valid_embs_set = {tuple(e.tolist()) for e in item_embeddings.values()}
+ for emb in embs_2[:10]: # Проверим первые 10
+ assert tuple(emb.tolist()) in valid_embs_set, f"Invalid embedding generated: {emb}"
+ print("✅ Fallback check passed")
+
+ # ----------------------------------------------------------------
+ # Проверка 3: Распределение вероятностей (Item 0 -> 1(66%) vs 2(33%))
+ # ----------------------------------------------------------------
+ # Запустим большой батч для статистики
+ batch_0 = {'item_id': torch.tensor([0] * 1000, device=device)}
+ res_0 = transform(batch_0)
+ embs_0 = res_0['cooccurrence_embedding']
+
+ # Считаем, сколько раз выпал эмбеддинг 1 (сосед с весом 10) и эмбеддинг 2 (сосед с весом 5)
+ # Emb 1 = [0, 1, 0], Emb 2 = [0, 0, 1]
+ count_1 = (embs_0[:, 1] == 1.0).sum().item()
+ count_2 = (embs_0[:, 2] == 1.0).sum().item()
+
+ ratio = count_1 / (count_1 + count_2)
+ expected_ratio = 10 / 15 # ~0.666
+
+ print(f"Checking distribution for Item 0. Expected ~{expected_ratio:.2f}, Got {ratio:.2f}")
+ assert abs(ratio - expected_ratio) < 0.05, \
+ f"Distribution mismatch! Expected {expected_ratio:.2f}, got {ratio:.2f}"
+ print("✅ Distribution check passed")
+
+ print("\n" + "="*80)
+ print("TEST 2: Edge Cases")
+ print("="*80)
+
+ # ----------------------------------------------------------------
+ # Проверка 4: Пустой батч
+ # ----------------------------------------------------------------
+ batch_empty = {'item_id': torch.tensor([], dtype=torch.long, device=device)}
+ res_empty = transform(batch_empty)
+ assert res_empty['cooccurrence_embedding'].shape[0] == 0
+ print("✅ Empty batch passed")
+
+ # ----------------------------------------------------------------
+ # Проверка 5: Item ID вне списка all_item_ids (например, padding index или новый айтем)
+ # В текущей реализации searchsorted, индексы клемпятся.
+ # Проверим, что код не падает.
+ # ----------------------------------------------------------------
+ unknown_id = 999
+ batch_unknown = {'item_id': torch.tensor([unknown_id], device=device)}
+
+ try:
+ res_unknown = transform(batch_unknown)
+ print("✅ Unknown item ID handled (no crash)")
+ # В идеале тут надо проверить, что вернулось (скорее всего, neighbor для последнего айтема)
+ except Exception as e:
+ pytest.fail(f"Crashed on unknown item ID: {e}")
+
+ # ----------------------------------------------------------------
+ # Проверка 6: Воспроизводимость (Seed)
+ # ----------------------------------------------------------------
+ batch_seed = {'item_id': torch.tensor([0, 2, 0, 1] * 10, device=device)}
+ transform_a = AddWeightedCooccurrenceEmbeddingsVectorized(
+ cooccur_counts, item_embeddings, all_item_ids, device, seed=42
+ )
+ res_a = transform_a(batch_seed)['cooccurrence_embedding']
+ transform_b = AddWeightedCooccurrenceEmbeddingsVectorized(
+ cooccur_counts, item_embeddings, all_item_ids, device, seed=42
+ )
+ res_b = transform_b(batch_seed)['cooccurrence_embedding']
+
+ assert torch.allclose(res_a, res_b), "Results differ with same seed!"
+ print("✅ Seeding reproducibility passed")
+
+ print("\n✅ ALL TESTS PASSED!")
+
+if __name__ == "__main__":
+ test_add_weighted_cooccurrence_embeddings()
diff --git a/scripts/plum-lsvd/transforms.py b/scripts/plum-lsvd/transforms.py
new file mode 100644
index 0000000..bc8f812
--- /dev/null
+++ b/scripts/plum-lsvd/transforms.py
@@ -0,0 +1,111 @@
+import numpy as np
+import torch
+from typing import Dict, List
+import time
+
+import torch
+from typing import Dict, List
+
+class AddWeightedCooccurrenceEmbeddingsVectorized:
+ def __init__(
+ self,
+ cooccur_counts: Dict[int, Dict[int, int]],
+ item_id_to_embedding: Dict[int, torch.Tensor],
+ all_item_ids: List[int],
+ device: torch.device,
+ max_neighbors: int = 128,
+ seed: int = 42,
+ ):
+ self.device = device
+ self.max_neighbors = max_neighbors
+ torch.manual_seed(seed)
+
+ self.all_item_ids = torch.tensor(sorted(all_item_ids), dtype=torch.long, device=device)
+ self.num_items = len(self.all_item_ids)
+
+ # 2. Эмбеддинги
+ emb_dim = list(item_id_to_embedding.values())[0].shape[0]
+ self.embedding_matrix = torch.zeros((self.num_items, emb_dim), dtype=torch.float32, device=device)
+
+ id_to_idx = {iid.item(): i for i, iid in enumerate(self.all_item_ids.cpu())}
+
+ for iid, emb in item_id_to_embedding.items():
+ if iid in id_to_idx:
+ self.embedding_matrix[id_to_idx[iid]] = emb.to(device)
+
+
+ self.neighbors = torch.randint(0, self.num_items, (self.num_items, max_neighbors), device=device)
+ self.probs = torch.full((self.num_items, max_neighbors), 1.0 / max_neighbors, device=device)
+
+ neighbors_cpu = self.neighbors.cpu()
+ probs_cpu = self.probs.cpu()
+
+ for item_id, neighbors_dict in cooccur_counts.items():
+ if item_id not in id_to_idx: continue
+ idx = id_to_idx[item_id]
+
+ if not neighbors_dict: continue
+
+ top_k = sorted(neighbors_dict.items(), key=lambda x: x[1], reverse=True)[:max_neighbors]
+ ids, counts = zip(*top_k)
+
+ valid_pairs = []
+ for n_id, c in zip(ids, counts):
+ n_idx = id_to_idx.get(n_id, -1)
+ if n_idx != -1:
+ valid_pairs.append((n_idx, c))
+
+ if not valid_pairs: continue
+
+ final_indices, final_counts = zip(*valid_pairs)
+ k_len = len(final_indices)
+
+ count_tensor = torch.tensor(final_counts, dtype=torch.float32)
+ prob_tensor = count_tensor / count_tensor.sum()
+
+ neighbors_cpu[idx, :k_len] = torch.tensor(final_indices, dtype=torch.long)
+ probs_cpu[idx, :k_len] = prob_tensor
+
+ if k_len < max_neighbors:
+ probs_cpu[idx, k_len:] = 0.0
+
+ self.neighbors = neighbors_cpu.to(device)
+ self.probs = probs_cpu.to(device)
+
+ def __call__(self, batch):
+ item_ids = batch['item_id'].to(self.device)
+
+ indices = torch.searchsorted(self.all_item_ids, item_ids)
+ indices = indices.clamp(max=self.num_items - 1)
+
+ found_mask = (self.all_item_ids[indices] == item_ids)
+
+ if not found_mask.all():
+ missing_count = (~found_mask).sum().item()
+ missing_examples = item_ids[~found_mask][:5].tolist()
+ print(f"[WARNING] Batch contains {missing_count} unknown items! Examples: {missing_examples}")
+ print(f" Assigning RANDOM embeddings for unknown items.")
+
+ # кого не нашли, подменим индекс на 0 временно
+ safe_indices = indices.clone()
+ safe_indices[~found_mask] = 0
+
+ batch_probs = self.probs[safe_indices] # (B, max_neighbors)
+ neighbor_local_indices = torch.multinomial(batch_probs, num_samples=1).squeeze(1) # (B)
+ selected_neighbor_indices = self.neighbors[safe_indices, neighbor_local_indices] # (B)
+ cooc_embeddings = self.embedding_matrix[selected_neighbor_indices]
+
+ if not found_mask.all():
+ # случайные индексы айтемов без истории
+ random_indices = torch.randint(0, self.num_items, (item_ids.shape[0],), device=self.device)
+ random_embeddings = self.embedding_matrix[random_indices]
+
+ # шум
+ cooc_embeddings = torch.where(
+ found_mask.unsqueeze(1),
+ cooc_embeddings,
+ random_embeddings
+ )
+
+ batch['cooccurrence_embedding'] = cooc_embeddings
+ return batch
diff --git a/scripts/sasrec-lsvd/base_gap/train_bg-v-t.py b/scripts/sasrec-lsvd/base_gap/train_bg-v-t.py
new file mode 100644
index 0000000..2861b2c
--- /dev/null
+++ b/scripts/sasrec-lsvd/base_gap/train_bg-v-t.py
@@ -0,0 +1,202 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.models import AutoCast
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import ArrowBatchDataset
+from models import SasRecModel
+
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+IREC_PATH = '../../../'
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/eval_batches/')
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints-lsvd-transformer')
+
+EXPERIMENT_NAME = 'sasrec_vk_lsvd_base_gap_msl_20_e_300'
+
+NUM_EPOCHS = 10
+MAX_SEQ_LEN = 20
+# TRAIN_BATCH_SIZE = 256
+# VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 64
+NUM_HEADS = 2
+NUM_LAYERS = 2
+FEEDFORWARD_DIM = 256
+DROPOUT = 0.3
+LR = 1e-4
+
+NUM_ITEMS = 66000
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=None
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=16,
+ prefetch_factor=16,
+ pin_memory=True,
+ persistent_workers=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = SasRecModel(
+ num_items=NUM_ITEMS,
+ max_sequence_length=MAX_SEQ_LEN,
+ embedding_dim=EMBEDDING_DIM,
+ num_heads=NUM_HEADS,
+ num_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ activation='relu',
+ topk_k=20,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-8,
+ initializer_range=0.02
+ )
+ model = torch.compile(model, mode="default", fullgraph=False)
+ model = model.to('cuda')
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS * 4),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='validation/ndcg@20',
+ patience=40 * 4,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/sasrec-lsvd/base_gap/varka_bg-v-t.py b/scripts/sasrec-lsvd/base_gap/varka_bg-v-t.py
new file mode 100644
index 0000000..8f131ad
--- /dev/null
+++ b/scripts/sasrec-lsvd/base_gap/varka_bg-v-t.py
@@ -0,0 +1,104 @@
+from collections import defaultdict
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+from irec.data.transforms import Collate
+from irec.data.dataloader import DataLoader
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import Dataset
+
+# ПУТИ
+IREC_PATH = '../../../'
+
+INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet"
+INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/val_interactions_grouped.parquet"
+INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/test_interactions_grouped.parquet"
+
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/eval_batches/')
+
+NUM_EPOCHS = 300
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+ length_groups[length][key] = value
+
+ for length, fields in length_groups.items():
+ arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ table = pa.table(arrow_dict)
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+def main():
+ data = Dataset.create_timestamp_based_parquet(
+ train_parquet_path=INTERACTIONS_TRAIN_PATH,
+ validation_parquet_path=INTERACTIONS_VALID_PATH,
+ test_parquet_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='sasrec',
+ min_sample_len=2,
+ is_extended=False,
+ max_train_events=MAX_SEQ_LEN
+ )
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ).map(Collate()).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ).map(Collate())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ).map(Collate())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/sasrec-lsvd/data.py b/scripts/sasrec-lsvd/data.py
new file mode 100644
index 0000000..4ae0e18
--- /dev/null
+++ b/scripts/sasrec-lsvd/data.py
@@ -0,0 +1,359 @@
+from collections import defaultdict
+import json
+from loguru import logger
+import numpy as np
+from pathlib import Path
+
+import copy
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+import polars as pl
+from irec.data.base import BaseDataset
+
+
+class InteractionsDatasetParquet(BaseDataset):
+ def __init__(self, data_path, max_items=None):
+ self.df = pl.read_parquet(data_path)
+ assert 'uid' in self.df.columns, "Missing 'uid' column"
+ assert 'item_ids' in self.df.columns, "Missing 'item_ids' column"
+ print(f"Dataset loaded: {len(self.df)} users")
+
+ if max_items is not None:
+ self.df = self.df.with_columns(
+ pl.col("item_ids").list.slice(-max_items).alias("item_ids")
+ )
+
+ def __getitem__(self, idx):
+ row = self.df.row(idx, named=True)
+ return {
+ 'user_id': row['uid'],
+ 'item_ids': np.array(row['item_ids'], dtype=np.uint32),
+ }
+
+ def __len__(self):
+ return len(self.df)
+
+ def __iter__(self):
+ for idx in range(len(self)):
+ yield self[idx]
+
+
+class Dataset:
+ def __init__(
+ self,
+ train_sampler,
+ validation_sampler,
+ test_sampler,
+ num_items,
+ max_sequence_length
+ ):
+ self._train_sampler = train_sampler
+ self._validation_sampler = validation_sampler
+ self._test_sampler = test_sampler
+ self._num_items = num_items
+ self._max_sequence_length = max_sequence_length
+
+ @classmethod
+ def create_timestamp_based_parquet(
+ cls,
+ train_parquet_path,
+ validation_parquet_path,
+ test_parquet_path,
+ max_sequence_length,
+ sampler_type,
+ min_sample_len=2,
+ is_extended=False,
+ max_train_events=20
+ ):
+ """
+ Загружает данные из parquet файлов с timestamp-based сплитом.
+
+ Ожидает структуру parquet:
+ - uid: int (user id)
+ - item_ids: list[int] (список item ids)
+
+ Аналогично create_timestamp_based, но для parquet формата.
+ """
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+
+ print(f"started to load datasets from parquet with max train length {max_train_events}")
+
+ # Загружаем parquet файлы
+ train_df = pl.read_parquet(train_parquet_path)
+ validation_df = pl.read_parquet(validation_parquet_path)
+ test_df = pl.read_parquet(test_parquet_path)
+
+ # Проверяем наличие необходимых колонок
+ for df, name in [(train_df, "train"), (validation_df, "validation"), (test_df, "test")]:
+ assert 'uid' in df.columns, f"Missing 'uid' column in {name}"
+ assert 'item_ids' in df.columns, f"Missing 'item_ids' column in {name}"
+
+ # Создаем словари для быстрого доступа
+ train_data = {str(row['uid']): row['item_ids'] for row in train_df.iter_rows(named=True)}
+ validation_data = {str(row['uid']): row['item_ids'] for row in validation_df.iter_rows(named=True)}
+ test_data = {str(row['uid']): row['item_ids'] for row in test_df.iter_rows(named=True)}
+
+ all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys())
+ print(f"all users count: {len(all_users)}")
+
+ us_count = 0
+ for user_id_str in all_users:
+ if us_count % 100 == 0:
+ print(f"user id {us_count}/{len(all_users)}: {user_id_str}")
+
+ user_id = int(user_id_str)
+
+ # Получаем последовательности для каждого сплита
+ train_items = list(train_data.get(user_id_str, []))
+ validation_items = list(validation_data.get(user_id_str, []))
+ test_items = list(test_data.get(user_id_str, []))
+
+ # Обрезаем train на последние max_train_events событий
+ train_items = train_items[-max_train_events:]
+
+ full_sequence = train_items + validation_items + test_items
+ if full_sequence:
+ max_item_id = max(max_item_id, max(full_sequence))
+
+ if us_count % 100 == 0:
+ print(f"full sequence len: {len(full_sequence)}")
+
+ us_count += 1
+ if len(train_items) < 2:
+ print(f'Core-2 train dataset is used, user {user_id} has only {len(full_sequence)} items')
+ continue
+
+ if len(validation_items) < 1 or len(test_items) < 1:
+ print(f'Test or validation is empty, user {user_id} has only {len(full_sequence)} items')
+ continue
+
+
+ if is_extended:
+ for prefix_length in range(min_sample_len, len(train_items) + 1):
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': copy.deepcopy(train_items[:prefix_length]),
+ })
+ else:
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': copy.deepcopy(train_items),
+ })
+
+ # валидация
+
+
+ # разворачиваем каждый айтем из валидации в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+
+ current_history = train_items
+ valid_small_history = 0
+ for item in validation_items:
+ current_history.append(item)
+
+ if len(current_history) >= min_sample_len:
+ validation_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': copy.deepcopy(current_history[-max_sequence_length:]),
+ })
+ else:
+ valid_small_history += 1
+
+
+ # разворачиваем каждый айтем из теста в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+
+ test_small_history = 0
+ for item in test_items:
+ current_history.append(item)
+
+ if len(current_history) >= min_sample_len:
+ test_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': copy.deepcopy(current_history[-max_sequence_length:]),
+ })
+ else:
+ test_small_history += 1
+
+
+ print(f"Train dataset size: {len(train_dataset)}")
+ print(f"Validation dataset size: {len(validation_dataset)} with skipped {valid_small_history}")
+ print(f"Test dataset size: {len(test_dataset)} with skipped {test_small_history}")
+
+ logger.debug(f'Train dataset size: {len(train_dataset)}')
+ logger.debug(f'Validation dataset size: {len(validation_dataset)}')
+ logger.debug(f'Test dataset size: {len(test_dataset)}')
+
+ train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length)
+ validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length)
+ test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length)
+
+ return cls(
+ train_sampler=train_sampler,
+ validation_sampler=validation_sampler,
+ test_sampler=test_sampler,
+ num_items=max_item_id + 1, # +1 added because our ids are 0-indexed
+ max_sequence_length=max_sequence_length
+ )
+
+ def get_datasets(self):
+ return self._train_sampler, self._validation_sampler, self._test_sampler
+
+ @property
+ def num_items(self):
+ return self._num_items
+
+ @property
+ def max_sequence_length(self):
+ return self._max_sequence_length
+
+
+class TrainDataset(BaseDataset):
+ def __init__(self, dataset, prediction_type, max_sequence_length):
+ self._dataset = dataset
+ self._prediction_type = prediction_type
+ self._max_sequence_length = max_sequence_length
+
+ self._transforms = {
+ 'sasrec': self._all_items_transform,
+ 'tiger': self._last_item_transform
+ }
+
+ def _all_items_transform(self, sample):
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:]
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array(next_item_sequence, dtype=np.int64),
+ 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64)
+ }
+
+ def _last_item_transform(self, sample):
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ last_item = sample['item.ids'][-self._max_sequence_length:][-1]
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array([last_item], dtype=np.int64),
+ 'labels.length': np.array([1], dtype=np.int64),
+ }
+
+ def __getitem__(self, index):
+ return self._transforms[self._prediction_type](self._dataset[index])
+
+ def __len__(self):
+ return len(self._dataset)
+
+
+class EvalDataset(BaseDataset):
+ def __init__(self, dataset, max_sequence_length):
+ self._dataset = dataset
+ self._max_sequence_length = max_sequence_length
+
+ @property
+ def dataset(self):
+ return self._dataset
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def __getitem__(self, index):
+ sample = self._dataset[index]
+
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ next_item = sample['item.ids'][-self._max_sequence_length:][-1]
+
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array([next_item], dtype=np.int64),
+ 'labels.length': np.array([1], dtype=np.int64),
+ 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64),
+ 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64),
+ }
+
+
+class ArrowBatchDataset(BaseDataset):
+ def __init__(self, batch_dir, device='cuda', preload=False):
+ self.batch_dir = Path(batch_dir)
+ self.device = device
+
+ all_files = list(self.batch_dir.glob('batch_*_len_*.arrow'))
+
+ batch_files_map = defaultdict(list)
+ for f in all_files:
+ batch_id = int(f.stem.split('_')[1])
+ batch_files_map[batch_id].append(f)
+
+ for batch_id in batch_files_map:
+ batch_files_map[batch_id].sort()
+
+ self.batch_indices = sorted(batch_files_map.keys())
+
+ if preload:
+ print(f"Preloading {len(self.batch_indices)} batches...")
+ self.cached_batches = []
+
+ for idx in range(len(self.batch_indices)):
+ batch = self._load_batch(batch_files_map[self.batch_indices[idx]])
+ self.cached_batches.append(batch)
+ else:
+ self.cached_batches = None
+ self.batch_files_map = batch_files_map
+
+ def _load_batch(self, arrow_files):
+ batch = {}
+
+ for arrow_file in arrow_files:
+ table = feather.read_table(arrow_file)
+ metadata = table.schema.metadata or {}
+
+ for col_name in table.column_names:
+ col = table.column(col_name)
+
+ shape_key = f'{col_name}_shape'
+ dtype_key = f'{col_name}_dtype'
+
+ if shape_key.encode() in metadata:
+ shape = eval(metadata[shape_key.encode()].decode())
+ dtype = np.dtype(metadata[dtype_key.encode()].decode())
+
+ # Проверяем тип колонки
+ if pa.types.is_list(col.type) or pa.types.is_large_list(col.type):
+ arr = np.array(col.to_pylist(), dtype=dtype)
+ else:
+ arr = col.to_numpy().reshape(shape).astype(dtype)
+ else:
+ if pa.types.is_list(col.type) or pa.types.is_large_list(col.type):
+ arr = np.array(col.to_pylist())
+ else:
+ arr = col.to_numpy()
+
+ batch[col_name] = torch.from_numpy(arr.copy()).to(self.device)
+
+ return batch
+
+ def __len__(self):
+ return len(self.batch_indices)
+
+ def __getitem__(self, idx):
+ if self.cached_batches is not None:
+ return self.cached_batches[idx]
+ else:
+ batch_id = self.batch_indices[idx]
+ arrow_files = self.batch_files_map[batch_id]
+ return self._load_batch(arrow_files)
diff --git a/scripts/sasrec-lsvd/models.py b/scripts/sasrec-lsvd/models.py
new file mode 100644
index 0000000..0e1ea00
--- /dev/null
+++ b/scripts/sasrec-lsvd/models.py
@@ -0,0 +1,244 @@
+import torch
+import torch.nn as nn
+
+from irec.models import TorchModel, create_masked_tensor
+
+import torch._dynamo
+import torch
+import torch.nn as nn
+
+from irec.models import create_masked_tensor
+
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self,
+ embedding_dim,
+ layers,
+ dim_feedforward,
+ num_heads,
+ dropout,
+ activation,
+ causal,
+ prenorm=False
+ ):
+ super().__init__()
+ self.causal = causal
+ layer = nn.TransformerEncoderLayer(
+ d_model=embedding_dim,
+ nhead=num_heads,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ activation=activation,
+ layer_norm_eps=1e-5,
+ batch_first=True,
+ norm_first=prenorm
+ )
+ self.encoder = nn.TransformerEncoder(layer, num_layers=layers)
+ self._init_weights(initializer_range=0.02)
+
+ @torch.no_grad()
+ def _init_weights(self, initializer_range: float) -> None:
+ for key, value in self.named_parameters():
+ if 'weight' in key:
+ if 'norm' in key:
+ nn.init.ones_(value.data)
+ else:
+ nn.init.trunc_normal_(
+ value.data,
+ std=initializer_range,
+ a=-2 * initializer_range,
+ b=2 * initializer_range
+ )
+ elif 'bias' in key:
+ nn.init.zeros_(value.data)
+ else:
+ raise ValueError(f'Unknown transformer weight: {key}')
+
+ def forward(self, embeddings, lengths, max_seqlen):
+ padded_embeddings, mask = create_masked_tensor(
+ data=embeddings,
+ lengths=lengths
+ )
+
+ if self.causal:
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(
+ sz=mask.shape[-1],
+ device=mask.device,
+ dtype=mask.dtype
+ )
+ else:
+ causal_mask = None
+
+ padded_embeddings = self.encoder(
+ padded_embeddings,
+ mask=causal_mask,
+ src_key_padding_mask=~mask,
+ is_causal=self.causal
+ )
+ return padded_embeddings[mask]
+
+
+class SasRecModel(TorchModel):
+ def __init__(
+ self,
+ num_items,
+ max_sequence_length,
+ embedding_dim,
+ num_heads,
+ num_layers,
+ dim_feedforward,
+ activation,
+ topk_k,
+ dropout=0.0,
+ layer_norm_eps=1e-9,
+ initializer_range=0.02
+ ):
+ super().__init__()
+ self._num_items = num_items
+ self._num_heads = num_heads
+ self._embedding_dim = embedding_dim
+
+ self._item_embeddings = nn.Embedding(
+ num_embeddings=num_items,
+ embedding_dim=embedding_dim
+ )
+ self._position_embeddings = nn.Embedding(
+ num_embeddings=max_sequence_length,
+ embedding_dim=embedding_dim
+ )
+
+ self._topk_k = topk_k
+
+ self._encoder = TransformerEncoder(
+ embedding_dim=embedding_dim,
+ dim_feedforward=dim_feedforward,
+ layers=num_layers,
+ num_heads=num_heads,
+ dropout=dropout,
+ activation=activation,
+ causal=True,
+ )
+
+ self._init_weights(initializer_range)
+
+ def forward(self, inputs):
+ with torch._dynamo.config.patch(suppress_errors=True):
+ all_sample_events = inputs['item.ids'] # (total_batch_items)
+ all_sample_lengths = inputs['item.length'] # (batch_size)
+ max_seqlen = int(all_sample_lengths.max().item())
+
+ embeddings = self._item_embeddings(all_sample_events)
+
+ end_indices = all_sample_lengths.cumsum(dim=0) # (batch_size)
+ start_indices = end_indices - all_sample_lengths # (batch_size)
+
+ sample_indices = torch.arange(
+ all_sample_lengths.shape[0],
+ device=all_sample_lengths.device
+ ).repeat_interleave(all_sample_lengths) # (total_batch_items)
+
+ positions = torch.arange(
+ all_sample_events.shape[0],
+ device=all_sample_events.device
+ ) - start_indices[sample_indices] # (total_batch_items)
+
+ position_embeddings = self._position_embeddings(positions) # (total_batch_items, embedding_dim)
+
+ embeddings = embeddings + position_embeddings # (total_batch_items, embedding_dim)
+
+ all_sample_embeddings = self._encoder(embeddings=embeddings, lengths=all_sample_lengths, max_seqlen=max_seqlen) # (total_batch_items, embedding_dim)
+
+ all_positive_sample_events = inputs['labels.ids'] # (total_batch_items)
+
+ if not self.training:
+ offsets = torch.cumsum(all_sample_lengths, dim=-1)
+ all_sample_embeddings = all_sample_embeddings[offsets - 1]
+
+ all_embeddings = self._item_embeddings.weight # (num_items, embedding_dim)
+
+ # a -- total_batch_items, n -- num_items, d -- embedding_dim
+ all_scores = torch.einsum(
+ 'ad,nd->an',
+ all_sample_embeddings,
+ all_embeddings
+ ) # (total_batch_items, num_items)
+
+ positive_scores = torch.gather(
+ input=all_scores,
+ dim=1,
+ index=all_positive_sample_events[..., None]
+ )[:, 0] # (total_batch_items)
+
+ # Compute loss
+ negative_scores = torch.gather(
+ input=all_scores,
+ dim=1,
+ index=torch.randint(
+ low=0,
+ high=all_scores.shape[1],
+ size=all_positive_sample_events.shape,
+ device=all_positive_sample_events.device
+ )[..., None]
+ )[:, 0] # (total_batch_items)
+
+ with torch.autocast(device_type='cuda', enabled=False):
+ loss = self._compute_loss(
+ positive_scores.float(),
+ negative_scores.float()
+ )
+
+ metrics = {
+ 'loss': loss.detach()
+ }
+
+ if not self.training:
+ batch_size = all_sample_lengths.shape[0]
+ num_items = all_embeddings.shape[0]
+
+ padded_items, _ = create_masked_tensor(
+ data=all_sample_events,
+ lengths=all_sample_lengths,
+ ) # (batch_size, max_seq_len)
+
+ visited_mask = torch.zeros(
+ batch_size, num_items,
+ dtype=torch.bool,
+ device=all_sample_events.device
+ )
+
+ batch_indices = torch.arange(batch_size, device=all_sample_events.device)[:, None]
+ batch_indices = batch_indices.expand(-1, padded_items.shape[1])
+
+ visited_mask.scatter_(
+ dim=1,
+ index=padded_items.long(),
+ value=True
+ )
+
+ all_scores = all_scores.masked_fill(visited_mask, float('-inf'))
+
+ positive_position = (all_scores > positive_scores[:, None]).float().sum(dim=-1) # (batch_size или total_batch_items)
+ dcg_score = 1. / (torch.log2(positive_position + 1) + 1.)
+
+ for k in [5, 10, 20]:
+ metrics[f'recall@{k}'] = (positive_position < k).float()
+ metrics[f'ndcg@{k}'] = torch.where(
+ positive_position < k,
+ dcg_score,
+ torch.zeros_like(dcg_score)
+ ).float()
+
+ return loss, metrics
+
+ def _compute_loss(self, positive_scores, negative_scores):
+ assert positive_scores.shape[0] == negative_scores.shape[0]
+
+ loss = torch.nn.functional.binary_cross_entropy_with_logits(
+ positive_scores, torch.ones_like(positive_scores)
+ ) + torch.nn.functional.binary_cross_entropy_with_logits(
+ negative_scores, torch.zeros_like(negative_scores)
+ )
+
+ return loss
diff --git a/scripts/tiger-lsvd/data.py b/scripts/tiger-lsvd/data.py
new file mode 100644
index 0000000..4ae0e18
--- /dev/null
+++ b/scripts/tiger-lsvd/data.py
@@ -0,0 +1,359 @@
+from collections import defaultdict
+import json
+from loguru import logger
+import numpy as np
+from pathlib import Path
+
+import copy
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+import polars as pl
+from irec.data.base import BaseDataset
+
+
+class InteractionsDatasetParquet(BaseDataset):
+ def __init__(self, data_path, max_items=None):
+ self.df = pl.read_parquet(data_path)
+ assert 'uid' in self.df.columns, "Missing 'uid' column"
+ assert 'item_ids' in self.df.columns, "Missing 'item_ids' column"
+ print(f"Dataset loaded: {len(self.df)} users")
+
+ if max_items is not None:
+ self.df = self.df.with_columns(
+ pl.col("item_ids").list.slice(-max_items).alias("item_ids")
+ )
+
+ def __getitem__(self, idx):
+ row = self.df.row(idx, named=True)
+ return {
+ 'user_id': row['uid'],
+ 'item_ids': np.array(row['item_ids'], dtype=np.uint32),
+ }
+
+ def __len__(self):
+ return len(self.df)
+
+ def __iter__(self):
+ for idx in range(len(self)):
+ yield self[idx]
+
+
+class Dataset:
+ def __init__(
+ self,
+ train_sampler,
+ validation_sampler,
+ test_sampler,
+ num_items,
+ max_sequence_length
+ ):
+ self._train_sampler = train_sampler
+ self._validation_sampler = validation_sampler
+ self._test_sampler = test_sampler
+ self._num_items = num_items
+ self._max_sequence_length = max_sequence_length
+
+ @classmethod
+ def create_timestamp_based_parquet(
+ cls,
+ train_parquet_path,
+ validation_parquet_path,
+ test_parquet_path,
+ max_sequence_length,
+ sampler_type,
+ min_sample_len=2,
+ is_extended=False,
+ max_train_events=20
+ ):
+ """
+ Загружает данные из parquet файлов с timestamp-based сплитом.
+
+ Ожидает структуру parquet:
+ - uid: int (user id)
+ - item_ids: list[int] (список item ids)
+
+ Аналогично create_timestamp_based, но для parquet формата.
+ """
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+
+ print(f"started to load datasets from parquet with max train length {max_train_events}")
+
+ # Загружаем parquet файлы
+ train_df = pl.read_parquet(train_parquet_path)
+ validation_df = pl.read_parquet(validation_parquet_path)
+ test_df = pl.read_parquet(test_parquet_path)
+
+ # Проверяем наличие необходимых колонок
+ for df, name in [(train_df, "train"), (validation_df, "validation"), (test_df, "test")]:
+ assert 'uid' in df.columns, f"Missing 'uid' column in {name}"
+ assert 'item_ids' in df.columns, f"Missing 'item_ids' column in {name}"
+
+ # Создаем словари для быстрого доступа
+ train_data = {str(row['uid']): row['item_ids'] for row in train_df.iter_rows(named=True)}
+ validation_data = {str(row['uid']): row['item_ids'] for row in validation_df.iter_rows(named=True)}
+ test_data = {str(row['uid']): row['item_ids'] for row in test_df.iter_rows(named=True)}
+
+ all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys())
+ print(f"all users count: {len(all_users)}")
+
+ us_count = 0
+ for user_id_str in all_users:
+ if us_count % 100 == 0:
+ print(f"user id {us_count}/{len(all_users)}: {user_id_str}")
+
+ user_id = int(user_id_str)
+
+ # Получаем последовательности для каждого сплита
+ train_items = list(train_data.get(user_id_str, []))
+ validation_items = list(validation_data.get(user_id_str, []))
+ test_items = list(test_data.get(user_id_str, []))
+
+ # Обрезаем train на последние max_train_events событий
+ train_items = train_items[-max_train_events:]
+
+ full_sequence = train_items + validation_items + test_items
+ if full_sequence:
+ max_item_id = max(max_item_id, max(full_sequence))
+
+ if us_count % 100 == 0:
+ print(f"full sequence len: {len(full_sequence)}")
+
+ us_count += 1
+ if len(train_items) < 2:
+ print(f'Core-2 train dataset is used, user {user_id} has only {len(full_sequence)} items')
+ continue
+
+ if len(validation_items) < 1 or len(test_items) < 1:
+ print(f'Test or validation is empty, user {user_id} has only {len(full_sequence)} items')
+ continue
+
+
+ if is_extended:
+ for prefix_length in range(min_sample_len, len(train_items) + 1):
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': copy.deepcopy(train_items[:prefix_length]),
+ })
+ else:
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': copy.deepcopy(train_items),
+ })
+
+ # валидация
+
+
+ # разворачиваем каждый айтем из валидации в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+
+ current_history = train_items
+ valid_small_history = 0
+ for item in validation_items:
+ current_history.append(item)
+
+ if len(current_history) >= min_sample_len:
+ validation_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': copy.deepcopy(current_history[-max_sequence_length:]),
+ })
+ else:
+ valid_small_history += 1
+
+
+ # разворачиваем каждый айтем из теста в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+
+ test_small_history = 0
+ for item in test_items:
+ current_history.append(item)
+
+ if len(current_history) >= min_sample_len:
+ test_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': copy.deepcopy(current_history[-max_sequence_length:]),
+ })
+ else:
+ test_small_history += 1
+
+
+ print(f"Train dataset size: {len(train_dataset)}")
+ print(f"Validation dataset size: {len(validation_dataset)} with skipped {valid_small_history}")
+ print(f"Test dataset size: {len(test_dataset)} with skipped {test_small_history}")
+
+ logger.debug(f'Train dataset size: {len(train_dataset)}')
+ logger.debug(f'Validation dataset size: {len(validation_dataset)}')
+ logger.debug(f'Test dataset size: {len(test_dataset)}')
+
+ train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length)
+ validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length)
+ test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length)
+
+ return cls(
+ train_sampler=train_sampler,
+ validation_sampler=validation_sampler,
+ test_sampler=test_sampler,
+ num_items=max_item_id + 1, # +1 added because our ids are 0-indexed
+ max_sequence_length=max_sequence_length
+ )
+
+ def get_datasets(self):
+ return self._train_sampler, self._validation_sampler, self._test_sampler
+
+ @property
+ def num_items(self):
+ return self._num_items
+
+ @property
+ def max_sequence_length(self):
+ return self._max_sequence_length
+
+
+class TrainDataset(BaseDataset):
+ def __init__(self, dataset, prediction_type, max_sequence_length):
+ self._dataset = dataset
+ self._prediction_type = prediction_type
+ self._max_sequence_length = max_sequence_length
+
+ self._transforms = {
+ 'sasrec': self._all_items_transform,
+ 'tiger': self._last_item_transform
+ }
+
+ def _all_items_transform(self, sample):
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:]
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array(next_item_sequence, dtype=np.int64),
+ 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64)
+ }
+
+ def _last_item_transform(self, sample):
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ last_item = sample['item.ids'][-self._max_sequence_length:][-1]
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array([last_item], dtype=np.int64),
+ 'labels.length': np.array([1], dtype=np.int64),
+ }
+
+ def __getitem__(self, index):
+ return self._transforms[self._prediction_type](self._dataset[index])
+
+ def __len__(self):
+ return len(self._dataset)
+
+
+class EvalDataset(BaseDataset):
+ def __init__(self, dataset, max_sequence_length):
+ self._dataset = dataset
+ self._max_sequence_length = max_sequence_length
+
+ @property
+ def dataset(self):
+ return self._dataset
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def __getitem__(self, index):
+ sample = self._dataset[index]
+
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ next_item = sample['item.ids'][-self._max_sequence_length:][-1]
+
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array([next_item], dtype=np.int64),
+ 'labels.length': np.array([1], dtype=np.int64),
+ 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64),
+ 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64),
+ }
+
+
+class ArrowBatchDataset(BaseDataset):
+ def __init__(self, batch_dir, device='cuda', preload=False):
+ self.batch_dir = Path(batch_dir)
+ self.device = device
+
+ all_files = list(self.batch_dir.glob('batch_*_len_*.arrow'))
+
+ batch_files_map = defaultdict(list)
+ for f in all_files:
+ batch_id = int(f.stem.split('_')[1])
+ batch_files_map[batch_id].append(f)
+
+ for batch_id in batch_files_map:
+ batch_files_map[batch_id].sort()
+
+ self.batch_indices = sorted(batch_files_map.keys())
+
+ if preload:
+ print(f"Preloading {len(self.batch_indices)} batches...")
+ self.cached_batches = []
+
+ for idx in range(len(self.batch_indices)):
+ batch = self._load_batch(batch_files_map[self.batch_indices[idx]])
+ self.cached_batches.append(batch)
+ else:
+ self.cached_batches = None
+ self.batch_files_map = batch_files_map
+
+ def _load_batch(self, arrow_files):
+ batch = {}
+
+ for arrow_file in arrow_files:
+ table = feather.read_table(arrow_file)
+ metadata = table.schema.metadata or {}
+
+ for col_name in table.column_names:
+ col = table.column(col_name)
+
+ shape_key = f'{col_name}_shape'
+ dtype_key = f'{col_name}_dtype'
+
+ if shape_key.encode() in metadata:
+ shape = eval(metadata[shape_key.encode()].decode())
+ dtype = np.dtype(metadata[dtype_key.encode()].decode())
+
+ # Проверяем тип колонки
+ if pa.types.is_list(col.type) or pa.types.is_large_list(col.type):
+ arr = np.array(col.to_pylist(), dtype=dtype)
+ else:
+ arr = col.to_numpy().reshape(shape).astype(dtype)
+ else:
+ if pa.types.is_list(col.type) or pa.types.is_large_list(col.type):
+ arr = np.array(col.to_pylist())
+ else:
+ arr = col.to_numpy()
+
+ batch[col_name] = torch.from_numpy(arr.copy()).to(self.device)
+
+ return batch
+
+ def __len__(self):
+ return len(self.batch_indices)
+
+ def __getitem__(self, idx):
+ if self.cached_batches is not None:
+ return self.cached_batches[idx]
+ else:
+ batch_id = self.batch_indices[idx]
+ arrow_files = self.batch_files_map[batch_id]
+ return self._load_batch(arrow_files)
diff --git a/scripts/tiger-lsvd/letter_base_gap/lsvd_train_letter_base_gap.py b/scripts/tiger-lsvd/letter_base_gap/lsvd_train_letter_base_gap.py
new file mode 100644
index 0000000..475d1fe
--- /dev/null
+++ b/scripts/tiger-lsvd/letter_base_gap/lsvd_train_letter_base_gap.py
@@ -0,0 +1,226 @@
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../../'
+TRAIN_PART_SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/only_base_with_gap_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless_from_all.json"
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/eval_batches/')
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints-lsvd-transformer')
+
+EXPERIMENT_NAME = 'tiger_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 100
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 512
+NUM_POSITIONS = 80
+NUM_USER_HASH = 8000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+USE_MICROBATCHING = True
+MICROBATCH_SIZE = 256
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(TRAIN_PART_SEMANTIC_MAPPING_PATH, 'r') as f:
+ train_part_mapping = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=CorrectItemsLogitsProcessor(NUM_CODEBOOKS, CODEBOOK_SIZE, train_part_mapping, NUM_BEAMS),
+ use_microbatching=USE_MICROBATCHING,
+ microbatch_size=MICROBATCH_SIZE
+ ).to(DEVICE)
+
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS * 4),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='validation/ndcg@20',
+ patience=40,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/letter_base_gap/lsvd_varka_letter_base_gap_on_sem_all_items.py b/scripts/tiger-lsvd/letter_base_gap/lsvd_varka_letter_base_gap_on_sem_all_items.py
new file mode 100644
index 0000000..1a5c68f
--- /dev/null
+++ b/scripts/tiger-lsvd/letter_base_gap/lsvd_varka_letter_base_gap_on_sem_all_items.py
@@ -0,0 +1,305 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import Dataset
+
+print("tiger arrow varka")
+
+# ПУТИ
+
+IREC_PATH = '../../../'
+INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet"
+INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/val_interactions_grouped.parquet"
+INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/test_interactions_grouped.parquet"
+
+SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless.json"
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/eval_batches/')
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 8000
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ max_item_id = max(int(k) for k in mapping.keys())
+ print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys()))
+ # print(mapping["280052"]) #304781
+ # assert False
+ data = []
+ for i in range(max_item_id + 1):
+ if str(i) in mapping:
+ data.append(mapping[str(i)])
+ else:
+ data.append([-1] * NUM_CODEBOOKS)
+
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ missing_count = (max_item_id + 1) - len(mapping)
+ print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)")
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ semantic_ids = self._mapping_tensor[ids].flatten()
+
+ assert (semantic_ids != -1).all(), \
+ f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}"
+
+ batch[f'{name}.semantic.ids'] = semantic_ids.numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+ print("варка может начать умирать")
+ data = Dataset.create_timestamp_based_parquet(
+ train_parquet_path=INTERACTIONS_TRAIN_PATH,
+ validation_parquet_path=INTERACTIONS_VALID_PATH,
+ test_parquet_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True,
+ max_train_events=MAX_SEQ_LEN
+ )
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+ print("варка не умерла")
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/models.py b/scripts/tiger-lsvd/models.py
new file mode 100644
index 0000000..03c9b74
--- /dev/null
+++ b/scripts/tiger-lsvd/models.py
@@ -0,0 +1,227 @@
+import torch
+from transformers import T5ForConditionalGeneration, T5Config, LogitsProcessor
+
+from irec.models import TorchModel
+
+
+class CorrectItemsLogitsProcessor(LogitsProcessor):
+ def __init__(self, num_codebooks, codebook_size, mapping, num_beams):
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.num_beams = num_beams
+
+ semantic_ids = []
+ for codes in mapping.values():
+ assert len(codes) == num_codebooks, 'All semantic ids must have the same length'
+ semantic_ids.append(codes)
+
+ print(f"semantic ids count (allowed for generation): {len(semantic_ids)}")
+
+ self.index_semantic_ids = torch.tensor(semantic_ids, dtype=torch.long, device='cuda') # (num_items, semantic_ids)
+
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ if (input_ids == -1).any():
+ raise AssertionError(f"LogitsProcessor received -1 in input_ids. \nIndices: {torch.nonzero(input_ids == -1)}")
+
+ batch_size = input_ids.shape[0]//self.num_beams
+
+ next_sid_codebook_num = (torch.minimum((input_ids[:, -1].max() // self.codebook_size), torch.as_tensor(self.num_codebooks - 1)).item() + 1) % self.num_codebooks
+ a = torch.tile(self.index_semantic_ids[None, None, :, next_sid_codebook_num], dims=[batch_size, self.num_beams, 1]) # (batch_size, num_beams, num_items)
+ a = a.reshape(a.shape[0] * a.shape[1], a.shape[2]) # (batch_size * num_beams, num_items)
+
+ if next_sid_codebook_num != 0:
+ b = torch.tile(self.index_semantic_ids[None, None :, :next_sid_codebook_num], dims=[batch_size, self.num_beams, 1, 1]) # (batch_size, num_beams, num_items, sid_len)
+ b = b.reshape(b.shape[0] * b.shape[1], b.shape[2], b.shape[3]) # (batch_size * num_beams, num_items, sid_len)
+
+ current_prefixes = input_ids[:, -next_sid_codebook_num:] # (batch_size * num_beams, sid_len)
+ possible_next_items_mask = (
+ torch.eq(current_prefixes[:, None, :], b).long().sum(dim=-1) == next_sid_codebook_num
+ ) # (batch_size * num_beams, num_items)
+ a[~possible_next_items_mask] = (next_sid_codebook_num + 1) * self.codebook_size
+
+ scores_mask = torch.zeros_like(scores).bool() # (batch_size * num_beams, num_items)
+ scores_mask = torch.scatter_add(
+ input=scores_mask,
+ dim=-1,
+ index=a,
+ src=torch.ones_like(a).bool()
+ )
+
+ scores[:, :next_sid_codebook_num * self.codebook_size] = -torch.inf
+ scores[:, (next_sid_codebook_num + 1) * self.codebook_size:] = -torch.inf
+ scores[~(scores_mask.bool())] = -torch.inf
+
+ return scores
+
+
+
+class TigerModel(TorchModel):
+ def __init__(
+ self,
+ embedding_dim,
+ codebook_size,
+ sem_id_len,
+ num_positions,
+ user_ids_count,
+ num_heads,
+ num_encoder_layers,
+ num_decoder_layers,
+ dim_feedforward,
+ num_beams=100,
+ num_return_sequences=20,
+ d_kv=64,
+ layer_norm_eps=1e-6,
+ activation='relu',
+ dropout=0.1,
+ initializer_range=0.02,
+ logits_processor=None,
+ use_microbatching=False,
+ microbatch_size=128
+ ):
+ super().__init__()
+ self._embedding_dim = embedding_dim
+ self._codebook_size = codebook_size
+ self._num_positions = num_positions
+ self._num_heads = num_heads
+ self._num_encoder_layers = num_encoder_layers
+ self._num_decoder_layers = num_decoder_layers
+ self._dim_feedforward = dim_feedforward
+ self._num_beams = num_beams
+ self._num_return_sequences = num_return_sequences
+ self._d_kv = d_kv
+ self._layer_norm_eps = layer_norm_eps
+ self._activation = activation
+ self._dropout = dropout
+ self._sem_id_len = sem_id_len
+ self.user_ids_count = user_ids_count
+ self.logits_processor = logits_processor
+ self._use_microbatching = use_microbatching
+ self._microbatch_size = microbatch_size
+
+ unified_vocab_size = codebook_size * self._sem_id_len + self.user_ids_count + 10 # 10 for utilities
+ self.config = T5Config(
+ vocab_size=unified_vocab_size,
+ d_model=self._embedding_dim,
+ d_kv=self._d_kv,
+ d_ff=self._dim_feedforward,
+ num_layers=self._num_encoder_layers,
+ num_decoder_layers=self._num_decoder_layers,
+ num_heads=self._num_heads,
+ dropout_rate=self._dropout,
+ is_encoder_decoder=True,
+ use_cache=False,
+ pad_token_id=unified_vocab_size - 1,
+ eos_token_id=unified_vocab_size - 2,
+ decoder_start_token_id=unified_vocab_size - 3,
+ layer_norm_epsilon=self._layer_norm_eps,
+ feed_forward_proj=self._activation,
+ tie_word_embeddings=False
+ )
+ self.model = T5ForConditionalGeneration(config=self.config)
+ self._init_weights(initializer_range)
+
+ self.model = torch.compile(
+ self.model,
+ mode='reduce-overhead',
+ fullgraph=False,
+ dynamic=True
+ )
+
+ def forward(self, inputs):
+ input_semantic_ids = inputs['input.data']
+ attention_mask = inputs['input.mask']
+ target_semantic_ids = inputs['output.data']
+
+ assert (input_semantic_ids != -1).all(), \
+ f"Found -1 in inputs['input.data']. Check your DataLoader/Collator."
+ assert (target_semantic_ids != -1).all(), \
+ f"Found -1 in inputs['output.data']. Check your DataLoader/Collator."
+
+ decoder_input_ids = target_semantic_ids[:, :-1].contiguous()
+ labels = target_semantic_ids[:, 1:].contiguous()
+
+ model_output = self.model(
+ input_ids=input_semantic_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ labels=labels
+ )
+ loss = model_output['loss']
+
+ metrics = {'loss': loss.detach()}
+
+ if not self.training and not self._use_microbatching:
+ # visited_batch = inputs['visited.padded']
+
+ output = self.model.generate(
+ input_ids=input_semantic_ids,
+ attention_mask=attention_mask,
+ num_beams=self._num_beams,
+ num_return_sequences=self._num_return_sequences,
+ max_length=self._sem_id_len + 1,
+ decoder_start_token_id=self.config.decoder_start_token_id,
+ eos_token_id=self.config.eos_token_id,
+ pad_token_id=self.config.pad_token_id,
+ do_sample=False,
+ early_stopping=False,
+ logits_processor=[self.logits_processor] if self.logits_processor is not None else [], #попробовать не маскировать
+ )
+
+ assert (output != -1).all(), "Model.generate returned -1 in raw output"
+
+ predictions = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len)
+
+ all_hits = (torch.eq(predictions, labels[:, None]).sum(dim=-1)) # (batch_size, top_k)
+ elif not self.training and self._use_microbatching:
+ # visited_batch = inputs['visited.padded']
+ batch_size = input_semantic_ids.shape[0]
+
+ inference_batch_size = self._microbatch_size # вместо полного batch_size
+
+ all_predictions = []
+ all_labels = []
+ # print(f"start to infer batch of shape {input_semantic_ids.shape} with new batch {inference_batch_size}")
+ for batch_idx in range(0, batch_size, inference_batch_size):
+ batch_end = min(batch_idx + inference_batch_size, batch_size)
+ batch_slice = slice(batch_idx, batch_end)
+
+ input_ids_batch = input_semantic_ids[batch_slice]
+ attention_mask_batch = attention_mask[batch_slice]
+ # visited_batch_subset = visited_batch[batch_slice]
+ labels_batch = labels[batch_slice]
+
+ with torch.inference_mode():
+ output = self.model.generate(
+ input_ids=input_ids_batch,
+ attention_mask=attention_mask_batch,
+ num_beams=self._num_beams,
+ num_return_sequences=self._num_return_sequences,
+ max_length=self._sem_id_len + 1,
+ decoder_start_token_id=self.config.decoder_start_token_id,
+ eos_token_id=self.config.eos_token_id,
+ pad_token_id=self.config.pad_token_id,
+ do_sample=False,
+ early_stopping=False,
+ logits_processor=[self.logits_processor] if self.logits_processor is not None else [],
+ )
+
+ predictions_batch = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len)
+ all_predictions.append(predictions_batch)
+ all_labels.append(labels_batch)
+ # print("end infer of batch")
+
+ predictions = torch.cat(all_predictions, dim=0) # (batch_size, num_return_sequences, sem_id_len)
+ labels_full = torch.cat(all_labels, dim=0) # (batch_size, sem_id_len)
+ all_hits = (torch.eq(predictions, labels_full[:, None]).sum(dim=-1)) # (batch_size, top_k)
+
+ if not self.training:
+ for k in [5, 10, 20]:
+ hits = (all_hits[:, :k] == self._sem_id_len).float() # (batch_size, k)
+ recall = hits.sum(dim=-1) # (batch_size)
+ discount_factor = 1 / torch.log2(torch.arange(1, k + 1, 1).float() + 1.).to(hits.device) # (k)
+
+ metrics[f'recall@{k}'] = recall.cpu().float()
+ metrics[f'ndcg@{k}'] = torch.einsum('bk,k->b', hits, discount_factor).cpu().float()
+
+ return loss, metrics
\ No newline at end of file
diff --git a/scripts/tiger-lsvd/plum_base_gap/lsvd_train_plum_base_gap.py b/scripts/tiger-lsvd/plum_base_gap/lsvd_train_plum_base_gap.py
new file mode 100644
index 0000000..4d3479d
--- /dev/null
+++ b/scripts/tiger-lsvd/plum_base_gap/lsvd_train_plum_base_gap.py
@@ -0,0 +1,226 @@
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../../'
+TRAIN_PART_SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/plum/only_base_with_gap_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_clusters_colisionless_from_all.json"
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/eval_batches/')
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints-lsvd-transformer')
+
+EXPERIMENT_NAME = 'TEST_tiger_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 200
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 512
+NUM_POSITIONS = 80
+NUM_USER_HASH = 8000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+USE_MICROBATCHING = True
+MICROBATCH_SIZE = 256
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(TRAIN_PART_SEMANTIC_MAPPING_PATH, 'r') as f:
+ train_part_mapping = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=CorrectItemsLogitsProcessor(NUM_CODEBOOKS, CODEBOOK_SIZE, train_part_mapping, NUM_BEAMS),
+ use_microbatching=USE_MICROBATCHING,
+ microbatch_size=MICROBATCH_SIZE
+ ).to(DEVICE)
+
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS * 4),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS * 4),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='validation/ndcg@20',
+ patience=12,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/plum_base_gap/lsvd_varka_plum_base_gap_on_sem_all_items.py b/scripts/tiger-lsvd/plum_base_gap/lsvd_varka_plum_base_gap_on_sem_all_items.py
new file mode 100644
index 0000000..120d343
--- /dev/null
+++ b/scripts/tiger-lsvd/plum_base_gap/lsvd_varka_plum_base_gap_on_sem_all_items.py
@@ -0,0 +1,305 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import Dataset
+
+print("tiger arrow varka")
+
+# ПУТИ
+
+IREC_PATH = '../../../'
+INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet"
+INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/val_interactions_grouped.parquet"
+INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/test_interactions_grouped.parquet"
+
+SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_clusters_colisionless.json"
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/eval_batches/')
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 8000
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ max_item_id = max(int(k) for k in mapping.keys())
+ print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys()))
+ # print(mapping["280052"]) #304781
+ # assert False
+ data = []
+ for i in range(max_item_id + 1):
+ if str(i) in mapping:
+ data.append(mapping[str(i)])
+ else:
+ data.append([-1] * NUM_CODEBOOKS)
+
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ missing_count = (max_item_id + 1) - len(mapping)
+ print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)")
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ semantic_ids = self._mapping_tensor[ids].flatten()
+
+ assert (semantic_ids != -1).all(), \
+ f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}"
+
+ batch[f'{name}.semantic.ids'] = semantic_ids.numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+ print("варка может начать умирать")
+ data = Dataset.create_timestamp_based_parquet(
+ train_parquet_path=INTERACTIONS_TRAIN_PATH,
+ validation_parquet_path=INTERACTIONS_VALID_PATH,
+ test_parquet_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True,
+ max_train_events=MAX_SEQ_LEN
+ )
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+ print("варка не умерла")
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/rqvae_base_gap/lsvd_train_rqvae_base_gap.py b/scripts/tiger-lsvd/rqvae_base_gap/lsvd_train_rqvae_base_gap.py
new file mode 100644
index 0000000..a1bccf5
--- /dev/null
+++ b/scripts/tiger-lsvd/rqvae_base_gap/lsvd_train_rqvae_base_gap.py
@@ -0,0 +1,226 @@
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../../'
+TRAIN_PART_SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/only_base_with_gap_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_clusters_colisionless_from_all.json"
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/eval_batches/')
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints-lsvd-transformer')
+
+EXPERIMENT_NAME = 'tiger_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 100
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 512
+NUM_POSITIONS = 80
+NUM_USER_HASH = 8000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+USE_MICROBATCHING = True
+MICROBATCH_SIZE = 256
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(TRAIN_PART_SEMANTIC_MAPPING_PATH, 'r') as f:
+ train_part_mapping = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=CorrectItemsLogitsProcessor(NUM_CODEBOOKS, CODEBOOK_SIZE, train_part_mapping, NUM_BEAMS),
+ use_microbatching=USE_MICROBATCHING,
+ microbatch_size=MICROBATCH_SIZE
+ ).to(DEVICE)
+
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS * 4),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='validation/ndcg@20',
+ patience=40 * 4,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/rqvae_base_gap/lsvd_varka_rqvae_base_gap_on_sem_all_items.py b/scripts/tiger-lsvd/rqvae_base_gap/lsvd_varka_rqvae_base_gap_on_sem_all_items.py
new file mode 100644
index 0000000..3b06988
--- /dev/null
+++ b/scripts/tiger-lsvd/rqvae_base_gap/lsvd_varka_rqvae_base_gap_on_sem_all_items.py
@@ -0,0 +1,305 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+import sys
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+from data import Dataset
+
+print("tiger arrow varka")
+
+# ПУТИ
+
+IREC_PATH = '../../../'
+INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet"
+INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/val_interactions_grouped.parquet"
+INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/test_interactions_grouped.parquet"
+
+SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_clusters_colisionless.json"
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/eval_batches/')
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 8000
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ max_item_id = max(int(k) for k in mapping.keys())
+ print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys()))
+ # print(mapping["280052"]) #304781
+ # assert False
+ data = []
+ for i in range(max_item_id + 1):
+ if str(i) in mapping:
+ data.append(mapping[str(i)])
+ else:
+ data.append([-1] * NUM_CODEBOOKS)
+
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ missing_count = (max_item_id + 1) - len(mapping)
+ print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)")
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ semantic_ids = self._mapping_tensor[ids].flatten()
+
+ assert (semantic_ids != -1).all(), \
+ f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}"
+
+ batch[f'{name}.semantic.ids'] = semantic_ids.numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+ print("варка может начать умирать")
+ data = Dataset.create_timestamp_based_parquet(
+ train_parquet_path=INTERACTIONS_TRAIN_PATH,
+ validation_parquet_path=INTERACTIONS_VALID_PATH,
+ test_parquet_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True,
+ max_train_events=MAX_SEQ_LEN
+ )
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+ print("варка не умерла")
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sigir/lsvd_processing/LsvdDownload2.ipynb b/sigir/lsvd_processing/LsvdDownload2.ipynb
new file mode 100644
index 0000000..d121678
--- /dev/null
+++ b/sigir/lsvd_processing/LsvdDownload2.ipynb
@@ -0,0 +1,628 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "SbkKok0dfjjS"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import polars as pl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'1.8.2'"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pl.__version__"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "HF_ENDPOINT=\"http://huggingface.proxy\" hf download deepvk/VK-LSVD --repo-type dataset --include \"metadata/*\" --local-dir /home/jovyan/IRec/sigir/lsvd_data/raw\n",
+ "\n",
+ "HF_ENDPOINT=\"http://huggingface.proxy\" hf download deepvk/VK-LSVD --repo-type dataset --include \"subsamples/ur0.01_ir0.01/*\" --local-dir /home/jovyan/IRec/sigir/lsvd_data/raw\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Разбиение сабсэмплов на базовую, гэп, вал и тест части\n",
+ "\n",
+ "Добавляется колонка original_order чтобы сохранять порядок внутри каждой из частей"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(22, 1, 1, 1, 25, 23)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "subsample_name = 'ur0.01_ir0.01'\n",
+ "content_embedding_size = 256\n",
+ "DATASET_PATH = \"/home/jovyan/IRec/sigir/lsvd_data_filtered/raw\"\n",
+ "\n",
+ "metadata_files = ['metadata/users_metadata.parquet',\n",
+ " 'metadata/items_metadata.parquet',\n",
+ " 'metadata/item_embeddings.npz']\n",
+ "\n",
+ "BASE_WEEKS = (0, 22)\n",
+ "GAP_WEEKS = (22, 23) #увеличить гэп\n",
+ "VAL_WEEKS = (23, 24)\n",
+ "\n",
+ "base_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n",
+ " for i in range(BASE_WEEKS[0], BASE_WEEKS[1])]\n",
+ "\n",
+ "gap_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n",
+ " for i in range(GAP_WEEKS[0], GAP_WEEKS[1])]\n",
+ "\n",
+ "val_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n",
+ " for i in range(VAL_WEEKS[0], VAL_WEEKS[1])]\n",
+ "\n",
+ "test_interactions_files = [f'subsamples/{subsample_name}/train/week_24.parquet']\n",
+ "\n",
+ "all_interactions_files = base_interactions_files + gap_interactions_files + val_interactions_files + test_interactions_files\n",
+ "\n",
+ "base_with_gap_interactions_files = base_interactions_files + gap_interactions_files\n",
+ "\n",
+ "len(base_interactions_files), len(gap_interactions_files), len(val_interactions_files), len(test_interactions_files), len(all_interactions_files), len(base_with_gap_interactions_files)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "('subsamples/ur0.01_ir0.01/train/week_21.parquet',\n",
+ " ['subsamples/ur0.01_ir0.01/train/week_22.parquet'],\n",
+ " ['subsamples/ur0.01_ir0.01/train/week_23.parquet'],\n",
+ " ['subsamples/ur0.01_ir0.01/train/week_24.parquet'])"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "base_interactions_files[-1], gap_interactions_files, val_interactions_files, test_interactions_files"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_parquet_interactions(data_files, positive_event_timespent):\n",
+ " data_interactions = pl.concat([pl.scan_parquet(f'{DATASET_PATH}/{file}')\n",
+ " for file in data_files])\n",
+ " data_interactions = data_interactions.collect(streaming=True)\n",
+ " data_interactions = data_interactions.with_row_index(\"original_order\")\n",
+ " data_interactions = data_interactions.filter(pl.col('timespent') > positive_event_timespent)\n",
+ " return data_interactions\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "POSITIVE_EVENT_TIMESPENT = 15\n",
+ "base_interactions = get_parquet_interactions(base_interactions_files, POSITIVE_EVENT_TIMESPENT)\n",
+ "gap_interactions = get_parquet_interactions(gap_interactions_files, POSITIVE_EVENT_TIMESPENT)\n",
+ "val_interactions = get_parquet_interactions(val_interactions_files, POSITIVE_EVENT_TIMESPENT)\n",
+ "test_interactions = get_parquet_interactions(test_interactions_files, POSITIVE_EVENT_TIMESPENT)\n",
+ "all_data_interactions = get_parquet_interactions(all_interactions_files, POSITIVE_EVENT_TIMESPENT)\n",
+ "base_with_gap_interactions = get_parquet_interactions(base_with_gap_interactions_files, POSITIVE_EVENT_TIMESPENT)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Загрузка и фильтрация эмбеддингов"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "all_data_users = all_data_interactions.select('user_id').unique()\n",
+ "all_data_items = all_data_interactions.select('item_id').unique()\n",
+ "\n",
+ "item_ids = np.load(f\"{DATASET_PATH}/metadata/item_embeddings.npz\")['item_id']\n",
+ "item_embeddings = np.load(f\"{DATASET_PATH}/metadata/item_embeddings.npz\")['embedding']\n",
+ "\n",
+ "mask = np.isin(item_ids, all_data_items.to_numpy())\n",
+ "item_ids = item_ids[mask]\n",
+ "item_embeddings = item_embeddings[mask]\n",
+ "item_embeddings = item_embeddings[:, :content_embedding_size]\n",
+ "\n",
+ "users_metadata = pl.read_parquet(f\"{DATASET_PATH}/metadata/users_metadata.parquet\")\n",
+ "items_metadata = pl.read_parquet(f\"{DATASET_PATH}/metadata/items_metadata.parquet\")\n",
+ "\n",
+ "users_metadata = users_metadata.join(all_data_users, on='user_id')\n",
+ "items_metadata = items_metadata.join(all_data_items, on='item_id')\n",
+ "items_metadata = items_metadata.join(pl.DataFrame({'item_id': item_ids, \n",
+ " 'embedding': item_embeddings}), on='item_id')\n",
+ "\n",
+ "only_base_items_metadata = items_metadata.join(base_interactions.select('item_id').unique(), on='item_id')\n",
+ "only_base_with_gap_items_metadata = items_metadata.join(base_with_gap_interactions.select('item_id').unique(), on='item_id')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Сжатие айтем айди и ремапинг"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total users: 77403, Total items: 65659\n"
+ ]
+ }
+ ],
+ "source": [
+ "all_data_items = all_data_interactions.select('item_id').unique()\n",
+ "all_data_users = all_data_interactions.select('user_id').unique()\n",
+ "\n",
+ "unique_items_sorted = all_data_items.sort('item_id').with_row_index('new_item_id')\n",
+ "global_item_mapping = dict(zip(unique_items_sorted['item_id'], unique_items_sorted['new_item_id']))\n",
+ "\n",
+ "print(f\"Total users: {all_data_users.shape[0]}, Total items: {len(global_item_mapping)}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def remap_interactions(df, mapping):\n",
+ " return df.with_columns(\n",
+ " pl.col('item_id')\n",
+ " .map_elements(lambda x: mapping.get(x, None), return_dtype=pl.UInt32)\n",
+ " )\n",
+ "\n",
+ "base_interactions_remapped = remap_interactions(base_interactions, global_item_mapping)\n",
+ "gap_interactions_remapped = remap_interactions(gap_interactions, global_item_mapping)\n",
+ "test_interactions_remapped = remap_interactions(test_interactions, global_item_mapping)\n",
+ "val_interactions_remapped = remap_interactions(val_interactions, global_item_mapping)\n",
+ "all_data_interactions_remapped = remap_interactions(all_data_interactions, global_item_mapping)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "del base_interactions, gap_interactions, test_interactions, val_interactions, all_data_interactions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "base_with_gap_interactions_remapped = remap_interactions(base_with_gap_interactions, global_item_mapping)\n",
+ "del base_with_gap_interactions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "items_metadata_remapped = remap_interactions(items_metadata, global_item_mapping)\n",
+ "only_base_items_metadata_remapped = remap_interactions(only_base_items_metadata, global_item_mapping)\n",
+ "only_base_with_gap_items_metadata_remapped = remap_interactions(only_base_with_gap_items_metadata, global_item_mapping)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Группировка по юзер айди"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "interactions count: (1206762, 13)\n",
+ "users count: (75281, 3)\n",
+ "interactions count: (66879, 13)\n",
+ "users count: (28610, 3)\n",
+ "interactions count: (69887, 13)\n",
+ "users count: (28982, 3)\n",
+ "interactions count: (73637, 13)\n",
+ "users count: (30231, 3)\n",
+ "interactions count: (1417165, 13)\n",
+ "users count: (77403, 3)\n"
+ ]
+ }
+ ],
+ "source": [
+ "def get_grouped_interactions(data_interactions):\n",
+ " print(f\"interactions count: {data_interactions.shape}\")\n",
+ " data_res = (\n",
+ " data_interactions\n",
+ " .select(['original_order', 'user_id', 'item_id'])\n",
+ " .group_by('user_id')\n",
+ " .agg(\n",
+ " pl.col('item_id')\n",
+ " .sort_by(pl.col('original_order'))\n",
+ " .alias('item_ids'),\n",
+ " pl.col('original_order').alias('timestamps')\n",
+ " )\n",
+ " .rename({'user_id': 'uid'})\n",
+ " )\n",
+ " print(f\"users count: {data_res.shape}\")\n",
+ " return data_res\n",
+ "base_interactions_grouped = get_grouped_interactions(base_interactions_remapped)\n",
+ "gap_interactions_grouped = get_grouped_interactions(gap_interactions_remapped)\n",
+ "test_interactions_grouped = get_grouped_interactions(test_interactions_remapped)\n",
+ "val_interactions_grouped = get_grouped_interactions(val_interactions_remapped)\n",
+ "all_data_interactions_grouped = get_grouped_interactions(all_data_interactions_remapped)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "
shape: (1, 3)| uid | item_ids | timestamps |
|---|
| u32 | list[u32] | list[u32] |
| 4651912 | [39251, 11631, … 18532] | [181275, 1237407, … 3130963] |
"
+ ],
+ "text/plain": [
+ "shape: (1, 3)\n",
+ "┌─────────┬─────────────────────────┬──────────────────────────────┐\n",
+ "│ uid ┆ item_ids ┆ timestamps │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ u32 ┆ list[u32] ┆ list[u32] │\n",
+ "╞═════════╪═════════════════════════╪══════════════════════════════╡\n",
+ "│ 4651912 ┆ [39251, 11631, … 18532] ┆ [181275, 1237407, … 3130963] │\n",
+ "└─────────┴─────────────────────────┴──────────────────────────────┘"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "base_interactions_grouped.head(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "del base_interactions_remapped, gap_interactions_remapped, test_interactions_remapped, val_interactions_remapped"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "interactions count: (1273641, 13)\n",
+ "users count: (76011, 3)\n"
+ ]
+ }
+ ],
+ "source": [
+ "base_with_gap_interactions_grouped = get_grouped_interactions(base_with_gap_interactions_remapped)\n",
+ "del base_with_gap_interactions_remapped"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Сохранение"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Сохранён маппинг: /home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/global_item_mapping.json\n"
+ ]
+ }
+ ],
+ "source": [
+ "import json\n",
+ "OUTPUT_DIR = \"/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows\"\n",
+ "\n",
+ "mapping_output_path = f\"{OUTPUT_DIR}/global_item_mapping.json\"\n",
+ "\n",
+ "with open(mapping_output_path, 'w') as f:\n",
+ " json.dump({str(k): v for k, v in global_item_mapping.items()}, f, indent=2)\n",
+ "\n",
+ "print(f\"Сохранён маппинг: {mapping_output_path}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "размерность: (65659, 5)\n",
+ "Сохранен файл: items_metadata_remapped\n",
+ "размерность: (65659, 5)\n",
+ "Сохранен файл: items_metadata_old\n",
+ "размерность: (60552, 5)\n",
+ "Сохранен файл: only_base_items_metadata_remapped\n",
+ "размерность: (60552, 5)\n",
+ "Сохранен файл: only_base_items_metadata_old\n",
+ "размерность: (62362, 5)\n",
+ "Сохранен файл: only_base_with_gap_items_metadata_remapped\n",
+ "размерность: (62362, 5)\n",
+ "Сохранен файл: only_base_with_gap_items_metadata_old\n",
+ "размерность: (75281, 3)\n",
+ "Сохранен файл: base_interactions_grouped\n",
+ "размерность: (28610, 3)\n",
+ "Сохранен файл: gap_interactions_grouped\n",
+ "размерность: (28982, 3)\n",
+ "Сохранен файл: test_interactions_grouped\n",
+ "размерность: (30231, 3)\n",
+ "Сохранен файл: val_interactions_grouped\n",
+ "размерность: (76011, 3)\n",
+ "Сохранен файл: base_with_gap_interactions_grouped\n",
+ "размерность: (77403, 3)\n",
+ "Сохранен файл: all_data_interactions_grouped\n",
+ "размерность: (1417165, 13)\n",
+ "Сохранен файл: all_data_interactions_remapped\n"
+ ]
+ }
+ ],
+ "source": [
+ "def write_parquet(output_dir, data, file_name):\n",
+ " print(f\"размерность: {data.shape}\")\n",
+ " output_parquet_path = f\"{output_dir}/{file_name}.parquet\"\n",
+ " data.write_parquet(output_parquet_path)\n",
+ " print(f\"Сохранен файл: {file_name}\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, items_metadata_remapped, \"items_metadata_remapped\")\n",
+ "write_parquet(OUTPUT_DIR, items_metadata, \"items_metadata_old\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, only_base_items_metadata_remapped, \"only_base_items_metadata_remapped\")\n",
+ "write_parquet(OUTPUT_DIR, only_base_items_metadata, \"only_base_items_metadata_old\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, only_base_with_gap_items_metadata_remapped, \"only_base_with_gap_items_metadata_remapped\")\n",
+ "write_parquet(OUTPUT_DIR, only_base_with_gap_items_metadata, \"only_base_with_gap_items_metadata_old\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, base_interactions_grouped, \"base_interactions_grouped\")\n",
+ "write_parquet(OUTPUT_DIR, gap_interactions_grouped, \"gap_interactions_grouped\")\n",
+ "write_parquet(OUTPUT_DIR, test_interactions_grouped, \"test_interactions_grouped\")\n",
+ "write_parquet(OUTPUT_DIR, val_interactions_grouped, \"val_interactions_grouped\")\n",
+ "write_parquet(OUTPUT_DIR, base_with_gap_interactions_grouped, \"base_with_gap_interactions_grouped\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, all_data_interactions_grouped, \"all_data_interactions_grouped\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, all_data_interactions_remapped, \"all_data_interactions_remapped\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(64, 64)"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(list(items_metadata_remapped.head(1)['embedding'].item())), len(list(items_metadata.head(1)['embedding'].item()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "((65659, 5), (60552, 5), (62362, 5))"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_metadata_remapped.shape, only_base_items_metadata_remapped.shape, only_base_with_gap_items_metadata_remapped.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (1, 5)| item_id | author_id | duration | train_interactions_rank | embedding |
|---|
| u32 | u32 | u8 | u32 | array[f32, 64] |
| 0 | 651705 | 71 | 14592477 | [-0.281006, -0.145508, … -0.058258] |
"
+ ],
+ "text/plain": [
+ "shape: (1, 5)\n",
+ "┌─────────┬───────────┬──────────┬─────────────────────────┬─────────────────────────────────┐\n",
+ "│ item_id ┆ author_id ┆ duration ┆ train_interactions_rank ┆ embedding │\n",
+ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
+ "│ u32 ┆ u32 ┆ u8 ┆ u32 ┆ array[f32, 64] │\n",
+ "╞═════════╪═══════════╪══════════╪═════════════════════════╪═════════════════════════════════╡\n",
+ "│ 0 ┆ 651705 ┆ 71 ┆ 14592477 ┆ [-0.281006, -0.145508, … -0.05… │\n",
+ "└─────────┴───────────┴──────────┴─────────────────────────┴─────────────────────────────────┘"
+ ]
+ },
+ "execution_count": 34,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_metadata_remapped.head(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| uid | item_ids | timestamps |
|---|
| u32 | list[u32] | list[u32] |
| 1450406 | [65092, 7570] | [2309584, 3257670] |
| 3736304 | [39094, 42731, … 10818] | [2691900, 2743200, … 3232861] |
| 775866 | [52462, 44617, … 64047] | [101336, 161036, … 2789649] |
| 777572 | [48639, 6069] | [1056275, 1218815] |
| 1942489 | [44575, 6264, 32234] | [372244, 955295, 2432349] |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌─────────┬─────────────────────────┬───────────────────────────────┐\n",
+ "│ uid ┆ item_ids ┆ timestamps │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ u32 ┆ list[u32] ┆ list[u32] │\n",
+ "╞═════════╪═════════════════════════╪═══════════════════════════════╡\n",
+ "│ 1450406 ┆ [65092, 7570] ┆ [2309584, 3257670] │\n",
+ "│ 3736304 ┆ [39094, 42731, … 10818] ┆ [2691900, 2743200, … 3232861] │\n",
+ "│ 775866 ┆ [52462, 44617, … 64047] ┆ [101336, 161036, … 2789649] │\n",
+ "│ 777572 ┆ [48639, 6069] ┆ [1056275, 1218815] │\n",
+ "│ 1942489 ┆ [44575, 6264, 32234] ┆ [372244, 955295, 2432349] │\n",
+ "└─────────┴─────────────────────────┴───────────────────────────────┘"
+ ]
+ },
+ "execution_count": 35,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "base_with_gap_interactions_grouped.head()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}