diff --git a/data/out/distillation/mmlu_synth_gptoss_b_t0_8_cleaned_32b.parquet b/data/out/distillation/mmlu_synth_gptoss_b_t0_8_cleaned_32b.parquet new file mode 100644 index 0000000..9a1ca6f Binary files /dev/null and b/data/out/distillation/mmlu_synth_gptoss_b_t0_8_cleaned_32b.parquet differ diff --git a/data/out/distillation/mmlu_synth_gptoss_b_t0_8_llm_cleaned.parquet b/data/out/distillation/mmlu_synth_gptoss_b_t0_8_llm_cleaned.parquet new file mode 100644 index 0000000..8740aee Binary files /dev/null and b/data/out/distillation/mmlu_synth_gptoss_b_t0_8_llm_cleaned.parquet differ diff --git a/data/out/distillation/mmlu_synth_gptoss_c_t0_8.parquet b/data/out/distillation/mmlu_synth_gptoss_c_t0_8.parquet new file mode 100644 index 0000000..55f5e82 Binary files /dev/null and b/data/out/distillation/mmlu_synth_gptoss_c_t0_8.parquet differ diff --git a/data/out/distillation/mmlu_synth_qwen3_b_t0_8_cleaned_32b.parquet b/data/out/distillation/mmlu_synth_qwen3_b_t0_8_cleaned_32b.parquet new file mode 100644 index 0000000..06a3e15 Binary files /dev/null and b/data/out/distillation/mmlu_synth_qwen3_b_t0_8_cleaned_32b.parquet differ diff --git a/pyproject.toml b/pyproject.toml index 5c61300..b09f9bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,9 @@ dependencies = [ "seaborn>=0.13.2", "sentencepiece>=0.2.0", "torch>=2.6.0", + "tqdm>=4.67.1", "transformers==4.52.3", + "vllm>=0.7.4", ] [build-system] diff --git a/src/core/distillation/clean_thinking_traces.py b/src/core/distillation/clean_thinking_traces.py new file mode 100644 index 0000000..172cf7b --- /dev/null +++ b/src/core/distillation/clean_thinking_traces.py @@ -0,0 +1,121 @@ +"""LLM cleaner for thinking traces on H100 80GB.""" + +import os +os.environ["VLLM_USE_V1"] = "0" + +import pandas as pd +from pathlib import Path +from vllm import LLM, SamplingParams +from tqdm import tqdm +import torch +import gc +import time + +from prompts import SYSTEM_PROMPT, USER_PROMPT_TEMPLATE + + +def format_prompt(reasoning_trace: str) -> list: + """Format prompt for LLM cleaning.""" + return [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": USER_PROMPT_TEMPLATE.format(reasoning_trace=reasoning_trace)} + ] + + +def truncate_thinking(thinking: str, max_chars: int = 50000) -> str: + if len(thinking) <= max_chars: + return thinking + truncated = thinking[:max_chars] + last_period = truncated.rfind('.') + if last_period > max_chars - 200: + return truncated[:last_period + 1] + return truncated + + +def clean_thinking_traces( + input_file: str | Path, + output_file: str | Path, + model_name: str = "Qwen/Qwen2.5-32B-Instruct", + batch_size: int = 24, + temperature: float = 0.0, + gpu_memory_utilization: float = 0.95, + max_model_len: int = 20480, + max_thinking_chars: int = 35000, + checkpoint_every_n_batches: int = 8 +) -> Path: + gc.collect() + torch.cuda.empty_cache() + + input_path = Path(input_file).resolve() + output_path = Path(output_file).resolve() + + checkpoint_name = f"checkpoint_{int(time.time())}_{os.getpid()}.parquet" + checkpoint_path = output_path.parent / checkpoint_name + + print(f"Input: {input_path}") + print(f"Output: {output_path}") + + llm = LLM( + model=model_name, + tensor_parallel_size=1, + gpu_memory_utilization=gpu_memory_utilization, + trust_remote_code=True, + dtype="bfloat16", + max_model_len=max_model_len, + enforce_eager=False + ) + + sampling_params = SamplingParams( + temperature=temperature, + top_p=1.0, + max_tokens=int(max_model_len * 0.9), + repetition_penalty=1.05 + ) + + df = pd.read_parquet(input_path) + print(f"Records: {len(df)}") + + reasoning_traces = [] + indices_to_clean = [] + + for idx, row in df.iterrows(): + out = row['output'] + if isinstance(out, dict) and out.get('error') is None and 'thinking' in out: + thinking = out.get('thinking', '') + if thinking: + thinking = truncate_thinking(thinking, max_thinking_chars) + reasoning_traces.append(thinking) + indices_to_clean.append(idx) + + print(f"Traces to clean: {len(reasoning_traces)}") + + all_prompts = [format_prompt(trace) for trace in reasoning_traces] + df_cleaned = df.copy() + + pbar = tqdm(total=len(all_prompts), desc="Cleaning") + + for i in range(0, len(all_prompts), batch_size): + batch_prompts = all_prompts[i : i + batch_size] + batch_indices = indices_to_clean[i : i + batch_size] + + outputs = llm.chat(messages=batch_prompts, sampling_params=sampling_params, use_tqdm=False) + batch_texts = [output.outputs[0].text.strip() for output in outputs] + + for idx, cleaned_text in zip(batch_indices, batch_texts): + output_dict = df_cleaned.at[idx, 'output'].copy() + output_dict['thinking'] = cleaned_text + df_cleaned.at[idx, 'output'] = output_dict + + pbar.update(len(batch_texts)) + + if (i // batch_size + 1) % checkpoint_every_n_batches == 0: + df_cleaned.to_parquet(checkpoint_path, index=False) + + pbar.close() + df_cleaned.to_parquet(output_path, index=False) + + if checkpoint_path.exists(): + os.remove(checkpoint_path) + + print(f"Saved: {output_path}") + return output_path diff --git a/src/core/distillation/evaluate_cleaning_quality.py b/src/core/distillation/evaluate_cleaning_quality.py new file mode 100644 index 0000000..40e97b0 --- /dev/null +++ b/src/core/distillation/evaluate_cleaning_quality.py @@ -0,0 +1,125 @@ +import os +os.environ["VLLM_USE_V1"] = "0" + +import json +import pandas as pd +from pathlib import Path +from vllm import LLM, SamplingParams +from tqdm import tqdm + +from core.prompts.cleaning_judge import create_judge_prompt, FEWSHOT_EXAMPLE + + +def select_samples(branch_a_file: Path, branch_b_file: Path, n_fewshot: int = 3, n_test: int = 100): + df_a = pd.read_parquet(branch_a_file) + df_b = pd.read_parquet(branch_b_file) + + df_a['question_id'] = df_a['input'].apply(lambda x: str(x.get('question_id')) if isinstance(x, dict) and x.get('question_id') is not None else None) + df_a['answer'] = df_a['output'].apply(lambda x: x.get('answer') if isinstance(x, dict) else None) + df_a['error'] = df_a['output'].apply(lambda x: x.get('error') if isinstance(x, dict) else None) + df_a['gold'] = df_a['input'].apply(lambda x: x.get('gold') if isinstance(x, dict) else None) + df_a['thinking_a'] = df_a['output'].apply(lambda x: x.get('thinking', '') if isinstance(x, dict) else '') + + df_b['question_id'] = df_b['input'].apply(lambda x: str(x.get('question_id')) if isinstance(x, dict) and x.get('question_id') is not None else None) + df_b['thinking_b'] = df_b['output'].apply(lambda x: x.get('thinking', '') if isinstance(x, dict) else '') + + correct_a = df_a[(df_a['error'].isna() | (df_a['error'] == False)) & (df_a['answer'] == df_a['gold'])].copy() + + merged = correct_a.merge(df_b[['question_id', 'thinking_b']], on='question_id', how='inner') + + matched = [ + { + 'question_id': row['question_id'], + 'original': row['thinking_a'], + 'cleaned': row['thinking_b'] + } + for _, row in merged.iterrows() + ] + + return matched[:n_fewshot], matched[n_fewshot:n_fewshot+n_test] + + +def evaluate_batch(samples: list, model_name: str, output_file: Path, batch_size: int = 10): + llm = LLM( + model=model_name, + tensor_parallel_size=1, + gpu_memory_utilization=0.9, + trust_remote_code=True, + dtype="bfloat16", + max_model_len=32768, + disable_custom_all_reduce=True, + enforce_eager=True + ) + + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=512, + top_p=1.0 + ) + + all_results = [] + + for i in range(0, len(samples), batch_size): + batch = samples[i:i+batch_size] + batch_results = [] + + prompts = [] + for sample in batch: + msgs = create_judge_prompt(sample['original'], sample['cleaned'], [FEWSHOT_EXAMPLE]) + prompt = llm.get_tokenizer().apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) + prompts.append(prompt) + + outputs = llm.generate(prompts, sampling_params) + + for sample, output in zip(batch, outputs): + text = output.outputs[0].text.strip() + try: + scores = json.loads(text) + except: + scores = {"error": "parse_failed", "raw": text} + + result = { + 'question_id': sample['question_id'], + 'scores': scores, + 'original_length': len(sample['original']), + 'cleaned_length': len(sample['cleaned']) + } + batch_results.append(result) + all_results.append(result) + + if i == 0: + pd.DataFrame(batch_results).to_parquet(output_file, index=False) + else: + existing = pd.read_parquet(output_file) + combined = pd.concat([existing, pd.DataFrame(batch_results)], ignore_index=True) + combined.to_parquet(output_file, index=False) + + print(f"Batch {i//batch_size + 1}/{(len(samples)-1)//batch_size + 1} completed, saved {len(batch_results)} results") + + print(f"All results saved to {output_file}") + return all_results + + +def compute_stats(results_file: Path): + df = pd.read_parquet(results_file) + + scores_list = [] + for scores in df['scores']: + if isinstance(scores, dict) and 'logical_preservation' in scores: + scores_list.append(scores) + + if not scores_list: + print("No valid scores found") + return + + criteria = ['logical_preservation', 'noise_removal'] + print("\nScores (mean ± std):") + for c in criteria: + vals = [s[c] for s in scores_list if c in s and s[c] is not None] + if vals: + print(f"{c}: {sum(vals)/len(vals):.2f} ± {(sum((x-sum(vals)/len(vals))**2 for x in vals)/len(vals))**0.5:.2f}") + + avg_original = df['original_length'].mean() + avg_cleaned = df['cleaned_length'].mean() + compression = (1 - avg_cleaned/avg_original) * 100 + print(f"\nCompression: {compression:.1f}% ({avg_original:.0f} → {avg_cleaned:.0f} chars)") diff --git a/src/core/distillation/prompts.py b/src/core/distillation/prompts.py new file mode 100644 index 0000000..32422e2 --- /dev/null +++ b/src/core/distillation/prompts.py @@ -0,0 +1,42 @@ +"""Prompts for cleaning thinking traces with LLM.""" + +SYSTEM_PROMPT = """You are a text editor. Your ONLY task is to remove phrases that reveal the final answer or outcome prematurely (spoilers). + +GUIDELINES: +1. REMOVE only parts where the text explicitly states the known result before deriving it (e.g., "Since we know the answer is B...", "The target value is 5, so..."). +2. KEEP all conversational fillers ("Let's think", "I wonder", "Note that"), planning, and formatting. +3. PRESERVE the original wording exactly. Do not summarize or rewrite. + +### EXAMPLE 1 +[INPUT] +We need to verify if the function is continuous. +Let f(x) = x^2. We will check the limit at x=0. +Limit calculation: lim(x->0) x^2 = 0. +Now provide a concise explanation. The limit exists and equals the function value. +Thus, it is continuous. + +[OUTPUT] +Let f(x) = x^2. Check the limit at x=0. +Limit calculation: lim(x->0) x^2 = 0. +The limit exists and equals the function value. +Thus, it is continuous. + +### EXAMPLE 2 +[INPUT] +Option A is wrong because 5 is not prime. +We'll explain why B is correct. +For option B: 7 is prime. +Now wrap up the answer. +Therefore, B is the answer. + +[OUTPUT] +Option A is wrong because 5 is not prime. +For option B: 7 is prime. +Therefore, B is the answer. +""" + +USER_PROMPT_TEMPLATE = """[INPUT] +{reasoning_trace} + +[OUTPUT] +""" diff --git a/src/core/distillation/run_cleaner_gptoss_b.py b/src/core/distillation/run_cleaner_gptoss_b.py new file mode 100644 index 0000000..b166c4d --- /dev/null +++ b/src/core/distillation/run_cleaner_gptoss_b.py @@ -0,0 +1,19 @@ +"""Clean gptoss Branch B thinking traces.""" + +from pathlib import Path +from clean_thinking_traces import clean_thinking_traces + +if __name__ == "__main__": + project_root = Path(__file__).parent.parent.parent.parent + data_dir = project_root / "data" / "out" / "distillation" + + clean_thinking_traces( + input_file=data_dir / "mmlu_synth_gptoss_b_t0_8.parquet", + output_file=data_dir / "mmlu_synth_gptoss_b_t0_8_cleaned_32b.parquet", + model_name="Qwen/Qwen2.5-32B-Instruct", + batch_size=12, + temperature=0.1, + max_model_len=20480, + max_thinking_chars=35000, + checkpoint_every_n_batches=4 + ) diff --git a/src/core/distillation/run_cleaner_gptoss_c.py b/src/core/distillation/run_cleaner_gptoss_c.py new file mode 100644 index 0000000..9477643 --- /dev/null +++ b/src/core/distillation/run_cleaner_gptoss_c.py @@ -0,0 +1,19 @@ +"""Clean gptoss Branch C thinking traces.""" + +from pathlib import Path +from clean_thinking_traces import clean_thinking_traces + +if __name__ == "__main__": + project_root = Path(__file__).parent.parent.parent.parent + data_dir = project_root / "data" / "out" / "distillation" + + clean_thinking_traces( + input_file=data_dir / "mmlu_synth_gptoss_c_t0_8.parquet", + output_file=data_dir / "mmlu_synth_gptoss_c_t0_8_cleaned_32b.parquet", + model_name="Qwen/Qwen2.5-32B-Instruct", + batch_size=12, + temperature=0.1, + max_model_len=20480, + max_thinking_chars=35000, + checkpoint_every_n_batches=4 + ) diff --git a/src/core/distillation/run_cleaner_qwen_b.py b/src/core/distillation/run_cleaner_qwen_b.py new file mode 100644 index 0000000..0fc8817 --- /dev/null +++ b/src/core/distillation/run_cleaner_qwen_b.py @@ -0,0 +1,19 @@ +"""Clean qwen3 Branch B thinking traces.""" + +from pathlib import Path +from clean_thinking_traces import clean_thinking_traces + +if __name__ == "__main__": + project_root = Path(__file__).parent.parent.parent.parent + data_dir = project_root / "data" / "out" / "distillation" + + clean_thinking_traces( + input_file=data_dir / "mmlu_synth_qwen3_b_t0_8.parquet", + output_file=data_dir / "mmlu_synth_qwen3_b_t0_8_cleaned_32b.parquet", + model_name="Qwen/Qwen2.5-32B-Instruct", + batch_size=12, + temperature=0.1, + max_model_len=20480, + max_thinking_chars=35000, + checkpoint_every_n_batches=4 + ) diff --git a/src/core/distillation/run_cleaner_qwen_c.py b/src/core/distillation/run_cleaner_qwen_c.py new file mode 100644 index 0000000..6370192 --- /dev/null +++ b/src/core/distillation/run_cleaner_qwen_c.py @@ -0,0 +1,19 @@ +"""Clean qwen3 Branch C thinking traces.""" + +from pathlib import Path +from clean_thinking_traces import clean_thinking_traces + +if __name__ == "__main__": + project_root = Path(__file__).parent.parent.parent.parent + data_dir = project_root / "data" / "out" / "distillation" + + clean_thinking_traces( + input_file=data_dir / "mmlu_synth_qwen3_c_t0_8.parquet", + output_file=data_dir / "mmlu_synth_qwen3_c_t0_8_cleaned_32b.parquet", + model_name="Qwen/Qwen2.5-32B-Instruct", + batch_size=12, + temperature=0.1, + max_model_len=20480, + max_thinking_chars=35000, + checkpoint_every_n_batches=4 + ) diff --git a/src/core/prompts/cleaning_judge.py b/src/core/prompts/cleaning_judge.py new file mode 100644 index 0000000..13ae944 --- /dev/null +++ b/src/core/prompts/cleaning_judge.py @@ -0,0 +1,88 @@ +JUDGE_SYSTEM_PROMPT = """### Project Context +We are creating a high-quality synthetic dataset for training small models. + +1. "Original Reasoning" (branch A) is our "gold standard." It represents the raw, correct chain of reasoning that led to the correct answer. + +2. "Metadata-stripped Reasoning with Correct Answer Explanation" (branch B) is a version that explains the correct answer instead of trying to "guess" the correct answer by generating a reasoning trace. + +### Problem + +When we generated branch B, we saw that the reasoning chain includes statements that reveal that this reasoning chain is generated when the answer is given beforehand. It is going ot ruin th etraining objective. We asked an LLM to clean the generated chain for branch B. + +### Your goal + +Rate the 'Cleaned reasoning' on 2 criteria (1-10 scale): + +1. logical_preservation: + - 10: All logical steps, calculations, and crucial evidence are preserved in the exact original sequence. The deductive flow (Step A -> Step B -> Conclusion) remains unchanged. + - 1: Critical steps are missing, or the reasoning order is distorted (e.g., stating the final answer before the logic/derivations, whereas the original built up to it). + +2. noise_removal: + - 10: Perfectly stripped of conversational fillers ("Let's think", "We need to", "Step 1"), meta-language, and self-corrections. Contains ONLY declarative facts and deductions. + - 1: Still contains significant conversational filler, headers like "Why this is wrong", or meta-commentary. + + +The output for branch A is provided as the golden reference training signal + +### Examples + +Input: +Original: "We need to select correct option letter. The second law of thermodynamics says that heat does not spontaneously flow from a colder body to a hotter body without external energy. Option j: 'cold to hot without external energy'. That matches. So answer is j." +Cleaned: "The second law states heat flows spontaneously from hot to cold; it cannot flow from cold to hot without external work. Option J says 'cold to hot without external energy' – this is contrary to the law. Therefore, J is correct." +Output: +{"logical_preservation": 10, "noise_removal": 10} + +Input: +Original: "We need to answer multiple choice. The question: In personnel selection, the selection ratio is defined as the ... Options: a) ratio of successful..., b) ratio of number... Let's think. Usually selection ratio is applicants hired vs available." +Cleaned: "D is correct. A, B, C are wrong." +Output: +{"logical_preservation": 1, "noise_removal": 10} + +Input: +Original: "We need to compute equivalent capacitance of series: 1/C_eq = sum(1/C_i). Here n=5, C=5. So 1/C_eq = 5/5 = 1. So answer is h." +Cleaned: "Why other options are wrong and why H (1 µF) is correct: series formula C_eq = 1/sum. For five identical 5 µF in series: C_eq = 1 µF. Rule out other options." +Output: +{"logical_preservation": 3, "noise_removal": 6} + +### Task + +Return ONLY a JSON object. +Format: {"logical_preservation": X, "noise_removal": X} +""" + + +JUDGE_USER_TEMPLATE = """Original reasoning: +{original} + +Cleaned reasoning: +{cleaned} + +Evaluate the cleaning quality.""" + + +FEWSHOT_EXAMPLE = { + "original": """Let me analyze this question step by step. We need to find the area of a circle with radius 5. +Note that the formula for circle area is A = πr². Let's calculate: r = 5, so r² = 25. +Therefore A = π × 25 = 25π ≈ 78.54. Now let me check the options. Option A says 78.5, which matches our calculation.""", + "cleaned": """We need to find the area of a circle with radius 5. +The formula for circle area is A = πr². r = 5, so r² = 25. +Therefore A = π × 25 = 25π ≈ 78.54. Option A says 78.5, which matches our calculation.""", + "rating": '{"logical_preservation": 5, "noise_removal": 5}' +} + + +def create_judge_prompt(original: str, cleaned: str, examples: list = None) -> list: + messages = [{"role": "system", "content": JUDGE_SYSTEM_PROMPT}] + + if examples: + for ex in examples: + messages.append({"role": "user", "content": JUDGE_USER_TEMPLATE.format( + original=ex["original"], cleaned=ex["cleaned"] + )}) + messages.append({"role": "assistant", "content": ex["rating"]}) + + messages.append({"role": "user", "content": JUDGE_USER_TEMPLATE.format( + original=original, cleaned=cleaned + )}) + + return messages diff --git a/src/experiments/distill/mmlu_synth_gptoss_c_t0_8.py b/src/experiments/distill/mmlu_synth_gptoss_c_t0_8.py new file mode 100644 index 0000000..a3227f6 --- /dev/null +++ b/src/experiments/distill/mmlu_synth_gptoss_c_t0_8.py @@ -0,0 +1,23 @@ +from pathlib import Path +from multiprocessing import freeze_support + +from core.distillation.synth_aug_mmlu import synth_on_dataset + +def main(): + freeze_support() + + synth_on_dataset( + in_filename=Path(__file__).parent.joinpath("../../../data/source/mmlu_pro_stem_shuffled.tsv").resolve(), + out_filename=Path(__file__).parent.joinpath("../../../data/out/distillation/mmlu_synth_gptoss_c_t0_8.parquet").resolve(), + model="openai/gpt-oss-120b", + max_tokens=16384, + dump_every=1000, + limit=None, + branch="C", + chunk_size=30, + a_file_path=Path(__file__).parent.joinpath("../../../data/out/distillation/mmlu_synth_gptoss_a_t0_8.parquet").resolve(), + temperature=0.8, + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/experiments/distill/run_cleaning_evaluation.py b/src/experiments/distill/run_cleaning_evaluation.py new file mode 100644 index 0000000..0fc95cd --- /dev/null +++ b/src/experiments/distill/run_cleaning_evaluation.py @@ -0,0 +1,40 @@ +from pathlib import Path +import pandas as pd +from core.distillation.evaluate_cleaning_quality import select_samples, evaluate_batch, compute_stats + + +if __name__ == "__main__": + project_root = Path(__file__).parent.parent.parent.parent + data_dir = project_root / "data" / "out" / "distillation" + eval_dir = project_root / "data" / "out" / "evaluation" + eval_dir.mkdir(parents=True, exist_ok=True) + + branch_a = data_dir / "mmlu_synth_gptoss_a_t0_8.parquet" + branch_b = data_dir / "mmlu_synth_gptoss_b_t0_8_cleaned_32b.parquet" + + print("Selecting samples...") + _, test = select_samples(branch_a, branch_b, n_test=100) + print(f"Test samples: {len(test)}") + + output = eval_dir / "cleaning_quality_results.parquet" + + import os + os.environ["CUDA_VISIBLE_DEVICES"] = "5" + os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" + + if output.exists(): + print(f"Found existing results with {len(pd.read_parquet(output))} samples") + print("Computing statistics on existing results...") + compute_stats(output) + print("To resume evaluation, remove the existing file or modify the script") + else: + print("\nRunning evaluation with kosbu/Llama-3.3-70B-Instruct-AWQ on H100...") + evaluate_batch( + samples=test, + model_name="kosbu/Llama-3.3-70B-Instruct-AWQ", + output_file=output, + batch_size=10 + ) + + print("\nComputing statistics...") + compute_stats(output)