-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
330 lines (266 loc) · 11.5 KB
/
main.py
File metadata and controls
330 lines (266 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#!/usr/bin/env python3
# Copyright 2025 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
"""
CompCon: Discovering Divergent Representations in Text-2-Image Models
This is the main orchestration script that runs the complete CompCon pipeline:
1. Discovers visual differences between two models
2. Generates prompt descriptions for the top differences
Usage:
python main.py --config configs/base.yaml --data_file data/results.csv --models model1 model2
For more details, see README.md
"""
import argparse
import logging
import os
from pathlib import Path
from typing import Dict, List
import pandas as pd
import weave
import wandb
# Import our pipeline components
from get_visual_differences import load_config, rank, load_data, propose, cluster_strings
from generate_prompt_descriptions import compcon
def setup_logging():
"""Set up logging configuration."""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
def parse_arguments():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="CompCon: Complete pipeline for discovering divergent visual representations",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Run full pipeline with default config
python main.py --data_file data/results.csv --models sd3.5-large playground
# Run with custom config and overrides
python main.py --config configs/base.yaml --data_file data/results.csv --models model1 model2 --overrides proposer.num_rounds=5
# Run only visual differences discovery
python main.py --data_file data/results.csv --models model1 model2 --differences_only
"""
)
# Data arguments (can be specified or extracted from config)
parser.add_argument("--data_file", type=str,
help="Path to CSV file containing image data (overrides config)")
parser.add_argument("--models", nargs=2,
help="Two model names to compare (overrides config)")
# Optional configuration
parser.add_argument("--config", type=str, default="configs/base.yaml",
help="Path to configuration file (default: configs/base.yaml)")
parser.add_argument("--overrides", nargs="*", default=[],
help="Configuration overrides in format key=value")
# Pipeline control
parser.add_argument("--differences_only", action="store_true",
help="Only run visual differences discovery, skip prompt descriptions")
parser.add_argument("--top_k", type=int, default=5,
help="Number of top differences to analyze (default: 5)")
# Output control
parser.add_argument("--output_dir", type=str, default="results",
help="Directory to save results (default: results)")
parser.add_argument("--project", type=str, default="CompCon",
help="WandB project name (default: CompCon)")
parser.add_argument("--name", type=str,
help="Run name for logging (default: auto-generated)")
# CompCon-specific parameters
parser.add_argument("--threshold", type=float, default=0.2,
help="Threshold for separable prompts (default: 0.2)")
parser.add_argument("--delta", type=float, default=0.05,
help="Delta for separable prompts (default: 0.05)")
parser.add_argument("--num_iterations", type=int, default=3,
help="Number of CompCon iterations (default: 3)")
return parser.parse_args()
def get_visual_differences(config_path: str, overrides: List[str], data_file: str = None, models: List[str] = None) -> List[str]:
"""
Run the visual differences discovery pipeline.
Args:
config_path: Path to configuration file
overrides: List of configuration overrides
data_file: Path to data CSV file
models: List of two model names to compare
Returns:
List of discovered visual differences, ranked by importance
"""
logging.info("=" * 50)
logging.info("STEP 1: Discovering Visual Differences")
logging.info("=" * 50)
# Prepare overrides for data file and models (if provided)
data_overrides = []
if data_file:
data_overrides.extend([
f"data.name={Path(data_file).stem}",
f"data.root={Path(data_file).parent}"
])
if models:
data_overrides.extend([
f"data.group1={models[0]}",
f"data.group2={models[1]}"
])
all_overrides = overrides + data_overrides
# Load configuration
args = load_config(config_path, all_overrides)
# Load data
logging.info("Loading data...")
dataset1, dataset2, group_names = load_data(args)
logging.info(f"Loaded {len(dataset1)} samples for {group_names[0]}, {len(dataset2)} samples for {group_names[1]}")
# Propose hypotheses
logging.info("Proposing visual difference hypotheses...")
hypotheses1, _ = propose(args, dataset1, dataset2)
# Cluster hypotheses if too many
if len(hypotheses1) >= 100:
logging.info(f"Clustering {len(hypotheses1)} hypotheses to 50...")
reduced_hypotheses1 = cluster_strings(hypotheses1, 50)
else:
reduced_hypotheses1 = hypotheses1
# Clean hypotheses
reduced_hypotheses1 = [h.strip().replace('"', '').replace("'", '') for h in reduced_hypotheses1]
# Rank hypotheses
logging.info("Ranking hypotheses by importance...")
ranked_hypotheses = rank(args, reduced_hypotheses1, dataset1, dataset2, group_names[0])
logging.info(f"Discovered {len(ranked_hypotheses)} visual differences")
for i, diff in enumerate(ranked_hypotheses[:5]):
logging.info(f" {i+1}. {diff}")
return ranked_hypotheses
def generate_prompt_descriptions_for_differences(
differences: List[str],
args: Dict,
output_dir: str,
threshold: float = 0.2,
delta: float = 0.05,
num_iterations: int = 3,
top_k: int = 5
) -> Dict[str, Dict]:
"""
Generate prompt descriptions for each visual difference.
Args:
differences: List of visual differences to analyze
args: Configuration dictionary with data settings
output_dir: Directory to save results
threshold: Threshold for separable prompts
delta: Delta for separable prompts
num_iterations: Number of CompCon iterations
top_k: Number of top differences to analyze
Returns:
Dictionary mapping differences to their results
"""
logging.info("=" * 50)
logging.info("STEP 2: Generating Prompt Descriptions")
logging.info("=" * 50)
# Extract data info from config
data_file = f"{args['data']['root']}/{args['data']['name']}.csv"
models = [args['data']['group1'], args['data']['group2']]
# Load prompt data
prompts_df = pd.read_csv(data_file)
# Prepare prompt data structure
prompt_data = []
for prompt in prompts_df["Prompt"].unique():
model_1_sample = prompts_df[
(prompts_df["Prompt"] == prompt) & (prompts_df["group_name"] == models[0])
]["path"].tolist()
model_2_sample = prompts_df[
(prompts_df["Prompt"] == prompt) & (prompts_df["group_name"] == models[1])
]["path"].tolist()
if model_1_sample and model_2_sample:
prompt_data.append({
"prompt": prompt,
"paths": [model_1_sample[0], model_2_sample[0]],
"models": models
})
results = {}
# Analyze top_k differences
for i, attribute in enumerate(differences[:top_k]):
logging.info(f"Analyzing difference {i+1}/{top_k}: '{attribute}'")
try:
# Run CompCon algorithm for this attribute
image_save_dir = f"{output_dir}/{attribute.replace(' ', '_').replace('/', '_')}"
P_a, P_a_new, description = compcon(
prompt_data,
models,
attribute,
image_save_dir,
threshold=threshold,
delta=delta,
num_iterations=num_iterations
)
results[attribute] = {
"separable_prompts": P_a,
"new_prompts": P_a_new,
"description": description,
"num_separable": len(P_a),
"save_dir": image_save_dir
}
logging.info(f" Found {len(P_a)} separable prompts")
logging.info(f" Generated {len(P_a_new)} new prompts")
logging.info(f" Description: {description}")
except Exception as e:
logging.error(f"Error analyzing '{attribute}': {e}")
results[attribute] = {"error": str(e)}
return results
def main():
"""Main orchestration function."""
setup_logging()
args = parse_arguments()
# Load config to get data info
config_args = load_config(args.config, args.overrides)
# Use command line args if provided, otherwise use config
data_file = args.data_file or f"{config_args['data']['root']}/{config_args['data']['name']}.csv"
models = args.models or [config_args['data']['group1'], config_args['data']['group2']]
logging.info("Starting CompCon pipeline...")
logging.info(f"Data file: {data_file}")
logging.info(f"Models: {models[0]} vs {models[1]}")
logging.info(f"Config: {args.config}")
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Initialize wandb
wandb_name = args.name or f"{models[0]}_vs_{models[1]}"
wandb.init(
project=args.project,
name=wandb_name,
config=vars(args)
)
try:
# Step 1: Get visual differences
differences = get_visual_differences(
args.config,
args.overrides,
data_file if args.data_file else None,
models if args.models else None
)
if not differences:
logging.error("No visual differences found!")
return
# Log top differences
wandb.summary["top_differences"] = differences[:args.top_k]
if args.differences_only:
logging.info("Differences-only mode: skipping prompt descriptions")
logging.info(f"Top {args.top_k} differences saved to WandB")
return
# Step 2: Generate prompt descriptions for top differences
results = generate_prompt_descriptions_for_differences(
differences,
config_args,
args.output_dir,
args.threshold,
args.delta,
args.num_iterations,
args.top_k
)
# Log results
for attribute, result in results.items():
if "error" not in result:
wandb.summary[f"{attribute}_num_separable"] = result["num_separable"]
wandb.summary[f"{attribute}_description"] = result["description"]
logging.info("=" * 50)
logging.info("PIPELINE COMPLETE")
logging.info("=" * 50)
logging.info(f"Results saved to: {args.output_dir}")
logging.info(f"WandB logs: {wandb.run.url}")
except Exception as e:
logging.error(f"Pipeline failed: {e}")
raise
finally:
wandb.finish()
if __name__ == "__main__":
main()