diff --git a/scripts/plum-lsvd/callbacks.py b/scripts/plum-lsvd/callbacks.py new file mode 100644 index 0000000..43ec460 --- /dev/null +++ b/scripts/plum-lsvd/callbacks.py @@ -0,0 +1,64 @@ +import torch + +import irec.callbacks as cb +from irec.runners import TrainingRunner, TrainingRunnerContext + +class InitCodebooks(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + @torch.no_grad() + def before_run(self, runner: TrainingRunner): + for i in range(len(runner.model.codebooks)): + X = next(iter(self._dataloader))['embedding'] + idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])] + remainder = runner.model.encoder(X[idx]) + + for j in range(i): + codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j]) + codebook_vectors = runner.model.codebooks[j][codebook_indices] + remainder = remainder - codebook_vectors + + runner.model.codebooks[i].data = remainder.detach() + + +class FixDeadCentroids(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)): + context.metrics[f'num_dead/{i}'] = num_fixed + + @torch.no_grad() + def fix_dead_codebooks(self, runner: TrainingRunner): + num_fixed = [] + for codebook_idx, codebook in enumerate(runner.model.codebooks): + centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device) + random_batch = next(iter(self._dataloader))['embedding'] + + for batch in self._dataloader: + remainder = runner.model.encoder(batch['embedding']) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + + indices = runner.model.get_codebook_indices(remainder, codebook) + centroid_counts.scatter_add_(0, indices, torch.ones_like(indices)) + + dead_mask = (centroid_counts == 0) + num_dead = int(dead_mask.sum().item()) + num_fixed.append(num_dead) + if num_dead == 0: + continue + + remainder = runner.model.encoder(random_batch) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead] + codebook[dead_mask] = remainder.detach() + + return num_fixed diff --git a/scripts/plum-lsvd/cooc_data.py b/scripts/plum-lsvd/cooc_data.py new file mode 100644 index 0000000..7cea906 --- /dev/null +++ b/scripts/plum-lsvd/cooc_data.py @@ -0,0 +1,117 @@ +import json +from collections import defaultdict, Counter +from data import InteractionsDatasetParquet +from collections import defaultdict, Counter + + +class CoocMappingDataset: + def __init__( + self, + train_sampler, + num_items, + cooccur_counter_mapping=None + ): + self._train_sampler = train_sampler + self._num_items = num_items + self._cooccur_counter_mapping = cooccur_counter_mapping + + @classmethod + def create(cls, inter_json_path, window_size): + max_item_id = 0 + train_dataset = [] + + with open(inter_json_path, 'r') as f: + user_interactions = json.load(f) + + for user_id_str, item_ids in user_interactions.items(): + user_id = int(user_id_str) + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + if len(item_ids) >= 5: + print(f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items') + train_dataset.append({ + 'user_ids': [user_id], + 'item_ids': item_ids[:-2], + }) + + + cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size) + print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}') + + + train_sampler = train_dataset + + + return cls( + train_sampler=train_sampler, + num_items=max_item_id + 1, + cooccur_counter_mapping=cooccur_counter_mapping + ) + + + @classmethod + def create_from_split_part( + cls, + train_inter_parquet_path, + window_size, + ): + + max_item_id = 0 + train_dataset = [] + + + train_interactions = InteractionsDatasetParquet(train_inter_parquet_path) + + actions_num = 0 + for session in train_interactions: + user_id, item_ids = int(session['user_id']), session['item_ids'] + if item_ids.any(): + max_item_id = max(max_item_id, max(item_ids)) + actions_num += len(item_ids) + train_dataset.append({ + 'user_ids': [user_id], + 'item_ids': item_ids, + }) + + + print(f'Train: {len(train_dataset)} users') + print(f'Max item ID: {max_item_id}') + print(f"Actions num: {actions_num}") + + + cooccur_counter_mapping = cls.build_cooccur_counter_mapping( + train_dataset, + window_size=window_size + ) + + + print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items') + + + return cls( + train_sampler=train_dataset, + num_items=max_item_id + 1, + cooccur_counter_mapping=cooccur_counter_mapping + ) + + + + @staticmethod + def build_cooccur_counter_mapping(train_dataset, window_size): + cooccur_counts = defaultdict(Counter) + for session in train_dataset: + items = session['item_ids'] + for i in range(len(items)): + item_i = items[i] + for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)): + if i != j: + cooccur_counts[item_i][items[j]] += 1 + max_hist_len = max(len(counter) for counter in cooccur_counts.values()) if cooccur_counts else 0 + print(f"Max cooccurrence history length is {max_hist_len}") + return cooccur_counts + + + + @property + def cooccur_counter_mapping(self): + return self._cooccur_counter_mapping \ No newline at end of file diff --git a/scripts/plum-lsvd/data.py b/scripts/plum-lsvd/data.py new file mode 100644 index 0000000..5a780fb --- /dev/null +++ b/scripts/plum-lsvd/data.py @@ -0,0 +1,87 @@ +import numpy as np +import pickle + +from irec.data.base import BaseDataset +from irec.data.transforms import Transform + + +import polars as pl + +class InteractionsDatasetParquet(BaseDataset): + def __init__(self, data_path, max_items=None): + self.df = pl.read_parquet(data_path) + assert 'uid' in self.df.columns, "Missing 'uid' column" + assert 'item_ids' in self.df.columns, "Missing 'item_ids' column" + print(f"Dataset loaded: {len(self.df)} users") + + if max_items is not None: + self.df = self.df.with_columns( + pl.col("item_ids").list.slice(-max_items).alias("item_ids") + ) + + def __getitem__(self, idx): + row = self.df.row(idx, named=True) + return { + 'user_id': row['uid'], + 'item_ids': np.array(row['item_ids'], dtype=np.uint32), + } + + def __len__(self): + return len(self.df) + + def __iter__(self): + for idx in range(len(self)): + yield self[idx] + + +class EmbeddingDatasetParquet(BaseDataset): + def __init__(self, data_path): + self.df = pl.read_parquet(data_path) + self.item_ids = np.array(self.df['item_id'], dtype=np.int64) + self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32) + print(f"embedding dim: {self.embeddings[0].shape}") + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + + +class EmbeddingDataset(BaseDataset): + def __init__(self, data_path): + self.data_path = data_path + with open(data_path, 'rb') as f: + self.data = pickle.load(f) + + self.item_ids = np.array(self.data['item_id'], dtype=np.int64) + self.embeddings = np.array(self.data['embedding'], dtype=np.float32) + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + + +class ProcessEmbeddings(Transform): + def __init__(self, embedding_dim, keys): + self.embedding_dim = embedding_dim + self.keys = keys + + def __call__(self, batch): + for key in self.keys: + batch[key] = batch[key].reshape(-1, self.embedding_dim) + return batch \ No newline at end of file diff --git a/scripts/plum-lsvd/letter_base_gap/create_base_gap_mapping_from_all.py b/scripts/plum-lsvd/letter_base_gap/create_base_gap_mapping_from_all.py new file mode 100644 index 0000000..734bfdf --- /dev/null +++ b/scripts/plum-lsvd/letter_base_gap/create_base_gap_mapping_from_all.py @@ -0,0 +1,50 @@ +import json +import pandas as pd +from pathlib import Path + + +ALL_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless.json" +TRAIN_INTERACTIONS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet" +OUTPUT_TRAIN_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/only_base_with_gap_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless_from_all.json" +with open(ALL_MAPPING_PATH, 'r') as f: + all_mapping = json.load(f) +print(f"Loaded {len(all_mapping)} items from all_mapping") + +train_interactions = pd.read_parquet(TRAIN_INTERACTIONS_PATH) + +train_item_ids = set() +for item_ids_array in train_interactions['item_ids']: + train_item_ids.update(item_ids_array) + +print(f"Found {len(train_item_ids)} unique train items") + +train_mapping = {} +missing_count = 0 + +for item_id in train_item_ids: + item_id_str = str(item_id) + if item_id_str in all_mapping: + train_mapping[item_id_str] = all_mapping[item_id_str] + else: + missing_count += 1 + +if missing_count > 0: + print(f"{missing_count} items from train not found in all_mapping") + +print(f"Created train_mapping with {len(train_mapping)} items") + +Path(OUTPUT_TRAIN_MAPPING_PATH).parent.mkdir(parents=True, exist_ok=True) +with open(OUTPUT_TRAIN_MAPPING_PATH, 'w') as f: + json.dump(train_mapping, f, indent=2) +print(f"Saved to {OUTPUT_TRAIN_MAPPING_PATH}") + +print(f"all_mapping size: {len(all_mapping)}") +print(f"train_mapping size: {len(train_mapping)}") +print(f"train_mapping/all_mapping ratio: {len(train_mapping)/len(all_mapping):.1%}") + +sample_matches = 0 +for item_id_str in list(train_mapping.keys())[:100]: + if all_mapping[item_id_str] == train_mapping[item_id_str]: + sample_matches += 1 + +print(f"Verified: {sample_matches}/100 sampled items have identical codes") diff --git a/scripts/plum-lsvd/letter_base_gap/infer_letter.py b/scripts/plum-lsvd/letter_base_gap/infer_letter.py new file mode 100644 index 0000000..1e3d45a --- /dev/null +++ b/scripts/plum-lsvd/letter_base_gap/infer_letter.py @@ -0,0 +1,155 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed +from letter import LetterRQVAE + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from collections import Counter + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 35 +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 64 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 + +RQVAE_LOSS_WEIGHT=1.0 +CF_LOSS_WEIGHT=0.01 + +EXPERIMENT_NAME = f'all_items_letter_vk-lsvd-15ts_base_with_gap_e{NUM_EPOCHS}_rqvae_{RQVAE_LOSS_WEIGHT}_cf_{CF_LOSS_WEIGHT}' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/items_metadata_remapped.parquet" +IREC_PATH = '../../../' + +MODEL_PATH = "/home/jovyan/IRec/checkpoints-lsvd/letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_best_0.0067.pth" +RESULTS_PATH = os.path.join(IREC_PATH, 'results-lsvd-2/base_gap/letter') + +# SASREC_MODEL_PATH = "" + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = LetterRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=RQVAE_LOSS_WEIGHT, + cf_loss_weight=CF_LOSS_WEIGHT, + cf_embeddings=None + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'cf_loss': model_outputs['cf_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/cf_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + collision_stats = [] + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + collision_stats.append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + if collision_stats: + max_col_tok = max(collision_stats) + avg_col_tok = np.mean(collision_stats) + collision_distribution = Counter(collision_stats) + + print(f"Max collision token: {max_col_tok}") + print(f"Avg collision token: {avg_col_tok:.2f}") + print(f"Total items with collisions: {len(collision_stats)}") + print(f"Collision solver distribution: {dict(collision_distribution)}") + else: + print("No collisions detected") + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/letter_base_gap/letter.py b/scripts/plum-lsvd/letter_base_gap/letter.py new file mode 100644 index 0000000..786772f --- /dev/null +++ b/scripts/plum-lsvd/letter_base_gap/letter.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class LetterRQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + beta=0.25, + quant_loss_weight=1.0, + cf_loss_weight=1.0, + cf_embeddings=None + ): + super().__init__() + self.register_buffer('beta', torch.tensor(beta)) + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.quant_loss_weight = quant_loss_weight + + self.cf_loss_weight = cf_loss_weight + + self.cf_embeddings = None + if cf_embeddings is not None: + self.cf_embeddings = torch.tensor(cf_embeddings, dtype=torch.float32) + print(self.cf_embeddings.requires_grad) + else: + print("CF EMBEDS IS NONE!!!! ONLY FOR INFER") + + self.encoder = self.make_encoding_tower(input_dim, embedding_dim) + self.decoder = self.make_encoding_tower(embedding_dim, input_dim) + + self.codebooks = torch.nn.ParameterList() + for _ in range(num_codebooks): + cb = torch.FloatTensor(codebook_size, embedding_dim) + #nn.init.normal_(cb) + self.codebooks.append(cb) + + @staticmethod + def make_encoding_tower(d1, d2, bias=False): + return torch.nn.Sequential( + nn.Linear(d1, d1), + nn.ReLU(), + nn.Linear(d1, d2), + nn.ReLU(), + nn.Linear(d2, d2, bias=bias) + ) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + item_ids = inputs['item_id'] + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + clusters.append(codebook_indices) + + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding']) + + if self.cf_embeddings is not None: + cf_embedding_in_batch = self.cf_embeddings[item_ids] + cf_loss = self.CF_loss(latent_restored, cf_embedding_in_batch) + else: + cf_loss = torch.as_tensor(0.0) + + loss = (recon_loss + self.quant_loss_weight * rqvae_loss + self.cf_loss_weight * cf_loss).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + 'cf_loss': cf_loss.item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } + + def CF_loss(self, quantized_rep, encoded_rep): + batch_size = quantized_rep.size(0) + labels = torch.arange(batch_size, dtype=torch.long, device=quantized_rep.device) + similarities = quantized_rep @ encoded_rep.T + cf_loss = F.cross_entropy(similarities, labels) + return cf_loss diff --git a/scripts/plum-lsvd/letter_base_gap/train_letter.py b/scripts/plum-lsvd/letter_base_gap/train_letter.py new file mode 100644 index 0000000..08b884b --- /dev/null +++ b/scripts/plum-lsvd/letter_base_gap/train_letter.py @@ -0,0 +1,165 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from letter import LetterRQVAE + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDatasetParquet, ProcessEmbeddings + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 35 +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 64 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 + +RQVAE_LOSS_WEIGHT=1.0 +CF_LOSS_WEIGHT=0.01 + +EXPERIMENT_NAME = f'letter_vk-lsvd-15ts_base_with_gap_e{NUM_EPOCHS}_rqvae_{RQVAE_LOSS_WEIGHT}_cf_{CF_LOSS_WEIGHT}' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/only_base_with_gap_items_metadata_remapped.parquet" +IREC_PATH = '../../../' + +SASREC_MODEL_PATH = "/home/jovyan/IRec/checkpoints-lsvd-transformer/sasrec_vk_lsvd_base_gap_msl_20_e_300_best_0.0223.pth" + +print(EMBEDDINGS_PATH) +print(SASREC_MODEL_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH, + ) + + train_dataloader = DataLoader( #call в основном потоке делается нужно исправить + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + cf_embeddings = torch.load( + SASREC_MODEL_PATH, + map_location=DEVICE + )['_orig_mod._item_embeddings.weight'] + + print(f"cf embeds shape: {cf_embeddings.shape}") + + model = LetterRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=RQVAE_LOSS_WEIGHT, + cf_loss_weight=CF_LOSS_WEIGHT, + cf_embeddings=cf_embeddings + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'cf_loss': model_outputs['cf_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/cf_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'cf_loss': model_outputs['cf_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/cf_loss': cb.MeanAccumulator(), + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints-lsvd', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/models.py b/scripts/plum-lsvd/models.py new file mode 100644 index 0000000..d475712 --- /dev/null +++ b/scripts/plum-lsvd/models.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class PlumRQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + beta=0.25, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=0.0, + ): + super().__init__() + self.register_buffer('beta', torch.tensor(beta)) + self.temperature = temperature + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.quant_loss_weight = quant_loss_weight + + self.contrastive_loss_weight = contrastive_loss_weight + + self.encoder = self.make_encoding_tower(input_dim, embedding_dim) + self.decoder = self.make_encoding_tower(embedding_dim, input_dim) + + self.codebooks = torch.nn.ParameterList() + for _ in range(num_codebooks): + cb = torch.FloatTensor(codebook_size, embedding_dim) + #nn.init.normal_(cb) + self.codebooks.append(cb) + + @staticmethod + def make_encoding_tower(d1, d2, bias=False): + return torch.nn.Sequential( + nn.Linear(d1, d1), + nn.ReLU(), + nn.Linear(d1, d2), + nn.ReLU(), + nn.Linear(d2, d2, bias=bias) + ) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def _quantize_representation(self, latent_vector): + latent_restored = 0 + remainder = latent_vector + + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + return latent_restored + + def contrastive_loss(self, p_i, p_i_star): + N_b = p_i.size(0) + + p_i = F.normalize(p_i, p=2, dim=-1) #TODO посмотреть без нормалайза + p_i_star = F.normalize(p_i_star, p=2, dim=-1) + + similarities = torch.matmul(p_i, p_i_star.T) / self.temperature + + labels = torch.arange(N_b, dtype=torch.long, device=p_i.device) + + loss = F.cross_entropy(similarities, labels) + + return loss + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + item_ids = inputs['item_id'] + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + clusters.append(codebook_indices) + + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding']) + + if 'cooccurrence_embedding' in inputs: + cooccurrence_latent = self.encoder(inputs['cooccurrence_embedding'].to(latent_restored.device)) + cooccurrence_restored = self._quantize_representation(cooccurrence_latent) + con_loss = self.contrastive_loss(latent_restored, cooccurrence_restored) + else: + con_loss = torch.as_tensor(0.0, device=latent_vector.device) + + loss = ( + recon_loss + + self.quant_loss_weight * rqvae_loss + + self.contrastive_loss_weight * con_loss + ).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + 'con_loss': con_loss.item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } \ No newline at end of file diff --git a/scripts/plum-lsvd/plum_base_gap/create_base_gap_mapping_from_all.py b/scripts/plum-lsvd/plum_base_gap/create_base_gap_mapping_from_all.py new file mode 100644 index 0000000..112c579 --- /dev/null +++ b/scripts/plum-lsvd/plum_base_gap/create_base_gap_mapping_from_all.py @@ -0,0 +1,52 @@ +import json +import pandas as pd +from pathlib import Path + + +ALL_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_clusters_colisionless.json" +TRAIN_INTERACTIONS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet" +OUTPUT_TRAIN_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/only_base_with_gap_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_clusters_colisionless_from_all.json" + + +with open(ALL_MAPPING_PATH, 'r') as f: + all_mapping = json.load(f) +print(f"Loaded {len(all_mapping)} items from all_mapping") + +train_interactions = pd.read_parquet(TRAIN_INTERACTIONS_PATH) + +train_item_ids = set() +for item_ids_array in train_interactions['item_ids']: + train_item_ids.update(item_ids_array) + +print(f"Found {len(train_item_ids)} unique train items") + +train_mapping = {} +missing_count = 0 + +for item_id in train_item_ids: + item_id_str = str(item_id) + if item_id_str in all_mapping: + train_mapping[item_id_str] = all_mapping[item_id_str] + else: + missing_count += 1 + +if missing_count > 0: + print(f"{missing_count} items from train not found in all_mapping") + +print(f"Created train_mapping with {len(train_mapping)} items") + +Path(OUTPUT_TRAIN_MAPPING_PATH).parent.mkdir(parents=True, exist_ok=True) +with open(OUTPUT_TRAIN_MAPPING_PATH, 'w') as f: + json.dump(train_mapping, f, indent=2) +print(f"Saved to {OUTPUT_TRAIN_MAPPING_PATH}") + +print(f"all_mapping size: {len(all_mapping)}") +print(f"train_mapping size: {len(train_mapping)}") +print(f"train_mapping/all_mapping ratio: {len(train_mapping)/len(all_mapping):.1%}") + +sample_matches = 0 +for item_id_str in list(train_mapping.keys())[:100]: + if all_mapping[item_id_str] == train_mapping[item_id_str]: + sample_matches += 1 + +print(f"Verified: {sample_matches}/100 sampled items have identical codes") diff --git a/scripts/plum-lsvd/plum_base_gap/infer_plum_on_all_items.py b/scripts/plum-lsvd/plum_base_gap/infer_plum_on_all_items.py new file mode 100644 index 0000000..d1473cc --- /dev/null +++ b/scripts/plum-lsvd/plum_base_gap/infer_plum_on_all_items.py @@ -0,0 +1,148 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 35 +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 +K=3000 + +RQVAE_LOSS_WEIGHT=1.0 +CON_LOSS_WEIGHT=0.01 + +EXPERIMENT_NAME = f'all_items_plum_vk-lsvd-15ts_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_e{NUM_EPOCHS}_con_{CON_LOSS_WEIGHT}_rqvae_{RQVAE_LOSS_WEIGHT}' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/items_metadata_remapped.parquet" +IREC_PATH = '../../../' + +MODEL_PATH = "/home/jovyan/IRec/checkpoints-lsvd/plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_best_0.008.pth" +RESULTS_PATH = os.path.join(IREC_PATH, 'results-lsvd-2/base_gap') + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=RQVAE_LOSS_WEIGHT, + contrastive_loss_weight=CON_LOSS_WEIGHT, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/plum_base_gap/train_plum_base_gap.py b/scripts/plum-lsvd/plum_base_gap/train_plum_base_gap.py new file mode 100644 index 0000000..3960337 --- /dev/null +++ b/scripts/plum-lsvd/plum_base_gap/train_plum_base_gap.py @@ -0,0 +1,186 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddingsVectorized +from cooc_data import CoocMappingDataset + +# ЭКСПЕРИМЕНТ С ОБРЕЗАННОЙ ИСТОРИЕЙ +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 35 +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 +K=3000 + +RQVAE_LOSS_WEIGHT=1.0 +CON_LOSS_WEIGHT=0.01 + +EXPERIMENT_NAME = f'plum_vk-lsvd-15ts_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_e{NUM_EPOCHS}_con_{CON_LOSS_WEIGHT}_rqvae_{RQVAE_LOSS_WEIGHT}' +INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet" +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/only_base_with_gap_items_metadata_remapped.parquet" +IREC_PATH = '../../' + +print(INTER_TRAIN_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH, + ) + + data = CoocMappingDataset.create_from_split_part( + train_inter_parquet_path=INTER_TRAIN_PATH, + window_size=WINDOW_SIZE + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE) + all_item_ids.append(item_id) + + # add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized(data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids, K, DEVICE,seed=42) + add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized( + cooccur_counts=data.cooccur_counter_mapping, + item_id_to_embedding=item_id_to_embedding, + all_item_ids=all_item_ids, + device=DEVICE, + max_neighbors=K, + seed=42 + ) + + train_dataloader = DataLoader( #call в основном потоке делается нужно исправить + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform + ).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=RQVAE_LOSS_WEIGHT, + contrastive_loss_weight=CON_LOSS_WEIGHT, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints-lsvd', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/rqvae_base_gap/create_base_gap_mapping_from_all.py b/scripts/plum-lsvd/rqvae_base_gap/create_base_gap_mapping_from_all.py new file mode 100644 index 0000000..a86f0c2 --- /dev/null +++ b/scripts/plum-lsvd/rqvae_base_gap/create_base_gap_mapping_from_all.py @@ -0,0 +1,52 @@ +import json +import pandas as pd +from pathlib import Path + + +ALL_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_clusters_colisionless.json" +TRAIN_INTERACTIONS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet" +OUTPUT_TRAIN_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/only_base_with_gap_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_clusters_colisionless_from_all.json" + + +with open(ALL_MAPPING_PATH, 'r') as f: + all_mapping = json.load(f) +print(f"Loaded {len(all_mapping)} items from all_mapping") + +train_interactions = pd.read_parquet(TRAIN_INTERACTIONS_PATH) + +train_item_ids = set() +for item_ids_array in train_interactions['item_ids']: + train_item_ids.update(item_ids_array) + +print(f"Found {len(train_item_ids)} unique train items") + +train_mapping = {} +missing_count = 0 + +for item_id in train_item_ids: + item_id_str = str(item_id) + if item_id_str in all_mapping: + train_mapping[item_id_str] = all_mapping[item_id_str] + else: + missing_count += 1 + +if missing_count > 0: + print(f"{missing_count} items from train not found in all_mapping") + +print(f"Created train_mapping with {len(train_mapping)} items") + +Path(OUTPUT_TRAIN_MAPPING_PATH).parent.mkdir(parents=True, exist_ok=True) +with open(OUTPUT_TRAIN_MAPPING_PATH, 'w') as f: + json.dump(train_mapping, f, indent=2) +print(f"Saved to {OUTPUT_TRAIN_MAPPING_PATH}") + +print(f"all_mapping size: {len(all_mapping)}") +print(f"train_mapping size: {len(train_mapping)}") +print(f"train_mapping/all_mapping ratio: {len(train_mapping)/len(all_mapping):.1%}") + +sample_matches = 0 +for item_id_str in list(train_mapping.keys())[:100]: + if all_mapping[item_id_str] == train_mapping[item_id_str]: + sample_matches += 1 + +print(f"Verified: {sample_matches}/100 sampled items have identical codes") diff --git a/scripts/plum-lsvd/rqvae_base_gap/infer_rqvae.py b/scripts/plum-lsvd/rqvae_base_gap/infer_rqvae.py new file mode 100644 index 0000000..3bd030d --- /dev/null +++ b/scripts/plum-lsvd/rqvae_base_gap/infer_rqvae.py @@ -0,0 +1,158 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from collections import Counter +from models import PlumRQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 35 +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 + +RQVAE_LOSS_WEIGHT=1.0 + +EXPERIMENT_NAME = f'all_items_rqvae_vk-lsvd-15ts_base_with_gap_e{NUM_EPOCHS}_rqvae_{RQVAE_LOSS_WEIGHT}' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/items_metadata_remapped.parquet" +IREC_PATH = '../../../' + +MODEL_PATH = "/home/jovyan/IRec/checkpoints-lsvd/rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_best_0.0073.pth" +RESULTS_PATH = os.path.join(IREC_PATH, 'results-lsvd-2/base_gap') + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=RQVAE_LOSS_WEIGHT + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(MODEL_PATH), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + collision_stats = [] + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + collision_stats.append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + if collision_stats: + max_col_tok = max(collision_stats) + avg_col_tok = np.mean(collision_stats) + collision_distribution = Counter(collision_stats) + + print(f"Max collision token: {max_col_tok}") + print(f"Avg collision token: {avg_col_tok:.2f}") + print(f"Total items with collisions: {len(collision_stats)}") + print(f"Collision solver distribution: {dict(collision_distribution)}") + else: + print("No collisions detected") + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/rqvae_base_gap/train_rqvae.py b/scripts/plum-lsvd/rqvae_base_gap/train_rqvae.py new file mode 100644 index 0000000..50c0bce --- /dev/null +++ b/scripts/plum-lsvd/rqvae_base_gap/train_rqvae.py @@ -0,0 +1,161 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDatasetParquet, ProcessEmbeddings +from models import PlumRQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 35 +BATCH_SIZE = 1024 + +INPUT_DIM = 64 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 + +RQVAE_LOSS_WEIGHT=1.0 + +EXPERIMENT_NAME = f'rqvae_vk-lsvd-15ts_base_with_gap_e{NUM_EPOCHS}_rqvae_{RQVAE_LOSS_WEIGHT}' +EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/only_base_with_gap_items_metadata_remapped.parquet" +IREC_PATH = '../../../' + +print(EMBEDDINGS_PATH) +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDatasetParquet( + data_path=EMBEDDINGS_PATH, + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE) + all_item_ids.append(item_id) + + train_dataloader = DataLoader( #call в основном потоке делается нужно исправить + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=RQVAE_LOSS_WEIGHT + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints-lsvd', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/plum-lsvd/transform_test.py b/scripts/plum-lsvd/transform_test.py new file mode 100644 index 0000000..b977be8 --- /dev/null +++ b/scripts/plum-lsvd/transform_test.py @@ -0,0 +1,145 @@ +import torch +import numpy as np +import pytest +from transforms import AddWeightedCooccurrenceEmbeddingsVectorized + +def test_add_weighted_cooccurrence_embeddings(): + torch.manual_seed(42) + np.random.seed(42) + device = torch.device('cpu') # Тестируем на CPU для простоты + + print("\n" + "="*80) + print("TEST 1: Normal case with cooccurrences") + print("="*80) + + # Граф: + # 0 -> {1: 10, 2: 5} (чаще 1) + # 1 -> {0: 10} (только 0) + # 2 -> {} (нет соседей) + cooccur_counts = { + 0: {1: 10, 2: 5}, + 1: {0: 10}, + 2: {}, + } + + # Эмбеддинги (one-hot для наглядности) + # 0: [1, 0, 0] + # 1: [0, 1, 0] + # 2: [0, 0, 1] + item_embeddings = { + 0: torch.tensor([1.0, 0.0, 0.0]), + 1: torch.tensor([0.0, 1.0, 0.0]), + 2: torch.tensor([0.0, 0.0, 1.0]), + } + + all_item_ids = [0, 1, 2] + + # Инициализация + transform = AddWeightedCooccurrenceEmbeddingsVectorized( + cooccur_counts=cooccur_counts, + item_id_to_embedding=item_embeddings, + all_item_ids=all_item_ids, + device=device, + max_neighbors=2, # Ограничим до 2 для проверки обрезки + seed=42 + ) + + # ---------------------------------------------------------------- + # Проверка 1: Айтем 1 должен всегда выбирать соседа 0 (вероятность 1.0) + # ---------------------------------------------------------------- + batch_1 = {'item_id': torch.tensor([1, 1, 1], device=device)} + res_1 = transform(batch_1) + # Ожидаем эмбеддинг айтема 0: [1.0, 0.0, 0.0] + expected_emb_0 = item_embeddings[0].to(device) + + print("Checking deterministic neighbor (1 -> 0)...") + for emb in res_1['cooccurrence_embedding']: + assert torch.allclose(emb, expected_emb_0), \ + f"Item 1 should strictly link to 0. Got {emb}" + print("✅ Deterministic neighbor check passed") + + # ---------------------------------------------------------------- + # Проверка 2: Айтем 2 (без соседей) должен выдавать валидный случайный эмбеддинг + # ---------------------------------------------------------------- + batch_2 = {'item_id': torch.tensor([2] * 100, device=device)} + res_2 = transform(batch_2) + embs_2 = res_2['cooccurrence_embedding'] + + print("Checking fallback for item without neighbors...") + # Проверяем, что нет NaN + assert not torch.isnan(embs_2).any(), "NaN found in fallback embeddings" + + # Проверяем, что возвращаются реальные эмбеддинги из словаря + valid_embs_set = {tuple(e.tolist()) for e in item_embeddings.values()} + for emb in embs_2[:10]: # Проверим первые 10 + assert tuple(emb.tolist()) in valid_embs_set, f"Invalid embedding generated: {emb}" + print("✅ Fallback check passed") + + # ---------------------------------------------------------------- + # Проверка 3: Распределение вероятностей (Item 0 -> 1(66%) vs 2(33%)) + # ---------------------------------------------------------------- + # Запустим большой батч для статистики + batch_0 = {'item_id': torch.tensor([0] * 1000, device=device)} + res_0 = transform(batch_0) + embs_0 = res_0['cooccurrence_embedding'] + + # Считаем, сколько раз выпал эмбеддинг 1 (сосед с весом 10) и эмбеддинг 2 (сосед с весом 5) + # Emb 1 = [0, 1, 0], Emb 2 = [0, 0, 1] + count_1 = (embs_0[:, 1] == 1.0).sum().item() + count_2 = (embs_0[:, 2] == 1.0).sum().item() + + ratio = count_1 / (count_1 + count_2) + expected_ratio = 10 / 15 # ~0.666 + + print(f"Checking distribution for Item 0. Expected ~{expected_ratio:.2f}, Got {ratio:.2f}") + assert abs(ratio - expected_ratio) < 0.05, \ + f"Distribution mismatch! Expected {expected_ratio:.2f}, got {ratio:.2f}" + print("✅ Distribution check passed") + + print("\n" + "="*80) + print("TEST 2: Edge Cases") + print("="*80) + + # ---------------------------------------------------------------- + # Проверка 4: Пустой батч + # ---------------------------------------------------------------- + batch_empty = {'item_id': torch.tensor([], dtype=torch.long, device=device)} + res_empty = transform(batch_empty) + assert res_empty['cooccurrence_embedding'].shape[0] == 0 + print("✅ Empty batch passed") + + # ---------------------------------------------------------------- + # Проверка 5: Item ID вне списка all_item_ids (например, padding index или новый айтем) + # В текущей реализации searchsorted, индексы клемпятся. + # Проверим, что код не падает. + # ---------------------------------------------------------------- + unknown_id = 999 + batch_unknown = {'item_id': torch.tensor([unknown_id], device=device)} + + try: + res_unknown = transform(batch_unknown) + print("✅ Unknown item ID handled (no crash)") + # В идеале тут надо проверить, что вернулось (скорее всего, neighbor для последнего айтема) + except Exception as e: + pytest.fail(f"Crashed on unknown item ID: {e}") + + # ---------------------------------------------------------------- + # Проверка 6: Воспроизводимость (Seed) + # ---------------------------------------------------------------- + batch_seed = {'item_id': torch.tensor([0, 2, 0, 1] * 10, device=device)} + transform_a = AddWeightedCooccurrenceEmbeddingsVectorized( + cooccur_counts, item_embeddings, all_item_ids, device, seed=42 + ) + res_a = transform_a(batch_seed)['cooccurrence_embedding'] + transform_b = AddWeightedCooccurrenceEmbeddingsVectorized( + cooccur_counts, item_embeddings, all_item_ids, device, seed=42 + ) + res_b = transform_b(batch_seed)['cooccurrence_embedding'] + + assert torch.allclose(res_a, res_b), "Results differ with same seed!" + print("✅ Seeding reproducibility passed") + + print("\n✅ ALL TESTS PASSED!") + +if __name__ == "__main__": + test_add_weighted_cooccurrence_embeddings() diff --git a/scripts/plum-lsvd/transforms.py b/scripts/plum-lsvd/transforms.py new file mode 100644 index 0000000..bc8f812 --- /dev/null +++ b/scripts/plum-lsvd/transforms.py @@ -0,0 +1,111 @@ +import numpy as np +import torch +from typing import Dict, List +import time + +import torch +from typing import Dict, List + +class AddWeightedCooccurrenceEmbeddingsVectorized: + def __init__( + self, + cooccur_counts: Dict[int, Dict[int, int]], + item_id_to_embedding: Dict[int, torch.Tensor], + all_item_ids: List[int], + device: torch.device, + max_neighbors: int = 128, + seed: int = 42, + ): + self.device = device + self.max_neighbors = max_neighbors + torch.manual_seed(seed) + + self.all_item_ids = torch.tensor(sorted(all_item_ids), dtype=torch.long, device=device) + self.num_items = len(self.all_item_ids) + + # 2. Эмбеддинги + emb_dim = list(item_id_to_embedding.values())[0].shape[0] + self.embedding_matrix = torch.zeros((self.num_items, emb_dim), dtype=torch.float32, device=device) + + id_to_idx = {iid.item(): i for i, iid in enumerate(self.all_item_ids.cpu())} + + for iid, emb in item_id_to_embedding.items(): + if iid in id_to_idx: + self.embedding_matrix[id_to_idx[iid]] = emb.to(device) + + + self.neighbors = torch.randint(0, self.num_items, (self.num_items, max_neighbors), device=device) + self.probs = torch.full((self.num_items, max_neighbors), 1.0 / max_neighbors, device=device) + + neighbors_cpu = self.neighbors.cpu() + probs_cpu = self.probs.cpu() + + for item_id, neighbors_dict in cooccur_counts.items(): + if item_id not in id_to_idx: continue + idx = id_to_idx[item_id] + + if not neighbors_dict: continue + + top_k = sorted(neighbors_dict.items(), key=lambda x: x[1], reverse=True)[:max_neighbors] + ids, counts = zip(*top_k) + + valid_pairs = [] + for n_id, c in zip(ids, counts): + n_idx = id_to_idx.get(n_id, -1) + if n_idx != -1: + valid_pairs.append((n_idx, c)) + + if not valid_pairs: continue + + final_indices, final_counts = zip(*valid_pairs) + k_len = len(final_indices) + + count_tensor = torch.tensor(final_counts, dtype=torch.float32) + prob_tensor = count_tensor / count_tensor.sum() + + neighbors_cpu[idx, :k_len] = torch.tensor(final_indices, dtype=torch.long) + probs_cpu[idx, :k_len] = prob_tensor + + if k_len < max_neighbors: + probs_cpu[idx, k_len:] = 0.0 + + self.neighbors = neighbors_cpu.to(device) + self.probs = probs_cpu.to(device) + + def __call__(self, batch): + item_ids = batch['item_id'].to(self.device) + + indices = torch.searchsorted(self.all_item_ids, item_ids) + indices = indices.clamp(max=self.num_items - 1) + + found_mask = (self.all_item_ids[indices] == item_ids) + + if not found_mask.all(): + missing_count = (~found_mask).sum().item() + missing_examples = item_ids[~found_mask][:5].tolist() + print(f"[WARNING] Batch contains {missing_count} unknown items! Examples: {missing_examples}") + print(f" Assigning RANDOM embeddings for unknown items.") + + # кого не нашли, подменим индекс на 0 временно + safe_indices = indices.clone() + safe_indices[~found_mask] = 0 + + batch_probs = self.probs[safe_indices] # (B, max_neighbors) + neighbor_local_indices = torch.multinomial(batch_probs, num_samples=1).squeeze(1) # (B) + selected_neighbor_indices = self.neighbors[safe_indices, neighbor_local_indices] # (B) + cooc_embeddings = self.embedding_matrix[selected_neighbor_indices] + + if not found_mask.all(): + # случайные индексы айтемов без истории + random_indices = torch.randint(0, self.num_items, (item_ids.shape[0],), device=self.device) + random_embeddings = self.embedding_matrix[random_indices] + + # шум + cooc_embeddings = torch.where( + found_mask.unsqueeze(1), + cooc_embeddings, + random_embeddings + ) + + batch['cooccurrence_embedding'] = cooc_embeddings + return batch diff --git a/scripts/sasrec-lsvd/base_gap/train_bg-v-t.py b/scripts/sasrec-lsvd/base_gap/train_bg-v-t.py new file mode 100644 index 0000000..2861b2c --- /dev/null +++ b/scripts/sasrec-lsvd/base_gap/train_bg-v-t.py @@ -0,0 +1,202 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.models import AutoCast +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import ArrowBatchDataset +from models import SasRecModel + +SEED_VALUE = 42 +DEVICE = 'cuda' + +IREC_PATH = '../../../' +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/eval_batches/') + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints-lsvd-transformer') + +EXPERIMENT_NAME = 'sasrec_vk_lsvd_base_gap_msl_20_e_300' + +NUM_EPOCHS = 10 +MAX_SEQ_LEN = 20 +# TRAIN_BATCH_SIZE = 256 +# VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 64 +NUM_HEADS = 2 +NUM_LAYERS = 2 +FEEDFORWARD_DIM = 256 +DROPOUT = 0.3 +LR = 1e-4 + +NUM_ITEMS = 66000 + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=None + ), + batch_size=1, + shuffle=True, + num_workers=16, + prefetch_factor=16, + pin_memory=True, + persistent_workers=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = SasRecModel( + num_items=NUM_ITEMS, + max_sequence_length=MAX_SEQ_LEN, + embedding_dim=EMBEDDING_DIM, + num_heads=NUM_HEADS, + num_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + activation='relu', + topk_k=20, + dropout=DROPOUT, + layer_norm_eps=1e-8, + initializer_range=0.02 + ) + model = torch.compile(model, mode="default", fullgraph=False) + model = model.to('cuda') + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS * 4), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='validation/ndcg@20', + patience=40 * 4, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/sasrec-lsvd/base_gap/varka_bg-v-t.py b/scripts/sasrec-lsvd/base_gap/varka_bg-v-t.py new file mode 100644 index 0000000..8f131ad --- /dev/null +++ b/scripts/sasrec-lsvd/base_gap/varka_bg-v-t.py @@ -0,0 +1,104 @@ +from collections import defaultdict +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +from irec.data.transforms import Collate +from irec.data.dataloader import DataLoader + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import Dataset + +# ПУТИ +IREC_PATH = '../../../' + +INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet" +INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/val_interactions_grouped.parquet" +INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/test_interactions_grouped.parquet" + +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/sasrec_base_gap/eval_batches/') + +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + length_groups[length][key] = value + + for length, fields in length_groups.items(): + arrow_dict = {k: pa.array(v) for k, v in fields.items()} + table = pa.table(arrow_dict) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + +def main(): + data = Dataset.create_timestamp_based_parquet( + train_parquet_path=INTERACTIONS_TRAIN_PATH, + validation_parquet_path=INTERACTIONS_VALID_PATH, + test_parquet_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='sasrec', + min_sample_len=2, + is_extended=False, + max_train_events=MAX_SEQ_LEN + ) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ).map(Collate()).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ).map(Collate()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ).map(Collate()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + +if __name__ == '__main__': + main() diff --git a/scripts/sasrec-lsvd/data.py b/scripts/sasrec-lsvd/data.py new file mode 100644 index 0000000..4ae0e18 --- /dev/null +++ b/scripts/sasrec-lsvd/data.py @@ -0,0 +1,359 @@ +from collections import defaultdict +import json +from loguru import logger +import numpy as np +from pathlib import Path + +import copy +import pyarrow as pa +import pyarrow.feather as feather + +import torch +import polars as pl +from irec.data.base import BaseDataset + + +class InteractionsDatasetParquet(BaseDataset): + def __init__(self, data_path, max_items=None): + self.df = pl.read_parquet(data_path) + assert 'uid' in self.df.columns, "Missing 'uid' column" + assert 'item_ids' in self.df.columns, "Missing 'item_ids' column" + print(f"Dataset loaded: {len(self.df)} users") + + if max_items is not None: + self.df = self.df.with_columns( + pl.col("item_ids").list.slice(-max_items).alias("item_ids") + ) + + def __getitem__(self, idx): + row = self.df.row(idx, named=True) + return { + 'user_id': row['uid'], + 'item_ids': np.array(row['item_ids'], dtype=np.uint32), + } + + def __len__(self): + return len(self.df) + + def __iter__(self): + for idx in range(len(self)): + yield self[idx] + + +class Dataset: + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_items, + max_sequence_length + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_items = num_items + self._max_sequence_length = max_sequence_length + + @classmethod + def create_timestamp_based_parquet( + cls, + train_parquet_path, + validation_parquet_path, + test_parquet_path, + max_sequence_length, + sampler_type, + min_sample_len=2, + is_extended=False, + max_train_events=20 + ): + """ + Загружает данные из parquet файлов с timestamp-based сплитом. + + Ожидает структуру parquet: + - uid: int (user id) + - item_ids: list[int] (список item ids) + + Аналогично create_timestamp_based, но для parquet формата. + """ + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + print(f"started to load datasets from parquet with max train length {max_train_events}") + + # Загружаем parquet файлы + train_df = pl.read_parquet(train_parquet_path) + validation_df = pl.read_parquet(validation_parquet_path) + test_df = pl.read_parquet(test_parquet_path) + + # Проверяем наличие необходимых колонок + for df, name in [(train_df, "train"), (validation_df, "validation"), (test_df, "test")]: + assert 'uid' in df.columns, f"Missing 'uid' column in {name}" + assert 'item_ids' in df.columns, f"Missing 'item_ids' column in {name}" + + # Создаем словари для быстрого доступа + train_data = {str(row['uid']): row['item_ids'] for row in train_df.iter_rows(named=True)} + validation_data = {str(row['uid']): row['item_ids'] for row in validation_df.iter_rows(named=True)} + test_data = {str(row['uid']): row['item_ids'] for row in test_df.iter_rows(named=True)} + + all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys()) + print(f"all users count: {len(all_users)}") + + us_count = 0 + for user_id_str in all_users: + if us_count % 100 == 0: + print(f"user id {us_count}/{len(all_users)}: {user_id_str}") + + user_id = int(user_id_str) + + # Получаем последовательности для каждого сплита + train_items = list(train_data.get(user_id_str, [])) + validation_items = list(validation_data.get(user_id_str, [])) + test_items = list(test_data.get(user_id_str, [])) + + # Обрезаем train на последние max_train_events событий + train_items = train_items[-max_train_events:] + + full_sequence = train_items + validation_items + test_items + if full_sequence: + max_item_id = max(max_item_id, max(full_sequence)) + + if us_count % 100 == 0: + print(f"full sequence len: {len(full_sequence)}") + + us_count += 1 + if len(train_items) < 2: + print(f'Core-2 train dataset is used, user {user_id} has only {len(full_sequence)} items') + continue + + if len(validation_items) < 1 or len(test_items) < 1: + print(f'Test or validation is empty, user {user_id} has only {len(full_sequence)} items') + continue + + + if is_extended: + for prefix_length in range(min_sample_len, len(train_items) + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': copy.deepcopy(train_items[:prefix_length]), + }) + else: + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': copy.deepcopy(train_items), + }) + + # валидация + + + # разворачиваем каждый айтем из валидации в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + + current_history = train_items + valid_small_history = 0 + for item in validation_items: + current_history.append(item) + + if len(current_history) >= min_sample_len: + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': copy.deepcopy(current_history[-max_sequence_length:]), + }) + else: + valid_small_history += 1 + + + # разворачиваем каждый айтем из теста в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + + test_small_history = 0 + for item in test_items: + current_history.append(item) + + if len(current_history) >= min_sample_len: + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': copy.deepcopy(current_history[-max_sequence_length:]), + }) + else: + test_small_history += 1 + + + print(f"Train dataset size: {len(train_dataset)}") + print(f"Validation dataset size: {len(validation_dataset)} with skipped {valid_small_history}") + print(f"Test dataset size: {len(test_dataset)} with skipped {test_small_history}") + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + + def get_datasets(self): + return self._train_sampler, self._validation_sampler, self._test_sampler + + @property + def num_items(self): + return self._num_items + + @property + def max_sequence_length(self): + return self._max_sequence_length + + +class TrainDataset(BaseDataset): + def __init__(self, dataset, prediction_type, max_sequence_length): + self._dataset = dataset + self._prediction_type = prediction_type + self._max_sequence_length = max_sequence_length + + self._transforms = { + 'sasrec': self._all_items_transform, + 'tiger': self._last_item_transform + } + + def _all_items_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array(next_item_sequence, dtype=np.int64), + 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64) + } + + def _last_item_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + last_item = sample['item.ids'][-self._max_sequence_length:][-1] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([last_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + } + + def __getitem__(self, index): + return self._transforms[self._prediction_type](self._dataset[index]) + + def __len__(self): + return len(self._dataset) + + +class EvalDataset(BaseDataset): + def __init__(self, dataset, max_sequence_length): + self._dataset = dataset + self._max_sequence_length = max_sequence_length + + @property + def dataset(self): + return self._dataset + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, index): + sample = self._dataset[index] + + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item = sample['item.ids'][-self._max_sequence_length:][-1] + + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([next_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64), + 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64), + } + + +class ArrowBatchDataset(BaseDataset): + def __init__(self, batch_dir, device='cuda', preload=False): + self.batch_dir = Path(batch_dir) + self.device = device + + all_files = list(self.batch_dir.glob('batch_*_len_*.arrow')) + + batch_files_map = defaultdict(list) + for f in all_files: + batch_id = int(f.stem.split('_')[1]) + batch_files_map[batch_id].append(f) + + for batch_id in batch_files_map: + batch_files_map[batch_id].sort() + + self.batch_indices = sorted(batch_files_map.keys()) + + if preload: + print(f"Preloading {len(self.batch_indices)} batches...") + self.cached_batches = [] + + for idx in range(len(self.batch_indices)): + batch = self._load_batch(batch_files_map[self.batch_indices[idx]]) + self.cached_batches.append(batch) + else: + self.cached_batches = None + self.batch_files_map = batch_files_map + + def _load_batch(self, arrow_files): + batch = {} + + for arrow_file in arrow_files: + table = feather.read_table(arrow_file) + metadata = table.schema.metadata or {} + + for col_name in table.column_names: + col = table.column(col_name) + + shape_key = f'{col_name}_shape' + dtype_key = f'{col_name}_dtype' + + if shape_key.encode() in metadata: + shape = eval(metadata[shape_key.encode()].decode()) + dtype = np.dtype(metadata[dtype_key.encode()].decode()) + + # Проверяем тип колонки + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist(), dtype=dtype) + else: + arr = col.to_numpy().reshape(shape).astype(dtype) + else: + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist()) + else: + arr = col.to_numpy() + + batch[col_name] = torch.from_numpy(arr.copy()).to(self.device) + + return batch + + def __len__(self): + return len(self.batch_indices) + + def __getitem__(self, idx): + if self.cached_batches is not None: + return self.cached_batches[idx] + else: + batch_id = self.batch_indices[idx] + arrow_files = self.batch_files_map[batch_id] + return self._load_batch(arrow_files) diff --git a/scripts/sasrec-lsvd/models.py b/scripts/sasrec-lsvd/models.py new file mode 100644 index 0000000..0e1ea00 --- /dev/null +++ b/scripts/sasrec-lsvd/models.py @@ -0,0 +1,244 @@ +import torch +import torch.nn as nn + +from irec.models import TorchModel, create_masked_tensor + +import torch._dynamo +import torch +import torch.nn as nn + +from irec.models import create_masked_tensor + + + +class TransformerEncoder(nn.Module): + def __init__( + self, + embedding_dim, + layers, + dim_feedforward, + num_heads, + dropout, + activation, + causal, + prenorm=False + ): + super().__init__() + self.causal = causal + layer = nn.TransformerEncoderLayer( + d_model=embedding_dim, + nhead=num_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=1e-5, + batch_first=True, + norm_first=prenorm + ) + self.encoder = nn.TransformerEncoder(layer, num_layers=layers) + self._init_weights(initializer_range=0.02) + + @torch.no_grad() + def _init_weights(self, initializer_range: float) -> None: + for key, value in self.named_parameters(): + if 'weight' in key: + if 'norm' in key: + nn.init.ones_(value.data) + else: + nn.init.trunc_normal_( + value.data, + std=initializer_range, + a=-2 * initializer_range, + b=2 * initializer_range + ) + elif 'bias' in key: + nn.init.zeros_(value.data) + else: + raise ValueError(f'Unknown transformer weight: {key}') + + def forward(self, embeddings, lengths, max_seqlen): + padded_embeddings, mask = create_masked_tensor( + data=embeddings, + lengths=lengths + ) + + if self.causal: + causal_mask = nn.Transformer.generate_square_subsequent_mask( + sz=mask.shape[-1], + device=mask.device, + dtype=mask.dtype + ) + else: + causal_mask = None + + padded_embeddings = self.encoder( + padded_embeddings, + mask=causal_mask, + src_key_padding_mask=~mask, + is_causal=self.causal + ) + return padded_embeddings[mask] + + +class SasRecModel(TorchModel): + def __init__( + self, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + activation, + topk_k, + dropout=0.0, + layer_norm_eps=1e-9, + initializer_range=0.02 + ): + super().__init__() + self._num_items = num_items + self._num_heads = num_heads + self._embedding_dim = embedding_dim + + self._item_embeddings = nn.Embedding( + num_embeddings=num_items, + embedding_dim=embedding_dim + ) + self._position_embeddings = nn.Embedding( + num_embeddings=max_sequence_length, + embedding_dim=embedding_dim + ) + + self._topk_k = topk_k + + self._encoder = TransformerEncoder( + embedding_dim=embedding_dim, + dim_feedforward=dim_feedforward, + layers=num_layers, + num_heads=num_heads, + dropout=dropout, + activation=activation, + causal=True, + ) + + self._init_weights(initializer_range) + + def forward(self, inputs): + with torch._dynamo.config.patch(suppress_errors=True): + all_sample_events = inputs['item.ids'] # (total_batch_items) + all_sample_lengths = inputs['item.length'] # (batch_size) + max_seqlen = int(all_sample_lengths.max().item()) + + embeddings = self._item_embeddings(all_sample_events) + + end_indices = all_sample_lengths.cumsum(dim=0) # (batch_size) + start_indices = end_indices - all_sample_lengths # (batch_size) + + sample_indices = torch.arange( + all_sample_lengths.shape[0], + device=all_sample_lengths.device + ).repeat_interleave(all_sample_lengths) # (total_batch_items) + + positions = torch.arange( + all_sample_events.shape[0], + device=all_sample_events.device + ) - start_indices[sample_indices] # (total_batch_items) + + position_embeddings = self._position_embeddings(positions) # (total_batch_items, embedding_dim) + + embeddings = embeddings + position_embeddings # (total_batch_items, embedding_dim) + + all_sample_embeddings = self._encoder(embeddings=embeddings, lengths=all_sample_lengths, max_seqlen=max_seqlen) # (total_batch_items, embedding_dim) + + all_positive_sample_events = inputs['labels.ids'] # (total_batch_items) + + if not self.training: + offsets = torch.cumsum(all_sample_lengths, dim=-1) + all_sample_embeddings = all_sample_embeddings[offsets - 1] + + all_embeddings = self._item_embeddings.weight # (num_items, embedding_dim) + + # a -- total_batch_items, n -- num_items, d -- embedding_dim + all_scores = torch.einsum( + 'ad,nd->an', + all_sample_embeddings, + all_embeddings + ) # (total_batch_items, num_items) + + positive_scores = torch.gather( + input=all_scores, + dim=1, + index=all_positive_sample_events[..., None] + )[:, 0] # (total_batch_items) + + # Compute loss + negative_scores = torch.gather( + input=all_scores, + dim=1, + index=torch.randint( + low=0, + high=all_scores.shape[1], + size=all_positive_sample_events.shape, + device=all_positive_sample_events.device + )[..., None] + )[:, 0] # (total_batch_items) + + with torch.autocast(device_type='cuda', enabled=False): + loss = self._compute_loss( + positive_scores.float(), + negative_scores.float() + ) + + metrics = { + 'loss': loss.detach() + } + + if not self.training: + batch_size = all_sample_lengths.shape[0] + num_items = all_embeddings.shape[0] + + padded_items, _ = create_masked_tensor( + data=all_sample_events, + lengths=all_sample_lengths, + ) # (batch_size, max_seq_len) + + visited_mask = torch.zeros( + batch_size, num_items, + dtype=torch.bool, + device=all_sample_events.device + ) + + batch_indices = torch.arange(batch_size, device=all_sample_events.device)[:, None] + batch_indices = batch_indices.expand(-1, padded_items.shape[1]) + + visited_mask.scatter_( + dim=1, + index=padded_items.long(), + value=True + ) + + all_scores = all_scores.masked_fill(visited_mask, float('-inf')) + + positive_position = (all_scores > positive_scores[:, None]).float().sum(dim=-1) # (batch_size или total_batch_items) + dcg_score = 1. / (torch.log2(positive_position + 1) + 1.) + + for k in [5, 10, 20]: + metrics[f'recall@{k}'] = (positive_position < k).float() + metrics[f'ndcg@{k}'] = torch.where( + positive_position < k, + dcg_score, + torch.zeros_like(dcg_score) + ).float() + + return loss, metrics + + def _compute_loss(self, positive_scores, negative_scores): + assert positive_scores.shape[0] == negative_scores.shape[0] + + loss = torch.nn.functional.binary_cross_entropy_with_logits( + positive_scores, torch.ones_like(positive_scores) + ) + torch.nn.functional.binary_cross_entropy_with_logits( + negative_scores, torch.zeros_like(negative_scores) + ) + + return loss diff --git a/scripts/tiger-lsvd/data.py b/scripts/tiger-lsvd/data.py new file mode 100644 index 0000000..4ae0e18 --- /dev/null +++ b/scripts/tiger-lsvd/data.py @@ -0,0 +1,359 @@ +from collections import defaultdict +import json +from loguru import logger +import numpy as np +from pathlib import Path + +import copy +import pyarrow as pa +import pyarrow.feather as feather + +import torch +import polars as pl +from irec.data.base import BaseDataset + + +class InteractionsDatasetParquet(BaseDataset): + def __init__(self, data_path, max_items=None): + self.df = pl.read_parquet(data_path) + assert 'uid' in self.df.columns, "Missing 'uid' column" + assert 'item_ids' in self.df.columns, "Missing 'item_ids' column" + print(f"Dataset loaded: {len(self.df)} users") + + if max_items is not None: + self.df = self.df.with_columns( + pl.col("item_ids").list.slice(-max_items).alias("item_ids") + ) + + def __getitem__(self, idx): + row = self.df.row(idx, named=True) + return { + 'user_id': row['uid'], + 'item_ids': np.array(row['item_ids'], dtype=np.uint32), + } + + def __len__(self): + return len(self.df) + + def __iter__(self): + for idx in range(len(self)): + yield self[idx] + + +class Dataset: + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_items, + max_sequence_length + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_items = num_items + self._max_sequence_length = max_sequence_length + + @classmethod + def create_timestamp_based_parquet( + cls, + train_parquet_path, + validation_parquet_path, + test_parquet_path, + max_sequence_length, + sampler_type, + min_sample_len=2, + is_extended=False, + max_train_events=20 + ): + """ + Загружает данные из parquet файлов с timestamp-based сплитом. + + Ожидает структуру parquet: + - uid: int (user id) + - item_ids: list[int] (список item ids) + + Аналогично create_timestamp_based, но для parquet формата. + """ + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + print(f"started to load datasets from parquet with max train length {max_train_events}") + + # Загружаем parquet файлы + train_df = pl.read_parquet(train_parquet_path) + validation_df = pl.read_parquet(validation_parquet_path) + test_df = pl.read_parquet(test_parquet_path) + + # Проверяем наличие необходимых колонок + for df, name in [(train_df, "train"), (validation_df, "validation"), (test_df, "test")]: + assert 'uid' in df.columns, f"Missing 'uid' column in {name}" + assert 'item_ids' in df.columns, f"Missing 'item_ids' column in {name}" + + # Создаем словари для быстрого доступа + train_data = {str(row['uid']): row['item_ids'] for row in train_df.iter_rows(named=True)} + validation_data = {str(row['uid']): row['item_ids'] for row in validation_df.iter_rows(named=True)} + test_data = {str(row['uid']): row['item_ids'] for row in test_df.iter_rows(named=True)} + + all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys()) + print(f"all users count: {len(all_users)}") + + us_count = 0 + for user_id_str in all_users: + if us_count % 100 == 0: + print(f"user id {us_count}/{len(all_users)}: {user_id_str}") + + user_id = int(user_id_str) + + # Получаем последовательности для каждого сплита + train_items = list(train_data.get(user_id_str, [])) + validation_items = list(validation_data.get(user_id_str, [])) + test_items = list(test_data.get(user_id_str, [])) + + # Обрезаем train на последние max_train_events событий + train_items = train_items[-max_train_events:] + + full_sequence = train_items + validation_items + test_items + if full_sequence: + max_item_id = max(max_item_id, max(full_sequence)) + + if us_count % 100 == 0: + print(f"full sequence len: {len(full_sequence)}") + + us_count += 1 + if len(train_items) < 2: + print(f'Core-2 train dataset is used, user {user_id} has only {len(full_sequence)} items') + continue + + if len(validation_items) < 1 or len(test_items) < 1: + print(f'Test or validation is empty, user {user_id} has only {len(full_sequence)} items') + continue + + + if is_extended: + for prefix_length in range(min_sample_len, len(train_items) + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': copy.deepcopy(train_items[:prefix_length]), + }) + else: + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': copy.deepcopy(train_items), + }) + + # валидация + + + # разворачиваем каждый айтем из валидации в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + + current_history = train_items + valid_small_history = 0 + for item in validation_items: + current_history.append(item) + + if len(current_history) >= min_sample_len: + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': copy.deepcopy(current_history[-max_sequence_length:]), + }) + else: + valid_small_history += 1 + + + # разворачиваем каждый айтем из теста в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + + test_small_history = 0 + for item in test_items: + current_history.append(item) + + if len(current_history) >= min_sample_len: + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': copy.deepcopy(current_history[-max_sequence_length:]), + }) + else: + test_small_history += 1 + + + print(f"Train dataset size: {len(train_dataset)}") + print(f"Validation dataset size: {len(validation_dataset)} with skipped {valid_small_history}") + print(f"Test dataset size: {len(test_dataset)} with skipped {test_small_history}") + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + + def get_datasets(self): + return self._train_sampler, self._validation_sampler, self._test_sampler + + @property + def num_items(self): + return self._num_items + + @property + def max_sequence_length(self): + return self._max_sequence_length + + +class TrainDataset(BaseDataset): + def __init__(self, dataset, prediction_type, max_sequence_length): + self._dataset = dataset + self._prediction_type = prediction_type + self._max_sequence_length = max_sequence_length + + self._transforms = { + 'sasrec': self._all_items_transform, + 'tiger': self._last_item_transform + } + + def _all_items_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array(next_item_sequence, dtype=np.int64), + 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64) + } + + def _last_item_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + last_item = sample['item.ids'][-self._max_sequence_length:][-1] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([last_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + } + + def __getitem__(self, index): + return self._transforms[self._prediction_type](self._dataset[index]) + + def __len__(self): + return len(self._dataset) + + +class EvalDataset(BaseDataset): + def __init__(self, dataset, max_sequence_length): + self._dataset = dataset + self._max_sequence_length = max_sequence_length + + @property + def dataset(self): + return self._dataset + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, index): + sample = self._dataset[index] + + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item = sample['item.ids'][-self._max_sequence_length:][-1] + + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([next_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64), + 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64), + } + + +class ArrowBatchDataset(BaseDataset): + def __init__(self, batch_dir, device='cuda', preload=False): + self.batch_dir = Path(batch_dir) + self.device = device + + all_files = list(self.batch_dir.glob('batch_*_len_*.arrow')) + + batch_files_map = defaultdict(list) + for f in all_files: + batch_id = int(f.stem.split('_')[1]) + batch_files_map[batch_id].append(f) + + for batch_id in batch_files_map: + batch_files_map[batch_id].sort() + + self.batch_indices = sorted(batch_files_map.keys()) + + if preload: + print(f"Preloading {len(self.batch_indices)} batches...") + self.cached_batches = [] + + for idx in range(len(self.batch_indices)): + batch = self._load_batch(batch_files_map[self.batch_indices[idx]]) + self.cached_batches.append(batch) + else: + self.cached_batches = None + self.batch_files_map = batch_files_map + + def _load_batch(self, arrow_files): + batch = {} + + for arrow_file in arrow_files: + table = feather.read_table(arrow_file) + metadata = table.schema.metadata or {} + + for col_name in table.column_names: + col = table.column(col_name) + + shape_key = f'{col_name}_shape' + dtype_key = f'{col_name}_dtype' + + if shape_key.encode() in metadata: + shape = eval(metadata[shape_key.encode()].decode()) + dtype = np.dtype(metadata[dtype_key.encode()].decode()) + + # Проверяем тип колонки + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist(), dtype=dtype) + else: + arr = col.to_numpy().reshape(shape).astype(dtype) + else: + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist()) + else: + arr = col.to_numpy() + + batch[col_name] = torch.from_numpy(arr.copy()).to(self.device) + + return batch + + def __len__(self): + return len(self.batch_indices) + + def __getitem__(self, idx): + if self.cached_batches is not None: + return self.cached_batches[idx] + else: + batch_id = self.batch_indices[idx] + arrow_files = self.batch_files_map[batch_id] + return self._load_batch(arrow_files) diff --git a/scripts/tiger-lsvd/letter_base_gap/lsvd_train_letter_base_gap.py b/scripts/tiger-lsvd/letter_base_gap/lsvd_train_letter_base_gap.py new file mode 100644 index 0000000..475d1fe --- /dev/null +++ b/scripts/tiger-lsvd/letter_base_gap/lsvd_train_letter_base_gap.py @@ -0,0 +1,226 @@ +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../../' +TRAIN_PART_SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/only_base_with_gap_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless_from_all.json" +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/eval_batches/') + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints-lsvd-transformer') + +EXPERIMENT_NAME = 'tiger_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 100 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 512 +NUM_POSITIONS = 80 +NUM_USER_HASH = 8000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + +USE_MICROBATCHING = True +MICROBATCH_SIZE = 256 + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(TRAIN_PART_SEMANTIC_MAPPING_PATH, 'r') as f: + train_part_mapping = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=CorrectItemsLogitsProcessor(NUM_CODEBOOKS, CODEBOOK_SIZE, train_part_mapping, NUM_BEAMS), + use_microbatching=USE_MICROBATCHING, + microbatch_size=MICROBATCH_SIZE + ).to(DEVICE) + + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS * 4), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='validation/ndcg@20', + patience=40, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/letter_base_gap/lsvd_varka_letter_base_gap_on_sem_all_items.py b/scripts/tiger-lsvd/letter_base_gap/lsvd_varka_letter_base_gap_on_sem_all_items.py new file mode 100644 index 0000000..1a5c68f --- /dev/null +++ b/scripts/tiger-lsvd/letter_base_gap/lsvd_varka_letter_base_gap_on_sem_all_items.py @@ -0,0 +1,305 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import Dataset + +print("tiger arrow varka") + +# ПУТИ + +IREC_PATH = '../../../' +INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet" +INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/val_interactions_grouped.parquet" +INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/test_interactions_grouped.parquet" + +SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless.json" +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/letter_base_gap/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01/eval_batches/') + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 8000 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + max_item_id = max(int(k) for k in mapping.keys()) + print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys())) + # print(mapping["280052"]) #304781 + # assert False + data = [] + for i in range(max_item_id + 1): + if str(i) in mapping: + data.append(mapping[str(i)]) + else: + data.append([-1] * NUM_CODEBOOKS) + + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + missing_count = (max_item_id + 1) - len(mapping) + print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)") + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + semantic_ids = self._mapping_tensor[ids].flatten() + + assert (semantic_ids != -1).all(), \ + f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}" + + batch[f'{name}.semantic.ids'] = semantic_ids.numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + print("варка может начать умирать") + data = Dataset.create_timestamp_based_parquet( + train_parquet_path=INTERACTIONS_TRAIN_PATH, + validation_parquet_path=INTERACTIONS_VALID_PATH, + test_parquet_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True, + max_train_events=MAX_SEQ_LEN + ) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + print("варка не умерла") + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/models.py b/scripts/tiger-lsvd/models.py new file mode 100644 index 0000000..03c9b74 --- /dev/null +++ b/scripts/tiger-lsvd/models.py @@ -0,0 +1,227 @@ +import torch +from transformers import T5ForConditionalGeneration, T5Config, LogitsProcessor + +from irec.models import TorchModel + + +class CorrectItemsLogitsProcessor(LogitsProcessor): + def __init__(self, num_codebooks, codebook_size, mapping, num_beams): + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.num_beams = num_beams + + semantic_ids = [] + for codes in mapping.values(): + assert len(codes) == num_codebooks, 'All semantic ids must have the same length' + semantic_ids.append(codes) + + print(f"semantic ids count (allowed for generation): {len(semantic_ids)}") + + self.index_semantic_ids = torch.tensor(semantic_ids, dtype=torch.long, device='cuda') # (num_items, semantic_ids) + + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if (input_ids == -1).any(): + raise AssertionError(f"LogitsProcessor received -1 in input_ids. \nIndices: {torch.nonzero(input_ids == -1)}") + + batch_size = input_ids.shape[0]//self.num_beams + + next_sid_codebook_num = (torch.minimum((input_ids[:, -1].max() // self.codebook_size), torch.as_tensor(self.num_codebooks - 1)).item() + 1) % self.num_codebooks + a = torch.tile(self.index_semantic_ids[None, None, :, next_sid_codebook_num], dims=[batch_size, self.num_beams, 1]) # (batch_size, num_beams, num_items) + a = a.reshape(a.shape[0] * a.shape[1], a.shape[2]) # (batch_size * num_beams, num_items) + + if next_sid_codebook_num != 0: + b = torch.tile(self.index_semantic_ids[None, None :, :next_sid_codebook_num], dims=[batch_size, self.num_beams, 1, 1]) # (batch_size, num_beams, num_items, sid_len) + b = b.reshape(b.shape[0] * b.shape[1], b.shape[2], b.shape[3]) # (batch_size * num_beams, num_items, sid_len) + + current_prefixes = input_ids[:, -next_sid_codebook_num:] # (batch_size * num_beams, sid_len) + possible_next_items_mask = ( + torch.eq(current_prefixes[:, None, :], b).long().sum(dim=-1) == next_sid_codebook_num + ) # (batch_size * num_beams, num_items) + a[~possible_next_items_mask] = (next_sid_codebook_num + 1) * self.codebook_size + + scores_mask = torch.zeros_like(scores).bool() # (batch_size * num_beams, num_items) + scores_mask = torch.scatter_add( + input=scores_mask, + dim=-1, + index=a, + src=torch.ones_like(a).bool() + ) + + scores[:, :next_sid_codebook_num * self.codebook_size] = -torch.inf + scores[:, (next_sid_codebook_num + 1) * self.codebook_size:] = -torch.inf + scores[~(scores_mask.bool())] = -torch.inf + + return scores + + + +class TigerModel(TorchModel): + def __init__( + self, + embedding_dim, + codebook_size, + sem_id_len, + num_positions, + user_ids_count, + num_heads, + num_encoder_layers, + num_decoder_layers, + dim_feedforward, + num_beams=100, + num_return_sequences=20, + d_kv=64, + layer_norm_eps=1e-6, + activation='relu', + dropout=0.1, + initializer_range=0.02, + logits_processor=None, + use_microbatching=False, + microbatch_size=128 + ): + super().__init__() + self._embedding_dim = embedding_dim + self._codebook_size = codebook_size + self._num_positions = num_positions + self._num_heads = num_heads + self._num_encoder_layers = num_encoder_layers + self._num_decoder_layers = num_decoder_layers + self._dim_feedforward = dim_feedforward + self._num_beams = num_beams + self._num_return_sequences = num_return_sequences + self._d_kv = d_kv + self._layer_norm_eps = layer_norm_eps + self._activation = activation + self._dropout = dropout + self._sem_id_len = sem_id_len + self.user_ids_count = user_ids_count + self.logits_processor = logits_processor + self._use_microbatching = use_microbatching + self._microbatch_size = microbatch_size + + unified_vocab_size = codebook_size * self._sem_id_len + self.user_ids_count + 10 # 10 for utilities + self.config = T5Config( + vocab_size=unified_vocab_size, + d_model=self._embedding_dim, + d_kv=self._d_kv, + d_ff=self._dim_feedforward, + num_layers=self._num_encoder_layers, + num_decoder_layers=self._num_decoder_layers, + num_heads=self._num_heads, + dropout_rate=self._dropout, + is_encoder_decoder=True, + use_cache=False, + pad_token_id=unified_vocab_size - 1, + eos_token_id=unified_vocab_size - 2, + decoder_start_token_id=unified_vocab_size - 3, + layer_norm_epsilon=self._layer_norm_eps, + feed_forward_proj=self._activation, + tie_word_embeddings=False + ) + self.model = T5ForConditionalGeneration(config=self.config) + self._init_weights(initializer_range) + + self.model = torch.compile( + self.model, + mode='reduce-overhead', + fullgraph=False, + dynamic=True + ) + + def forward(self, inputs): + input_semantic_ids = inputs['input.data'] + attention_mask = inputs['input.mask'] + target_semantic_ids = inputs['output.data'] + + assert (input_semantic_ids != -1).all(), \ + f"Found -1 in inputs['input.data']. Check your DataLoader/Collator." + assert (target_semantic_ids != -1).all(), \ + f"Found -1 in inputs['output.data']. Check your DataLoader/Collator." + + decoder_input_ids = target_semantic_ids[:, :-1].contiguous() + labels = target_semantic_ids[:, 1:].contiguous() + + model_output = self.model( + input_ids=input_semantic_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + labels=labels + ) + loss = model_output['loss'] + + metrics = {'loss': loss.detach()} + + if not self.training and not self._use_microbatching: + # visited_batch = inputs['visited.padded'] + + output = self.model.generate( + input_ids=input_semantic_ids, + attention_mask=attention_mask, + num_beams=self._num_beams, + num_return_sequences=self._num_return_sequences, + max_length=self._sem_id_len + 1, + decoder_start_token_id=self.config.decoder_start_token_id, + eos_token_id=self.config.eos_token_id, + pad_token_id=self.config.pad_token_id, + do_sample=False, + early_stopping=False, + logits_processor=[self.logits_processor] if self.logits_processor is not None else [], #попробовать не маскировать + ) + + assert (output != -1).all(), "Model.generate returned -1 in raw output" + + predictions = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len) + + all_hits = (torch.eq(predictions, labels[:, None]).sum(dim=-1)) # (batch_size, top_k) + elif not self.training and self._use_microbatching: + # visited_batch = inputs['visited.padded'] + batch_size = input_semantic_ids.shape[0] + + inference_batch_size = self._microbatch_size # вместо полного batch_size + + all_predictions = [] + all_labels = [] + # print(f"start to infer batch of shape {input_semantic_ids.shape} with new batch {inference_batch_size}") + for batch_idx in range(0, batch_size, inference_batch_size): + batch_end = min(batch_idx + inference_batch_size, batch_size) + batch_slice = slice(batch_idx, batch_end) + + input_ids_batch = input_semantic_ids[batch_slice] + attention_mask_batch = attention_mask[batch_slice] + # visited_batch_subset = visited_batch[batch_slice] + labels_batch = labels[batch_slice] + + with torch.inference_mode(): + output = self.model.generate( + input_ids=input_ids_batch, + attention_mask=attention_mask_batch, + num_beams=self._num_beams, + num_return_sequences=self._num_return_sequences, + max_length=self._sem_id_len + 1, + decoder_start_token_id=self.config.decoder_start_token_id, + eos_token_id=self.config.eos_token_id, + pad_token_id=self.config.pad_token_id, + do_sample=False, + early_stopping=False, + logits_processor=[self.logits_processor] if self.logits_processor is not None else [], + ) + + predictions_batch = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len) + all_predictions.append(predictions_batch) + all_labels.append(labels_batch) + # print("end infer of batch") + + predictions = torch.cat(all_predictions, dim=0) # (batch_size, num_return_sequences, sem_id_len) + labels_full = torch.cat(all_labels, dim=0) # (batch_size, sem_id_len) + all_hits = (torch.eq(predictions, labels_full[:, None]).sum(dim=-1)) # (batch_size, top_k) + + if not self.training: + for k in [5, 10, 20]: + hits = (all_hits[:, :k] == self._sem_id_len).float() # (batch_size, k) + recall = hits.sum(dim=-1) # (batch_size) + discount_factor = 1 / torch.log2(torch.arange(1, k + 1, 1).float() + 1.).to(hits.device) # (k) + + metrics[f'recall@{k}'] = recall.cpu().float() + metrics[f'ndcg@{k}'] = torch.einsum('bk,k->b', hits, discount_factor).cpu().float() + + return loss, metrics \ No newline at end of file diff --git a/scripts/tiger-lsvd/plum_base_gap/lsvd_train_plum_base_gap.py b/scripts/tiger-lsvd/plum_base_gap/lsvd_train_plum_base_gap.py new file mode 100644 index 0000000..4d3479d --- /dev/null +++ b/scripts/tiger-lsvd/plum_base_gap/lsvd_train_plum_base_gap.py @@ -0,0 +1,226 @@ +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../../' +TRAIN_PART_SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/plum/only_base_with_gap_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_clusters_colisionless_from_all.json" +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/eval_batches/') + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints-lsvd-transformer') + +EXPERIMENT_NAME = 'TEST_tiger_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 200 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 512 +NUM_POSITIONS = 80 +NUM_USER_HASH = 8000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + +USE_MICROBATCHING = True +MICROBATCH_SIZE = 256 + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(TRAIN_PART_SEMANTIC_MAPPING_PATH, 'r') as f: + train_part_mapping = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=CorrectItemsLogitsProcessor(NUM_CODEBOOKS, CODEBOOK_SIZE, train_part_mapping, NUM_BEAMS), + use_microbatching=USE_MICROBATCHING, + microbatch_size=MICROBATCH_SIZE + ).to(DEVICE) + + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS * 4), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS * 4), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='validation/ndcg@20', + patience=12, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/plum_base_gap/lsvd_varka_plum_base_gap_on_sem_all_items.py b/scripts/tiger-lsvd/plum_base_gap/lsvd_varka_plum_base_gap_on_sem_all_items.py new file mode 100644 index 0000000..120d343 --- /dev/null +++ b/scripts/tiger-lsvd/plum_base_gap/lsvd_varka_plum_base_gap_on_sem_all_items.py @@ -0,0 +1,305 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import Dataset + +print("tiger arrow varka") + +# ПУТИ + +IREC_PATH = '../../../' +INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet" +INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/val_interactions_grouped.parquet" +INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/test_interactions_grouped.parquet" + +SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0_clusters_colisionless.json" +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/plum_base_gap/all_items_plum_vk-lsvd-15ts_base_with_gap_cb_512_ws_2_k_3000_e35_con_0.01_rqvae_1.0/eval_batches/') + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 8000 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + max_item_id = max(int(k) for k in mapping.keys()) + print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys())) + # print(mapping["280052"]) #304781 + # assert False + data = [] + for i in range(max_item_id + 1): + if str(i) in mapping: + data.append(mapping[str(i)]) + else: + data.append([-1] * NUM_CODEBOOKS) + + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + missing_count = (max_item_id + 1) - len(mapping) + print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)") + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + semantic_ids = self._mapping_tensor[ids].flatten() + + assert (semantic_ids != -1).all(), \ + f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}" + + batch[f'{name}.semantic.ids'] = semantic_ids.numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + print("варка может начать умирать") + data = Dataset.create_timestamp_based_parquet( + train_parquet_path=INTERACTIONS_TRAIN_PATH, + validation_parquet_path=INTERACTIONS_VALID_PATH, + test_parquet_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True, + max_train_events=MAX_SEQ_LEN + ) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + print("варка не умерла") + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/rqvae_base_gap/lsvd_train_rqvae_base_gap.py b/scripts/tiger-lsvd/rqvae_base_gap/lsvd_train_rqvae_base_gap.py new file mode 100644 index 0000000..a1bccf5 --- /dev/null +++ b/scripts/tiger-lsvd/rqvae_base_gap/lsvd_train_rqvae_base_gap.py @@ -0,0 +1,226 @@ +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + + +# ПУТИ +IREC_PATH = '../../../' +TRAIN_PART_SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/only_base_with_gap_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_clusters_colisionless_from_all.json" +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/eval_batches/') + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints-lsvd-transformer') + +EXPERIMENT_NAME = 'tiger_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0' + +# ОСТАЛЬНОЕ +SEED_VALUE = 42 +DEVICE = 'cuda' + +NUM_EPOCHS = 100 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 512 +NUM_POSITIONS = 80 +NUM_USER_HASH = 8000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.2 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 0.0001 + +USE_MICROBATCHING = True +MICROBATCH_SIZE = 256 + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(TRAIN_PART_SEMANTIC_MAPPING_PATH, 'r') as f: + train_part_mapping = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + TRAIN_BATCHES_DIR, + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + VALID_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + EVAL_BATCHES_DIR, + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=CorrectItemsLogitsProcessor(NUM_CODEBOOKS, CODEBOOK_SIZE, train_part_mapping, NUM_BEAMS), + use_microbatching=USE_MICROBATCHING, + microbatch_size=MICROBATCH_SIZE + ).to(DEVICE) + + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS * 4), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), + + cb.EarlyStopping( + metric='validation/ndcg@20', + patience=40 * 4, + minimize=False, + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=TENSORBOARD_LOGDIR + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger-lsvd/rqvae_base_gap/lsvd_varka_rqvae_base_gap_on_sem_all_items.py b/scripts/tiger-lsvd/rqvae_base_gap/lsvd_varka_rqvae_base_gap_on_sem_all_items.py new file mode 100644 index 0000000..3b06988 --- /dev/null +++ b/scripts/tiger-lsvd/rqvae_base_gap/lsvd_varka_rqvae_base_gap_on_sem_all_items.py @@ -0,0 +1,305 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from data import Dataset + +print("tiger arrow varka") + +# ПУТИ + +IREC_PATH = '../../../' +INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet" +INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/val_interactions_grouped.parquet" +INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/test_interactions_grouped.parquet" + +SEMANTIC_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_clusters_colisionless.json" +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd-2/rqvae_base_gap/all_items_rqvae_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0/eval_batches/') + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 8000 +CODEBOOK_SIZE = 512 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + max_item_id = max(int(k) for k in mapping.keys()) + print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys())) + # print(mapping["280052"]) #304781 + # assert False + data = [] + for i in range(max_item_id + 1): + if str(i) in mapping: + data.append(mapping[str(i)]) + else: + data.append([-1] * NUM_CODEBOOKS) + + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + missing_count = (max_item_id + 1) - len(mapping) + print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)") + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + semantic_ids = self._mapping_tensor[ids].flatten() + + assert (semantic_ids != -1).all(), \ + f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}" + + batch[f'{name}.semantic.ids'] = semantic_ids.numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + print("варка может начать умирать") + data = Dataset.create_timestamp_based_parquet( + train_parquet_path=INTERACTIONS_TRAIN_PATH, + validation_parquet_path=INTERACTIONS_VALID_PATH, + test_parquet_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True, + max_train_events=MAX_SEQ_LEN + ) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + print("варка не умерла") + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/sigir/lsvd_processing/LsvdDownload2.ipynb b/sigir/lsvd_processing/LsvdDownload2.ipynb new file mode 100644 index 0000000..d121678 --- /dev/null +++ b/sigir/lsvd_processing/LsvdDownload2.ipynb @@ -0,0 +1,628 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "SbkKok0dfjjS" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import polars as pl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'1.8.2'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pl.__version__" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "HF_ENDPOINT=\"http://huggingface.proxy\" hf download deepvk/VK-LSVD --repo-type dataset --include \"metadata/*\" --local-dir /home/jovyan/IRec/sigir/lsvd_data/raw\n", + "\n", + "HF_ENDPOINT=\"http://huggingface.proxy\" hf download deepvk/VK-LSVD --repo-type dataset --include \"subsamples/ur0.01_ir0.01/*\" --local-dir /home/jovyan/IRec/sigir/lsvd_data/raw\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Разбиение сабсэмплов на базовую, гэп, вал и тест части\n", + "\n", + "Добавляется колонка original_order чтобы сохранять порядок внутри каждой из частей" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(22, 1, 1, 1, 25, 23)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "subsample_name = 'ur0.01_ir0.01'\n", + "content_embedding_size = 256\n", + "DATASET_PATH = \"/home/jovyan/IRec/sigir/lsvd_data_filtered/raw\"\n", + "\n", + "metadata_files = ['metadata/users_metadata.parquet',\n", + " 'metadata/items_metadata.parquet',\n", + " 'metadata/item_embeddings.npz']\n", + "\n", + "BASE_WEEKS = (0, 22)\n", + "GAP_WEEKS = (22, 23) #увеличить гэп\n", + "VAL_WEEKS = (23, 24)\n", + "\n", + "base_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n", + " for i in range(BASE_WEEKS[0], BASE_WEEKS[1])]\n", + "\n", + "gap_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n", + " for i in range(GAP_WEEKS[0], GAP_WEEKS[1])]\n", + "\n", + "val_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n", + " for i in range(VAL_WEEKS[0], VAL_WEEKS[1])]\n", + "\n", + "test_interactions_files = [f'subsamples/{subsample_name}/train/week_24.parquet']\n", + "\n", + "all_interactions_files = base_interactions_files + gap_interactions_files + val_interactions_files + test_interactions_files\n", + "\n", + "base_with_gap_interactions_files = base_interactions_files + gap_interactions_files\n", + "\n", + "len(base_interactions_files), len(gap_interactions_files), len(val_interactions_files), len(test_interactions_files), len(all_interactions_files), len(base_with_gap_interactions_files)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('subsamples/ur0.01_ir0.01/train/week_21.parquet',\n", + " ['subsamples/ur0.01_ir0.01/train/week_22.parquet'],\n", + " ['subsamples/ur0.01_ir0.01/train/week_23.parquet'],\n", + " ['subsamples/ur0.01_ir0.01/train/week_24.parquet'])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "base_interactions_files[-1], gap_interactions_files, val_interactions_files, test_interactions_files" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "def get_parquet_interactions(data_files, positive_event_timespent):\n", + " data_interactions = pl.concat([pl.scan_parquet(f'{DATASET_PATH}/{file}')\n", + " for file in data_files])\n", + " data_interactions = data_interactions.collect(streaming=True)\n", + " data_interactions = data_interactions.with_row_index(\"original_order\")\n", + " data_interactions = data_interactions.filter(pl.col('timespent') > positive_event_timespent)\n", + " return data_interactions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "POSITIVE_EVENT_TIMESPENT = 15\n", + "base_interactions = get_parquet_interactions(base_interactions_files, POSITIVE_EVENT_TIMESPENT)\n", + "gap_interactions = get_parquet_interactions(gap_interactions_files, POSITIVE_EVENT_TIMESPENT)\n", + "val_interactions = get_parquet_interactions(val_interactions_files, POSITIVE_EVENT_TIMESPENT)\n", + "test_interactions = get_parquet_interactions(test_interactions_files, POSITIVE_EVENT_TIMESPENT)\n", + "all_data_interactions = get_parquet_interactions(all_interactions_files, POSITIVE_EVENT_TIMESPENT)\n", + "base_with_gap_interactions = get_parquet_interactions(base_with_gap_interactions_files, POSITIVE_EVENT_TIMESPENT)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Загрузка и фильтрация эмбеддингов" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "all_data_users = all_data_interactions.select('user_id').unique()\n", + "all_data_items = all_data_interactions.select('item_id').unique()\n", + "\n", + "item_ids = np.load(f\"{DATASET_PATH}/metadata/item_embeddings.npz\")['item_id']\n", + "item_embeddings = np.load(f\"{DATASET_PATH}/metadata/item_embeddings.npz\")['embedding']\n", + "\n", + "mask = np.isin(item_ids, all_data_items.to_numpy())\n", + "item_ids = item_ids[mask]\n", + "item_embeddings = item_embeddings[mask]\n", + "item_embeddings = item_embeddings[:, :content_embedding_size]\n", + "\n", + "users_metadata = pl.read_parquet(f\"{DATASET_PATH}/metadata/users_metadata.parquet\")\n", + "items_metadata = pl.read_parquet(f\"{DATASET_PATH}/metadata/items_metadata.parquet\")\n", + "\n", + "users_metadata = users_metadata.join(all_data_users, on='user_id')\n", + "items_metadata = items_metadata.join(all_data_items, on='item_id')\n", + "items_metadata = items_metadata.join(pl.DataFrame({'item_id': item_ids, \n", + " 'embedding': item_embeddings}), on='item_id')\n", + "\n", + "only_base_items_metadata = items_metadata.join(base_interactions.select('item_id').unique(), on='item_id')\n", + "only_base_with_gap_items_metadata = items_metadata.join(base_with_gap_interactions.select('item_id').unique(), on='item_id')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Сжатие айтем айди и ремапинг" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total users: 77403, Total items: 65659\n" + ] + } + ], + "source": [ + "all_data_items = all_data_interactions.select('item_id').unique()\n", + "all_data_users = all_data_interactions.select('user_id').unique()\n", + "\n", + "unique_items_sorted = all_data_items.sort('item_id').with_row_index('new_item_id')\n", + "global_item_mapping = dict(zip(unique_items_sorted['item_id'], unique_items_sorted['new_item_id']))\n", + "\n", + "print(f\"Total users: {all_data_users.shape[0]}, Total items: {len(global_item_mapping)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "def remap_interactions(df, mapping):\n", + " return df.with_columns(\n", + " pl.col('item_id')\n", + " .map_elements(lambda x: mapping.get(x, None), return_dtype=pl.UInt32)\n", + " )\n", + "\n", + "base_interactions_remapped = remap_interactions(base_interactions, global_item_mapping)\n", + "gap_interactions_remapped = remap_interactions(gap_interactions, global_item_mapping)\n", + "test_interactions_remapped = remap_interactions(test_interactions, global_item_mapping)\n", + "val_interactions_remapped = remap_interactions(val_interactions, global_item_mapping)\n", + "all_data_interactions_remapped = remap_interactions(all_data_interactions, global_item_mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "del base_interactions, gap_interactions, test_interactions, val_interactions, all_data_interactions" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "base_with_gap_interactions_remapped = remap_interactions(base_with_gap_interactions, global_item_mapping)\n", + "del base_with_gap_interactions" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "items_metadata_remapped = remap_interactions(items_metadata, global_item_mapping)\n", + "only_base_items_metadata_remapped = remap_interactions(only_base_items_metadata, global_item_mapping)\n", + "only_base_with_gap_items_metadata_remapped = remap_interactions(only_base_with_gap_items_metadata, global_item_mapping)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Группировка по юзер айди" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "interactions count: (1206762, 13)\n", + "users count: (75281, 3)\n", + "interactions count: (66879, 13)\n", + "users count: (28610, 3)\n", + "interactions count: (69887, 13)\n", + "users count: (28982, 3)\n", + "interactions count: (73637, 13)\n", + "users count: (30231, 3)\n", + "interactions count: (1417165, 13)\n", + "users count: (77403, 3)\n" + ] + } + ], + "source": [ + "def get_grouped_interactions(data_interactions):\n", + " print(f\"interactions count: {data_interactions.shape}\")\n", + " data_res = (\n", + " data_interactions\n", + " .select(['original_order', 'user_id', 'item_id'])\n", + " .group_by('user_id')\n", + " .agg(\n", + " pl.col('item_id')\n", + " .sort_by(pl.col('original_order'))\n", + " .alias('item_ids'),\n", + " pl.col('original_order').alias('timestamps')\n", + " )\n", + " .rename({'user_id': 'uid'})\n", + " )\n", + " print(f\"users count: {data_res.shape}\")\n", + " return data_res\n", + "base_interactions_grouped = get_grouped_interactions(base_interactions_remapped)\n", + "gap_interactions_grouped = get_grouped_interactions(gap_interactions_remapped)\n", + "test_interactions_grouped = get_grouped_interactions(test_interactions_remapped)\n", + "val_interactions_grouped = get_grouped_interactions(val_interactions_remapped)\n", + "all_data_interactions_grouped = get_grouped_interactions(all_data_interactions_remapped)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1, 3)
uiditem_idstimestamps
u32list[u32]list[u32]
4651912[39251, 11631, … 18532][181275, 1237407, … 3130963]
" + ], + "text/plain": [ + "shape: (1, 3)\n", + "┌─────────┬─────────────────────────┬──────────────────────────────┐\n", + "│ uid ┆ item_ids ┆ timestamps │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ list[u32] ┆ list[u32] │\n", + "╞═════════╪═════════════════════════╪══════════════════════════════╡\n", + "│ 4651912 ┆ [39251, 11631, … 18532] ┆ [181275, 1237407, … 3130963] │\n", + "└─────────┴─────────────────────────┴──────────────────────────────┘" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "base_interactions_grouped.head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "del base_interactions_remapped, gap_interactions_remapped, test_interactions_remapped, val_interactions_remapped" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "interactions count: (1273641, 13)\n", + "users count: (76011, 3)\n" + ] + } + ], + "source": [ + "base_with_gap_interactions_grouped = get_grouped_interactions(base_with_gap_interactions_remapped)\n", + "del base_with_gap_interactions_remapped" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Сохранение" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Сохранён маппинг: /home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/global_item_mapping.json\n" + ] + } + ], + "source": [ + "import json\n", + "OUTPUT_DIR = \"/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows\"\n", + "\n", + "mapping_output_path = f\"{OUTPUT_DIR}/global_item_mapping.json\"\n", + "\n", + "with open(mapping_output_path, 'w') as f:\n", + " json.dump({str(k): v for k, v in global_item_mapping.items()}, f, indent=2)\n", + "\n", + "print(f\"Сохранён маппинг: {mapping_output_path}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "размерность: (65659, 5)\n", + "Сохранен файл: items_metadata_remapped\n", + "размерность: (65659, 5)\n", + "Сохранен файл: items_metadata_old\n", + "размерность: (60552, 5)\n", + "Сохранен файл: only_base_items_metadata_remapped\n", + "размерность: (60552, 5)\n", + "Сохранен файл: only_base_items_metadata_old\n", + "размерность: (62362, 5)\n", + "Сохранен файл: only_base_with_gap_items_metadata_remapped\n", + "размерность: (62362, 5)\n", + "Сохранен файл: only_base_with_gap_items_metadata_old\n", + "размерность: (75281, 3)\n", + "Сохранен файл: base_interactions_grouped\n", + "размерность: (28610, 3)\n", + "Сохранен файл: gap_interactions_grouped\n", + "размерность: (28982, 3)\n", + "Сохранен файл: test_interactions_grouped\n", + "размерность: (30231, 3)\n", + "Сохранен файл: val_interactions_grouped\n", + "размерность: (76011, 3)\n", + "Сохранен файл: base_with_gap_interactions_grouped\n", + "размерность: (77403, 3)\n", + "Сохранен файл: all_data_interactions_grouped\n", + "размерность: (1417165, 13)\n", + "Сохранен файл: all_data_interactions_remapped\n" + ] + } + ], + "source": [ + "def write_parquet(output_dir, data, file_name):\n", + " print(f\"размерность: {data.shape}\")\n", + " output_parquet_path = f\"{output_dir}/{file_name}.parquet\"\n", + " data.write_parquet(output_parquet_path)\n", + " print(f\"Сохранен файл: {file_name}\")\n", + "\n", + "write_parquet(OUTPUT_DIR, items_metadata_remapped, \"items_metadata_remapped\")\n", + "write_parquet(OUTPUT_DIR, items_metadata, \"items_metadata_old\")\n", + "\n", + "write_parquet(OUTPUT_DIR, only_base_items_metadata_remapped, \"only_base_items_metadata_remapped\")\n", + "write_parquet(OUTPUT_DIR, only_base_items_metadata, \"only_base_items_metadata_old\")\n", + "\n", + "write_parquet(OUTPUT_DIR, only_base_with_gap_items_metadata_remapped, \"only_base_with_gap_items_metadata_remapped\")\n", + "write_parquet(OUTPUT_DIR, only_base_with_gap_items_metadata, \"only_base_with_gap_items_metadata_old\")\n", + "\n", + "write_parquet(OUTPUT_DIR, base_interactions_grouped, \"base_interactions_grouped\")\n", + "write_parquet(OUTPUT_DIR, gap_interactions_grouped, \"gap_interactions_grouped\")\n", + "write_parquet(OUTPUT_DIR, test_interactions_grouped, \"test_interactions_grouped\")\n", + "write_parquet(OUTPUT_DIR, val_interactions_grouped, \"val_interactions_grouped\")\n", + "write_parquet(OUTPUT_DIR, base_with_gap_interactions_grouped, \"base_with_gap_interactions_grouped\")\n", + "\n", + "write_parquet(OUTPUT_DIR, all_data_interactions_grouped, \"all_data_interactions_grouped\")\n", + "\n", + "write_parquet(OUTPUT_DIR, all_data_interactions_remapped, \"all_data_interactions_remapped\")" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(64, 64)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(list(items_metadata_remapped.head(1)['embedding'].item())), len(list(items_metadata.head(1)['embedding'].item()))" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((65659, 5), (60552, 5), (62362, 5))" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_metadata_remapped.shape, only_base_items_metadata_remapped.shape, only_base_with_gap_items_metadata_remapped.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1, 5)
item_idauthor_iddurationtrain_interactions_rankembedding
u32u32u8u32array[f32, 64]
06517057114592477[-0.281006, -0.145508, … -0.058258]
" + ], + "text/plain": [ + "shape: (1, 5)\n", + "┌─────────┬───────────┬──────────┬─────────────────────────┬─────────────────────────────────┐\n", + "│ item_id ┆ author_id ┆ duration ┆ train_interactions_rank ┆ embedding │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ u32 ┆ u8 ┆ u32 ┆ array[f32, 64] │\n", + "╞═════════╪═══════════╪══════════╪═════════════════════════╪═════════════════════════════════╡\n", + "│ 0 ┆ 651705 ┆ 71 ┆ 14592477 ┆ [-0.281006, -0.145508, … -0.05… │\n", + "└─────────┴───────────┴──────────┴─────────────────────────┴─────────────────────────────────┘" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "items_metadata_remapped.head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
uiditem_idstimestamps
u32list[u32]list[u32]
1450406[65092, 7570][2309584, 3257670]
3736304[39094, 42731, … 10818][2691900, 2743200, … 3232861]
775866[52462, 44617, … 64047][101336, 161036, … 2789649]
777572[48639, 6069][1056275, 1218815]
1942489[44575, 6264, 32234][372244, 955295, 2432349]
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌─────────┬─────────────────────────┬───────────────────────────────┐\n", + "│ uid ┆ item_ids ┆ timestamps │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ u32 ┆ list[u32] ┆ list[u32] │\n", + "╞═════════╪═════════════════════════╪═══════════════════════════════╡\n", + "│ 1450406 ┆ [65092, 7570] ┆ [2309584, 3257670] │\n", + "│ 3736304 ┆ [39094, 42731, … 10818] ┆ [2691900, 2743200, … 3232861] │\n", + "│ 775866 ┆ [52462, 44617, … 64047] ┆ [101336, 161036, … 2789649] │\n", + "│ 777572 ┆ [48639, 6069] ┆ [1056275, 1218815] │\n", + "│ 1942489 ┆ [44575, 6264, 32234] ┆ [372244, 955295, 2432349] │\n", + "└─────────┴─────────────────────────┴───────────────────────────────┘" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "base_with_gap_interactions_grouped.head()" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}