-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathteam_code.py
More file actions
542 lines (425 loc) · 22.6 KB
/
team_code.py
File metadata and controls
542 lines (425 loc) · 22.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
#!/usr/bin/env python
# Edit this script to add your team's code. Some functions are *required*, but you can edit most parts of the required functions,
# change or remove non-required functions, and add your own functions.
################################################################################
#
# Optional libraries, functions, and variables. You can change or remove them.
#
################################################################################
import joblib
import numpy as np
import os
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
import sys
from tqdm import tqdm
from helper_code import *
################################################################################
# Path & Constant Configuration (Added for Robustness)
################################################################################
# Get the absolute directory where this script is located
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# Build the absolute path to the CSV file relative to the script location
DEFAULT_CSV_PATH = os.path.join(SCRIPT_DIR, 'channel_table.csv')
################################################################################
#
# Required functions. Edit these functions to add your code, but do not change the arguments for the functions.
#
################################################################################
# Train your models. This function is *required*. You should edit this function to add your code, but do *not* change the arguments
# of this function. If you do not train one of the models, then you can return None for the model.
# Train your model.
def train_model(data_folder, model_folder, verbose, csv_path=DEFAULT_CSV_PATH):
# Find the data files.
if verbose:
print('Finding the Challenge data...')
patient_data_file = os.path.join(data_folder, DEMOGRAPHICS_FILE)
patient_metadata_list = find_patients(patient_data_file)
num_records = len(patient_metadata_list)
if num_records == 0:
raise FileNotFoundError('No data were provided.')
# Extract the features and labels from the data.
if verbose:
print('Extracting features and labels from the data...')
# Iterate over the records to extract the features and labels.
features = list()
labels = list()
pbar = tqdm(range(num_records), desc="Extracting Features", unit="record", disable=not verbose)
for i in pbar:
try:
# Extract identifiers for this specific record
record = patient_metadata_list[i]
patient_id = record[HEADERS['bids_folder']]
site_id = record[HEADERS['site_id']]
session_id = record[HEADERS['session_id']]
if verbose:
pbar.set_postfix({"patient": patient_id})
# Load the patient data.
patient_data_file = os.path.join(data_folder, DEMOGRAPHICS_FILE)
patient_data = load_demographics(patient_data_file, patient_id, session_id)
demographic_features = extract_demographic_features(patient_data)
# Load signal data.
# Load the physiological signal.
physiological_data_file = os.path.join(data_folder, PHYSIOLOGICAL_DATA_SUBFOLDER, site_id, f"{patient_id}_ses-{session_id}.edf")
# --- Check if the file actually exists before proceeding ---
if not os.path.exists(physiological_data_file):
if verbose:
print(f" ! Missing physiological data for {patient_id}. Skipping...")
continue # skip record
physiological_data, physiological_fs = load_signal_data(physiological_data_file)
physiological_features = extract_physiological_features(physiological_data, physiological_fs, csv_path=csv_path) # This function can rename, re-reference, resample, etc. the signal data.
# Load the algorithmic annotations.
algorithmic_annotations_file = os.path.join(data_folder, ALGORITHMIC_ANNOTATIONS_SUBFOLDER, site_id, f"{patient_id}_ses-{session_id}_caisr_annotations.edf")
algorithmic_annotations, algorithmic_fs = load_signal_data(algorithmic_annotations_file)
algorithmic_features = extract_algorithmic_annotations_features(algorithmic_annotations)
# Load the human annotations; these data will not be available in the hidden validation and test sets.
human_annotations_file = os.path.join(data_folder, HUMAN_ANNOTATIONS_SUBFOLDER, site_id, f"{patient_id}_ses-{session_id}_expert_annotations.edf")
human_annotations, human_fs = load_signal_data(human_annotations_file)
human_features = extract_human_annotations_features(human_annotations)
# Load the diagnoses; these data will not be available in the hidden validation and test sets.
diagnosis_file = os.path.join(data_folder, DEMOGRAPHICS_FILE)
label = load_diagnoses(diagnosis_file, patient_id)
# Store the features and labels, but
# the human annotations are not available on the hidden validation and test sets, but you
# may want to consider how to use them for training.
if label == 0 or label == 1:
features.append(np.hstack([demographic_features, physiological_features, algorithmic_features]))
labels.append(label)
if 'physiological_data' in locals(): del physiological_data
if 'algorithmic_annotations' in locals(): del algorithmic_annotations
except Exception as e:
# If an error occurs (e.g., a record is corrupted), log it and move to the next
tqdm.write(f" !!! Error processing record {i+1} ({patient_id}): {e}")
continue
pbar.close()
features = np.asarray(features, dtype=np.float32)
labels = np.asarray(labels, dtype=bool)
# Train the models on the features.
if verbose:
print('Training the model on the data...')
# This very simple model trains a random forest model with very simple features.
# Define the parameters for the random forest classifier and regressor.
n_estimators = 12 # Number of trees in the forest.
max_leaf_nodes = 34 # Maximum number of leaf nodes in each tree.
random_state = 56 # Random state; set for reproducibility.
# Fit the model.
model = RandomForestClassifier(
n_estimators=n_estimators, max_leaf_nodes=max_leaf_nodes, random_state=random_state).fit(features, labels)
# Create a folder for the model if it does not already exist.
os.makedirs(model_folder, exist_ok=True)
# Save the model.
save_model(model_folder, model)
if verbose:
print('Done.')
print()
# Load your trained models. This function is *required*. You should edit this function to add your code, but do *not* change the
# arguments of this function. If you do not train one of the models, then you can return None for the model.
def load_model(model_folder, verbose):
model_filename = os.path.join(model_folder, 'model.sav')
model = joblib.load(model_filename)
return model
# Run your trained model. This function is *required*. You should edit this function to add your code, but do *not* change the
# arguments of this function.
def run_model(model, record, data_folder, verbose):
# Load the model.
model = model['model']
# Extract identifiers from the record dictionary
patient_id = record[HEADERS['bids_folder']]
site_id = record[HEADERS['site_id']]
session_id = record[HEADERS['session_id']]
# Load the patient data.
patient_data_file = os.path.join(data_folder, DEMOGRAPHICS_FILE)
patient_data = load_demographics(patient_data_file, patient_id, session_id)
demographic_features = extract_demographic_features(patient_data)
# Load signal data.
phys_file = os.path.join(data_folder, PHYSIOLOGICAL_DATA_SUBFOLDER, site_id, f"{patient_id}_ses-{session_id}.edf")
if os.path.exists(phys_file):
phys_data, phys_fs = load_signal_data(phys_file)
# Ensure csv_path is accessible or defined
physiological_features = extract_physiological_features(phys_data, phys_fs)
else:
# Fallback to zeros if file is missing (length 49)
physiological_features = np.zeros(49)
# Load Algorithmic Annotations
algo_file = os.path.join(data_folder, ALGORITHMIC_ANNOTATIONS_SUBFOLDER, site_id, f"{patient_id}_ses-{session_id}_caisr_annotations.edf")
if os.path.exists(algo_file):
algo_data, _ = load_signal_data(algo_file)
algorithmic_features = extract_algorithmic_annotations_features(algo_data)
else:
# Fallback to zeros (length 12)
algorithmic_features = np.zeros(12)
features = np.hstack([demographic_features, physiological_features, algorithmic_features]).reshape(1, -1)
# Get the model outputs.
binary_output = model.predict(features)[0]
probability_output = model.predict_proba(features)[0][1]
return binary_output, probability_output
################################################################################
#
# Optional functions. You can change or remove these functions and/or add new functions.
#
################################################################################
def extract_demographic_features(data):
"""
Extracts and encodes demographic features from a metadata dictionary.
Inputs:
data (dict): A dictionary containing patient metadata (e.g., from a CSV row).
Returns:
np.array: A feature vector of length 11:
- [0]: Age (Continuous)
- [1:4]: Sex (One-hot: Female, Male, Other/Unknown)
- [4:9]: Race (One-hot: Asian, Black, Other, Unavailable, White)
- [9]: BMI (Continuous)
"""
# 1. Age Feature (1 dimension)
# Convert 'Age' to a float; default to 0 if missing
age = np.array([load_age(data)])
# 2. Sex One-Hot Encoding (3 dimensions: Female, Male, Other/Unknown)
# Uses lowercase prefix matching to handle variants like 'F', 'Female', 'M', or 'Male'
sex = load_sex(data)
sex_vec = np.zeros(3)
if sex == 'Female':
sex_vec[0] = 1 # Index 0: Female
elif sex == 'Male':
sex_vec[1] = 1 # Index 1: Male
else:
sex_vec[2] = 1 # Index 2: Other/Unknown
# 3. Race One-Hot Encoding (6 dimensions)
# Standardizes the raw text into one of six categories using the helper function
race_category = get_standardized_race(data).lower()
race_vec = np.zeros(5)
# Pre-defined mapping for index consistency
race_mapping = {'asian': 0, 'black': 1, 'others': 2, 'unavailable': 3, 'white': 4}
race_vec[race_mapping.get(race_category, 2)] = 1
# 4. Body Mass Index (BMI) Feature (1 dimension)
# Extracts the pre-calculated mean BMI; handles strings, NaNs, and missing keys
bmi = np.array([load_bmi(data)])
# 5. Concatenate all components into a single vector (1 + 3 + 5 + 1 = 10)
return np.concatenate([age, sex_vec, race_vec, bmi])
def extract_physiological_features(physiological_data, physiological_fs, csv_path=DEFAULT_CSV_PATH):
"""
Standardizes channels and extracts statistical/spectral features.
"""
original_labels = list(physiological_data.keys())
# Step 1: Load rules and standardize names
# Note: Use script-relative path or absolute path for robustness
rename_rules = load_rename_rules(os.path.abspath(csv_path))
rename_map, cols_to_drop = standardize_channel_names_rename_only(original_labels, rename_rules)
# Step 2: Apply renaming to BOTH signals and their corresponding FS
processed_channels = {}
processed_fs = {}
for old_label, data in physiological_data.items():
if old_label in cols_to_drop:
continue
new_label = rename_map.get(old_label, old_label.lower())
processed_channels[new_label] = data
# Mapping the sampling rate to the new label
processed_fs[new_label] = physiological_fs.get(old_label, 200.0) # Default to 200 if missing
if 'physiological_data' in locals(): del physiological_data
# Step 3: Construct Bipolar Derivations
bipolar_configs = [
('f3-m2', 'f3', ['m2']), ('f4-m1', 'f4', ['m1']),
('c3-m2', 'c3', ['m2']), ('c4-m1', 'c4', ['m1']),
('o1-m2', 'o1', ['m2']), ('o2-m1', 'o2', ['m1']),
('e1-m2', 'e1', ['m2']), ('e2-m1', 'e2', ['m1']),
('chin1-chin2', 'chin1', ['chin2']),
('lat', 'lleg+', ['lleg-']), ('rat', 'rleg+', ['rleg-'])
]
for target, pos, neg_list in bipolar_configs:
# 1. Skip if target already exists or pos channel missing
if target in processed_channels or pos not in processed_channels:
continue
# 2. Check all neg channels exist
if not all(n in processed_channels for n in neg_list):
continue
# 3. Check sampling rate consistency
all_involved = [pos] + neg_list
fs_values = [processed_fs[ch] for ch in all_involved]
if len(set(fs_values)) > 1:
raise ValueError(f"Sampling rate mismatch for {target}: {dict(zip(all_involved, fs_values))}")
# 4. Derive bipolar signal
ref_sig = processed_channels[neg_list[0]] if len(neg_list) == 1 else tuple(processed_channels[n] for n in neg_list)
derived = derive_bipolar_signal(processed_channels[pos], ref_sig)
if derived is not None:
processed_channels[target] = derived
processed_fs[target] = processed_fs[pos]
leads_to_check = {
'eeg': ['f3-m2', 'f4-m1', 'c3-m2', 'c4-m1'],
'eog': ['e1-m2', 'e2-m1'],
'chin': ['chin1-chin2', 'chin'],
'leg': ['lat', 'rat'],
'ecg': ['ecg', 'ekg'],
'resp': ['airflow', 'ptaf', 'abd', 'chest'],
'spo2': ['spo2', 'sao2'] # Added sao2 as fallback for spo2
}
final_features = []
for lead_type, candidates in leads_to_check.items():
sig = None
fs = None
# Identify the first available candidate
for candidate in candidates:
if candidate in processed_channels and processed_channels[candidate] is not None:
sig = processed_channels[candidate]
fs = processed_fs.get(candidate)
break
# if sig is not None and len(sig) > 0 and fs is not None:
# # --- 1. Time Domain Features ---
# std_val = np.std(sig)
# mav_val = np.mean(np.abs(sig))
# energy_val = np.sum(sig**2) / len(sig)
# # --- 2. Frequency Domain Features (Spectral) ---
# n = len(sig)
# # Correct spacing for frequency axis based on channel-specific fs
# freqs = np.fft.rfftfreq(n, d=1/fs)
# # Compute Power Spectral Density (PSD)
# # Multiplied by 2 for rfft (except DC/Nyquist) and divided by fs for density
# fft_res = np.abs(np.fft.rfft(sig))
# psd = (fft_res**2) / (n * fs)
# # Define band masks
# delta_mask = (freqs >= 0.5) & (freqs <= 4)
# theta_mask = (freqs > 4) & (freqs <= 8)
# alpha_mask = (freqs > 8) & (freqs <= 12)
# # Calculate power in bands using trapezoidal integration for physical accuracy
# delta_p = np.trapezoid(psd[delta_mask], freqs[delta_mask]) if np.any(delta_mask) else 0.0
# theta_p = np.trapezoid(psd[theta_mask], freqs[theta_mask]) if np.any(theta_mask) else 0.0
# alpha_p = np.trapezoid(psd[alpha_mask], freqs[alpha_mask]) if np.any(alpha_mask) else 0.0
# # Ratio biomarker: Delta/Theta (Indicator of cognitive slowing)
# dt_ratio = delta_p / theta_p if theta_p > 0 else 0.0
# final_features.extend([std_val, mav_val, energy_val, delta_p, theta_p, alpha_p, dt_ratio])
if sig is not None and len(sig) > 1:
# --- Time Domain Features (Very Fast) ---
std_val = np.std(sig)
mav_val = np.mean(np.abs(sig))
# Zero Crossing Rate (Proxy for frequency/slowing)
zcr = np.mean(np.diff(np.sign(sig)) != 0)
# Root Mean Square
rms = np.sqrt(np.mean(sig**2))
# Signal Activity (Variance)
activity = np.var(sig)
# Mobility (Hjorth Parameter) - Proxy for mean frequency
# sqrt(var(diff(sig)) / var(sig))
diff_sig = np.diff(sig)
mobility = np.sqrt(np.var(diff_sig) / activity) if activity > 0 else 0.0
# Complexity (Hjorth Parameter) - Proxy for bandwidth
diff2_sig = np.diff(diff_sig)
var_d2 = np.var(diff2_sig)
var_d1 = np.var(diff_sig)
complexity = (np.sqrt(var_d2 / var_d1) / mobility) if (var_d1 > 0 and mobility > 0) else 0.0
final_features.extend([std_val, mav_val, zcr, rms, activity, mobility, complexity])
else:
# Padding: 7 features per lead type
final_features.extend([0.0] * 7)
if 'processed_channels' in locals(): del processed_channels
return np.array(final_features)
def extract_algorithmic_annotations_features(algo_data):
"""
Extracts sleep architecture and event density features from CAISR outputs.
Output vector length: 12
"""
if not algo_data:
return np.zeros(12)
features = []
# --- 1. Respiratory & Arousal Event Densities ---
# Total duration in hours (assuming 1Hz for event traces)
# If the signal exists, we calculate events per hour (Index)
total_hours = len(algo_data.get('resp_caisr', [])) / 3600.0
def count_discrete_events(key):
if key not in algo_data or total_hours <= 0:
return 0.0
sig = algo_data[key].astype(float)
# Create a binary mask: 1 if there is an event, 0 if not
binary_sig = (sig > 0).astype(int)
# Detect rising edges: 0 to 1 transition
# diff will be 1 at the start of an event, -1 at the end
diff = np.diff(binary_sig, prepend=0)
num_events = np.count_nonzero(diff == 1)
return num_events / total_hours
ahi_auto = count_discrete_events('resp_caisr') # Automated Apnea-Hypopnea Index
arousal_auto = count_discrete_events('arousal_caisr') # Automated Arousal Index
limb_auto = count_discrete_events('limb_caisr') # Automated Limb Movement Index
features.extend([ahi_auto, arousal_auto, limb_auto])
# --- 2. Sleep Architecture (from stage_caisr) ---
# Standard labels: 0=W, 1=N1, 2=N2, 3=N3, 4=R (or similar mapping)
stages = algo_data.get('stage_caisr', np.array([]))
# Filter out invalid/background values (like the 9.0 in your sample)
valid_stages = stages[stages < 9.0]
if len(valid_stages) > 0:
total_epochs = len(valid_stages)
# Percentage of each stage
w_pct = np.mean(valid_stages == 5)
r_pct = np.mean(valid_stages == 4)
n1_pct = np.mean(valid_stages == 3)
n2_pct = np.mean(valid_stages == 2)
n3_pct = np.mean(valid_stages == 1)
# Sleep Efficiency: (N1+N2+N3+R) / Total
efficiency = np.mean(valid_stages > 0)
else:
w_pct = n1_pct = n2_pct = n3_pct = r_pct = efficiency = 0.0
features.extend([w_pct, n1_pct, n2_pct, n3_pct, r_pct, efficiency])
# --- 3. Model Confidence / Uncertainty ---
# Mean probability of Wake and REM (indicators of sleep stability)
# We use the raw probability traces
prob_w = np.mean(algo_data.get('caisr_prob_w', [0]))
prob_n3 = np.mean(algo_data.get('caisr_prob_n3', [0]))
prob_arous = np.mean(algo_data.get('caisr_prob_arous', [0]))
# Standardize '9.0' or other filler values to 0
clean_prob = lambda x: x if x < 1.0 else 0.0
features.extend([clean_prob(prob_w), clean_prob(prob_n3), clean_prob(prob_arous)])
return np.array(features)
def extract_human_annotations_features(human_data):
"""
Extracts features from expert-scored human annotations.
Output vector length: 12 (to match algorithmic feature length)
"""
# If data is missing (common in hidden test sets), return a zero vector
if not human_data or 'resp_expert' not in human_data:
return np.zeros(12)
features = []
# --- 1. Human Event Indices (Events per Hour) ---
# Total duration in hours based on 1Hz signal
total_seconds = len(human_data.get('resp_expert', []))
total_hours = total_seconds / 3600.0
def count_discrete_events(key):
if key not in human_data or total_hours <= 0:
return 0.0
sig = (human_data[key] > 0).astype(int)
# Identify the start of each continuous event block
diff = np.diff(sig, prepend=0)
return np.count_nonzero(diff == 1) / total_hours
ahi_human = count_discrete_events('resp_expert') # Human AHI
arousal_human = count_discrete_events('arousal_expert') # Human Arousal Index
limb_human = count_discrete_events('limb_expert') # Human PLMI
features.extend([ahi_human, arousal_human, limb_human])
# --- 2. Human Sleep Architecture ---
# Standard labels: 0=W, 1=N1, 2=N2, 3=N3, 4=R, 5=Unknown/Movement
stages = human_data.get('stage_expert', np.array([]))
# Filter out label 5 (often used by experts for movement/unscored)
valid_mask = (stages < 9.0)
valid_stages = stages[valid_mask]
if len(valid_stages) > 0:
w_pct = np.mean(valid_stages == 5)
r_pct = np.mean(valid_stages == 4)
n1_pct = np.mean(valid_stages == 3)
n2_pct = np.mean(valid_stages == 2)
n3_pct = np.mean(valid_stages == 1)
efficiency = np.mean(valid_stages > 0)
else:
w_pct = n1_pct = n2_pct = n3_pct = r_pct = efficiency = 0.0
features.extend([w_pct, n1_pct, n2_pct, n3_pct, r_pct, efficiency])
# --- 3. Fragmentation & Stability (Replacing Probabilities) ---
# These metrics quantify how "broken" the sleep is, which is a key marker.
if len(valid_stages) > 1:
# Number of stage transitions
transitions = np.count_nonzero(np.diff(valid_stages)) / total_hours
# Wake After Sleep Onset (WASO) proxy: non-zero stages followed by zero
waso_minutes = (np.count_nonzero(valid_stages == 0) * 30) / 60.0
# REM Latency (epochs until first REM)
rem_indices = np.where(valid_stages == 4)[0]
rem_latency = rem_indices[0] if len(rem_indices) > 0 else 0.0
else:
transitions = waso_minutes = rem_latency = 0.0
features.extend([transitions, waso_minutes, rem_latency])
return np.array(features)
# Save your trained model.
def save_model(model_folder, model):
d = {'model': model}
filename = os.path.join(model_folder, 'model.sav')
joblib.dump(d, filename, protocol=0)