Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions scripts/plum-lsvd/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch

import irec.callbacks as cb
from irec.runners import TrainingRunner, TrainingRunnerContext

class InitCodebooks(cb.TrainingCallback):
def __init__(self, dataloader):
super().__init__()
self._dataloader = dataloader

@torch.no_grad()
def before_run(self, runner: TrainingRunner):
for i in range(len(runner.model.codebooks)):
X = next(iter(self._dataloader))['embedding']
idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])]
remainder = runner.model.encoder(X[idx])

for j in range(i):
codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j])
codebook_vectors = runner.model.codebooks[j][codebook_indices]
remainder = remainder - codebook_vectors

runner.model.codebooks[i].data = remainder.detach()


class FixDeadCentroids(cb.TrainingCallback):
def __init__(self, dataloader):
super().__init__()
self._dataloader = dataloader

def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext):
for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)):
context.metrics[f'num_dead/{i}'] = num_fixed

@torch.no_grad()
def fix_dead_codebooks(self, runner: TrainingRunner):
num_fixed = []
for codebook_idx, codebook in enumerate(runner.model.codebooks):
centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device)
random_batch = next(iter(self._dataloader))['embedding']

for batch in self._dataloader:
remainder = runner.model.encoder(batch['embedding'])
for l in range(codebook_idx):
ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
remainder = remainder - runner.model.codebooks[l][ind]

indices = runner.model.get_codebook_indices(remainder, codebook)
centroid_counts.scatter_add_(0, indices, torch.ones_like(indices))

dead_mask = (centroid_counts == 0)
num_dead = int(dead_mask.sum().item())
num_fixed.append(num_dead)
if num_dead == 0:
continue

remainder = runner.model.encoder(random_batch)
for l in range(codebook_idx):
ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
remainder = remainder - runner.model.codebooks[l][ind]
remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead]
codebook[dead_mask] = remainder.detach()

return num_fixed
117 changes: 117 additions & 0 deletions scripts/plum-lsvd/cooc_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import json
from collections import defaultdict, Counter
from data import InteractionsDatasetParquet
from collections import defaultdict, Counter


class CoocMappingDataset:
def __init__(
self,
train_sampler,
num_items,
cooccur_counter_mapping=None
):
self._train_sampler = train_sampler
self._num_items = num_items
self._cooccur_counter_mapping = cooccur_counter_mapping

@classmethod
def create(cls, inter_json_path, window_size):
max_item_id = 0
train_dataset = []

with open(inter_json_path, 'r') as f:
user_interactions = json.load(f)

for user_id_str, item_ids in user_interactions.items():
user_id = int(user_id_str)
if item_ids:
max_item_id = max(max_item_id, max(item_ids))
if len(item_ids) >= 5:
print(f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items')
train_dataset.append({
'user_ids': [user_id],
'item_ids': item_ids[:-2],
})


cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size)
print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}')


train_sampler = train_dataset


return cls(
train_sampler=train_sampler,
num_items=max_item_id + 1,
cooccur_counter_mapping=cooccur_counter_mapping
)


@classmethod
def create_from_split_part(
cls,
train_inter_parquet_path,
window_size,
):

max_item_id = 0
train_dataset = []


train_interactions = InteractionsDatasetParquet(train_inter_parquet_path)

actions_num = 0
for session in train_interactions:
user_id, item_ids = int(session['user_id']), session['item_ids']
if item_ids.any():
max_item_id = max(max_item_id, max(item_ids))
actions_num += len(item_ids)
train_dataset.append({
'user_ids': [user_id],
'item_ids': item_ids,
})


print(f'Train: {len(train_dataset)} users')
print(f'Max item ID: {max_item_id}')
print(f"Actions num: {actions_num}")


cooccur_counter_mapping = cls.build_cooccur_counter_mapping(
train_dataset,
window_size=window_size
)


print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items')


return cls(
train_sampler=train_dataset,
num_items=max_item_id + 1,
cooccur_counter_mapping=cooccur_counter_mapping
)



@staticmethod
def build_cooccur_counter_mapping(train_dataset, window_size):
cooccur_counts = defaultdict(Counter)
for session in train_dataset:
items = session['item_ids']
for i in range(len(items)):
item_i = items[i]
for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)):
if i != j:
cooccur_counts[item_i][items[j]] += 1
max_hist_len = max(len(counter) for counter in cooccur_counts.values()) if cooccur_counts else 0
print(f"Max cooccurrence history length is {max_hist_len}")
return cooccur_counts



@property
def cooccur_counter_mapping(self):
return self._cooccur_counter_mapping
87 changes: 87 additions & 0 deletions scripts/plum-lsvd/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np
import pickle

from irec.data.base import BaseDataset
from irec.data.transforms import Transform


import polars as pl

class InteractionsDatasetParquet(BaseDataset):
def __init__(self, data_path, max_items=None):
self.df = pl.read_parquet(data_path)
assert 'uid' in self.df.columns, "Missing 'uid' column"
assert 'item_ids' in self.df.columns, "Missing 'item_ids' column"
print(f"Dataset loaded: {len(self.df)} users")

if max_items is not None:
self.df = self.df.with_columns(
pl.col("item_ids").list.slice(-max_items).alias("item_ids")
)

def __getitem__(self, idx):
row = self.df.row(idx, named=True)
return {
'user_id': row['uid'],
'item_ids': np.array(row['item_ids'], dtype=np.uint32),
}

def __len__(self):
return len(self.df)

def __iter__(self):
for idx in range(len(self)):
yield self[idx]


class EmbeddingDatasetParquet(BaseDataset):
def __init__(self, data_path):
self.df = pl.read_parquet(data_path)
self.item_ids = np.array(self.df['item_id'], dtype=np.int64)
self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32)
print(f"embedding dim: {self.embeddings[0].shape}")

def __getitem__(self, idx):
index = self.item_ids[idx]
tensor_emb = self.embeddings[idx]
return {
'item_id': index,
'embedding': tensor_emb,
'embedding_dim': len(tensor_emb)
}

def __len__(self):
return len(self.embeddings)


class EmbeddingDataset(BaseDataset):
def __init__(self, data_path):
self.data_path = data_path
with open(data_path, 'rb') as f:
self.data = pickle.load(f)

self.item_ids = np.array(self.data['item_id'], dtype=np.int64)
self.embeddings = np.array(self.data['embedding'], dtype=np.float32)

def __getitem__(self, idx):
index = self.item_ids[idx]
tensor_emb = self.embeddings[idx]
return {
'item_id': index,
'embedding': tensor_emb,
'embedding_dim': len(tensor_emb)
}

def __len__(self):
return len(self.embeddings)


class ProcessEmbeddings(Transform):
def __init__(self, embedding_dim, keys):
self.embedding_dim = embedding_dim
self.keys = keys

def __call__(self, batch):
for key in self.keys:
batch[key] = batch[key].reshape(-1, self.embedding_dim)
return batch
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import json
import pandas as pd
from pathlib import Path


ALL_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/all_items_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless.json"
TRAIN_INTERACTIONS_PATH = "/home/jovyan/IRec/sigir/lsvd_data_filtered/15-ts-ows/base_with_gap_interactions_grouped.parquet"
OUTPUT_TRAIN_MAPPING_PATH = "/home/jovyan/IRec/results-lsvd-2/base_gap/letter/only_base_with_gap_letter_vk-lsvd-15ts_base_with_gap_e35_rqvae_1.0_cf_0.01_clusters_colisionless_from_all.json"
with open(ALL_MAPPING_PATH, 'r') as f:
all_mapping = json.load(f)
print(f"Loaded {len(all_mapping)} items from all_mapping")

train_interactions = pd.read_parquet(TRAIN_INTERACTIONS_PATH)

train_item_ids = set()
for item_ids_array in train_interactions['item_ids']:
train_item_ids.update(item_ids_array)

print(f"Found {len(train_item_ids)} unique train items")

train_mapping = {}
missing_count = 0

for item_id in train_item_ids:
item_id_str = str(item_id)
if item_id_str in all_mapping:
train_mapping[item_id_str] = all_mapping[item_id_str]
else:
missing_count += 1

if missing_count > 0:
print(f"{missing_count} items from train not found in all_mapping")

print(f"Created train_mapping with {len(train_mapping)} items")

Path(OUTPUT_TRAIN_MAPPING_PATH).parent.mkdir(parents=True, exist_ok=True)
with open(OUTPUT_TRAIN_MAPPING_PATH, 'w') as f:
json.dump(train_mapping, f, indent=2)
print(f"Saved to {OUTPUT_TRAIN_MAPPING_PATH}")

print(f"all_mapping size: {len(all_mapping)}")
print(f"train_mapping size: {len(train_mapping)}")
print(f"train_mapping/all_mapping ratio: {len(train_mapping)/len(all_mapping):.1%}")

sample_matches = 0
for item_id_str in list(train_mapping.keys())[:100]:
if all_mapping[item_id_str] == train_mapping[item_id_str]:
sample_matches += 1

print(f"Verified: {sample_matches}/100 sampled items have identical codes")
Loading