Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ dependencies = [
"seaborn>=0.13.2",
"sentencepiece>=0.2.0",
"torch>=2.6.0",
"tqdm>=4.67.1",
"transformers==4.52.3",
"vllm>=0.7.4",
]

[build-system]
Expand Down
121 changes: 121 additions & 0 deletions src/core/distillation/clean_thinking_traces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""LLM cleaner for thinking traces on H100 80GB."""

import os
os.environ["VLLM_USE_V1"] = "0"

import pandas as pd
from pathlib import Path
from vllm import LLM, SamplingParams
from tqdm import tqdm
import torch
import gc
import time

from prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE


def format_prompt(reasoning_trace: str) -> list:
"""Format prompt for LLM cleaning."""
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": USER_PROMPT_TEMPLATE.format(reasoning_trace=reasoning_trace)}
]


def truncate_thinking(thinking: str, max_chars: int = 50000) -> str:
if len(thinking) <= max_chars:
return thinking
truncated = thinking[:max_chars]
last_period = truncated.rfind('.')
if last_period > max_chars - 200:
return truncated[:last_period + 1]
return truncated


def clean_thinking_traces(
input_file: str | Path,
output_file: str | Path,
model_name: str = "Qwen/Qwen2.5-32B-Instruct",
batch_size: int = 24,
temperature: float = 0.0,
gpu_memory_utilization: float = 0.95,
max_model_len: int = 20480,
max_thinking_chars: int = 35000,
checkpoint_every_n_batches: int = 8
) -> Path:
gc.collect()
torch.cuda.empty_cache()

input_path = Path(input_file).resolve()
output_path = Path(output_file).resolve()

checkpoint_name = f"checkpoint_{int(time.time())}_{os.getpid()}.parquet"
checkpoint_path = output_path.parent / checkpoint_name

print(f"Input: {input_path}")
print(f"Output: {output_path}")

llm = LLM(
model=model_name,
tensor_parallel_size=1,
gpu_memory_utilization=gpu_memory_utilization,
trust_remote_code=True,
dtype="bfloat16",
max_model_len=max_model_len,
enforce_eager=False
)

sampling_params = SamplingParams(
temperature=temperature,
top_p=1.0,
max_tokens=int(max_model_len * 0.9),
repetition_penalty=1.05
)

df = pd.read_parquet(input_path)
print(f"Records: {len(df)}")

reasoning_traces = []
indices_to_clean = []

for idx, row in df.iterrows():
out = row['output']
if isinstance(out, dict) and out.get('error') is None and 'thinking' in out:
thinking = out.get('thinking', '')
if thinking:
thinking = truncate_thinking(thinking, max_thinking_chars)
reasoning_traces.append(thinking)
indices_to_clean.append(idx)

print(f"Traces to clean: {len(reasoning_traces)}")

all_prompts = [format_prompt(trace) for trace in reasoning_traces]
df_cleaned = df.copy()

pbar = tqdm(total=len(all_prompts), desc="Cleaning")

for i in range(0, len(all_prompts), batch_size):
batch_prompts = all_prompts[i : i + batch_size]
batch_indices = indices_to_clean[i : i + batch_size]

outputs = llm.chat(messages=batch_prompts, sampling_params=sampling_params, use_tqdm=False)
batch_texts = [output.outputs[0].text.strip() for output in outputs]

for idx, cleaned_text in zip(batch_indices, batch_texts):
output_dict = df_cleaned.at[idx, 'output'].copy()
output_dict['thinking'] = cleaned_text
df_cleaned.at[idx, 'output'] = output_dict

pbar.update(len(batch_texts))

if (i // batch_size + 1) % checkpoint_every_n_batches == 0:
df_cleaned.to_parquet(checkpoint_path, index=False)

pbar.close()
df_cleaned.to_parquet(output_path, index=False)

if checkpoint_path.exists():
os.remove(checkpoint_path)

print(f"Saved: {output_path}")
return output_path
125 changes: 125 additions & 0 deletions src/core/distillation/evaluate_cleaning_quality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
os.environ["VLLM_USE_V1"] = "0"

import json
import pandas as pd
from pathlib import Path
from vllm import LLM, SamplingParams
from tqdm import tqdm

from core.prompts.cleaning_judge import create_judge_prompt, FEWSHOT_EXAMPLE


def select_samples(branch_a_file: Path, branch_b_file: Path, n_fewshot: int = 3, n_test: int = 100):
df_a = pd.read_parquet(branch_a_file)
df_b = pd.read_parquet(branch_b_file)

df_a['question_id'] = df_a['input'].apply(lambda x: str(x.get('question_id')) if isinstance(x, dict) and x.get('question_id') is not None else None)
df_a['answer'] = df_a['output'].apply(lambda x: x.get('answer') if isinstance(x, dict) else None)
df_a['error'] = df_a['output'].apply(lambda x: x.get('error') if isinstance(x, dict) else None)
df_a['gold'] = df_a['input'].apply(lambda x: x.get('gold') if isinstance(x, dict) else None)
df_a['thinking_a'] = df_a['output'].apply(lambda x: x.get('thinking', '') if isinstance(x, dict) else '')

df_b['question_id'] = df_b['input'].apply(lambda x: str(x.get('question_id')) if isinstance(x, dict) and x.get('question_id') is not None else None)
df_b['thinking_b'] = df_b['output'].apply(lambda x: x.get('thinking', '') if isinstance(x, dict) else '')

correct_a = df_a[(df_a['error'].isna() | (df_a['error'] == False)) & (df_a['answer'] == df_a['gold'])].copy()

merged = correct_a.merge(df_b[['question_id', 'thinking_b']], on='question_id', how='inner')

matched = [
{
'question_id': row['question_id'],
'original': row['thinking_a'],
'cleaned': row['thinking_b']
}
for _, row in merged.iterrows()
]

return matched[:n_fewshot], matched[n_fewshot:n_fewshot+n_test]


def evaluate_batch(samples: list, model_name: str, output_file: Path, batch_size: int = 10):
llm = LLM(
model=model_name,
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
trust_remote_code=True,
dtype="bfloat16",
max_model_len=32768,
disable_custom_all_reduce=True,
enforce_eager=True
)

sampling_params = SamplingParams(
temperature=0.0,
max_tokens=512,
top_p=1.0
)

all_results = []

for i in range(0, len(samples), batch_size):
batch = samples[i:i+batch_size]
batch_results = []

prompts = []
for sample in batch:
msgs = create_judge_prompt(sample['original'], sample['cleaned'], [FEWSHOT_EXAMPLE])
prompt = llm.get_tokenizer().apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
prompts.append(prompt)

outputs = llm.generate(prompts, sampling_params)

for sample, output in zip(batch, outputs):
text = output.outputs[0].text.strip()
try:
scores = json.loads(text)
except:
scores = {"error": "parse_failed", "raw": text}

result = {
'question_id': sample['question_id'],
'scores': scores,
'original_length': len(sample['original']),
'cleaned_length': len(sample['cleaned'])
}
batch_results.append(result)
all_results.append(result)

if i == 0:
pd.DataFrame(batch_results).to_parquet(output_file, index=False)
else:
existing = pd.read_parquet(output_file)
combined = pd.concat([existing, pd.DataFrame(batch_results)], ignore_index=True)
combined.to_parquet(output_file, index=False)

print(f"Batch {i//batch_size + 1}/{(len(samples)-1)//batch_size + 1} completed, saved {len(batch_results)} results")

print(f"All results saved to {output_file}")
return all_results


def compute_stats(results_file: Path):
df = pd.read_parquet(results_file)

scores_list = []
for scores in df['scores']:
if isinstance(scores, dict) and 'logical_preservation' in scores:
scores_list.append(scores)

if not scores_list:
print("No valid scores found")
return

criteria = ['logical_preservation', 'noise_removal']
print("\nScores (mean ± std):")
for c in criteria:
vals = [s[c] for s in scores_list if c in s and s[c] is not None]
if vals:
print(f"{c}: {sum(vals)/len(vals):.2f} ± {(sum((x-sum(vals)/len(vals))**2 for x in vals)/len(vals))**0.5:.2f}")

avg_original = df['original_length'].mean()
avg_cleaned = df['cleaned_length'].mean()
compression = (1 - avg_cleaned/avg_original) * 100
print(f"\nCompression: {compression:.1f}% ({avg_original:.0f} → {avg_cleaned:.0f} chars)")
42 changes: 42 additions & 0 deletions src/core/distillation/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Prompts for cleaning thinking traces with LLM."""

SYSTEM_PROMPT = """You are a text editor. Your ONLY task is to remove phrases that reveal the final answer or outcome prematurely (spoilers).

GUIDELINES:
1. REMOVE only parts where the text explicitly states the known result before deriving it (e.g., "Since we know the answer is B...", "The target value is 5, so...").
2. KEEP all conversational fillers ("Let's think", "I wonder", "Note that"), planning, and formatting.
3. PRESERVE the original wording exactly. Do not summarize or rewrite.

### EXAMPLE 1
[INPUT]
We need to verify if the function is continuous.
Let f(x) = x^2. We will check the limit at x=0.
Limit calculation: lim(x->0) x^2 = 0.
Now provide a concise explanation. The limit exists and equals the function value.
Thus, it is continuous.

[OUTPUT]
Let f(x) = x^2. Check the limit at x=0.
Limit calculation: lim(x->0) x^2 = 0.
The limit exists and equals the function value.
Thus, it is continuous.

### EXAMPLE 2
[INPUT]
Option A is wrong because 5 is not prime.
We'll explain why B is correct.
For option B: 7 is prime.
Now wrap up the answer.
Therefore, B is the answer.

[OUTPUT]
Option A is wrong because 5 is not prime.
For option B: 7 is prime.
Therefore, B is the answer.
"""

USER_PROMPT_TEMPLATE = """[INPUT]
{reasoning_trace}

[OUTPUT]
"""
19 changes: 19 additions & 0 deletions src/core/distillation/run_cleaner_gptoss_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Clean gptoss Branch B thinking traces."""

from pathlib import Path
from clean_thinking_traces import clean_thinking_traces

if __name__ == "__main__":
project_root = Path(__file__).parent.parent.parent.parent
data_dir = project_root / "data" / "out" / "distillation"

clean_thinking_traces(
input_file=data_dir / "mmlu_synth_gptoss_b_t0_8.parquet",
output_file=data_dir / "mmlu_synth_gptoss_b_t0_8_cleaned_32b.parquet",
model_name="Qwen/Qwen2.5-32B-Instruct",
batch_size=12,
temperature=0.1,
max_model_len=20480,
max_thinking_chars=35000,
checkpoint_every_n_batches=4
)
19 changes: 19 additions & 0 deletions src/core/distillation/run_cleaner_gptoss_c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Clean gptoss Branch C thinking traces."""

from pathlib import Path
from clean_thinking_traces import clean_thinking_traces

if __name__ == "__main__":
project_root = Path(__file__).parent.parent.parent.parent
data_dir = project_root / "data" / "out" / "distillation"

clean_thinking_traces(
input_file=data_dir / "mmlu_synth_gptoss_c_t0_8.parquet",
output_file=data_dir / "mmlu_synth_gptoss_c_t0_8_cleaned_32b.parquet",
model_name="Qwen/Qwen2.5-32B-Instruct",
batch_size=12,
temperature=0.1,
max_model_len=20480,
max_thinking_chars=35000,
checkpoint_every_n_batches=4
)
19 changes: 19 additions & 0 deletions src/core/distillation/run_cleaner_qwen_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Clean qwen3 Branch B thinking traces."""

from pathlib import Path
from clean_thinking_traces import clean_thinking_traces

if __name__ == "__main__":
project_root = Path(__file__).parent.parent.parent.parent
data_dir = project_root / "data" / "out" / "distillation"

clean_thinking_traces(
input_file=data_dir / "mmlu_synth_qwen3_b_t0_8.parquet",
output_file=data_dir / "mmlu_synth_qwen3_b_t0_8_cleaned_32b.parquet",
model_name="Qwen/Qwen2.5-32B-Instruct",
batch_size=12,
temperature=0.1,
max_model_len=20480,
max_thinking_chars=35000,
checkpoint_every_n_batches=4
)
19 changes: 19 additions & 0 deletions src/core/distillation/run_cleaner_qwen_c.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Clean qwen3 Branch C thinking traces."""

from pathlib import Path
from clean_thinking_traces import clean_thinking_traces

if __name__ == "__main__":
project_root = Path(__file__).parent.parent.parent.parent
data_dir = project_root / "data" / "out" / "distillation"

clean_thinking_traces(
input_file=data_dir / "mmlu_synth_qwen3_c_t0_8.parquet",
output_file=data_dir / "mmlu_synth_qwen3_c_t0_8_cleaned_32b.parquet",
model_name="Qwen/Qwen2.5-32B-Instruct",
batch_size=12,
temperature=0.1,
max_model_len=20480,
max_thinking_chars=35000,
checkpoint_every_n_batches=4
)
Loading