comparison mil_bag.py @ 0:e6e9ea0703ef draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
author goeckslab
date Thu, 19 Jun 2025 23:31:55 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:e6e9ea0703ef
1 """
2 A script for creating bags of instances from embeddings
3 and metadata for Multiple Instance Learning (MIL) tasks.
4
5 Processes embedding and metadata CSV files to generate
6 bags of instances, saved as a single CSV file. Supports
7 bagging strategies (by sample, in turns, or random),
8 pooling methods, and options for balancing, preventing
9 data leakage, and Ludwig formatting. Handles large
10 datasets efficiently using temporary Parquet files,
11 sequential processing, and multiprocessing.
12
13 Dependencies:
14 - gc: For manual garbage collection to manage memory.
15 - argparse: For parsing command-line arguments.
16 - logging: For logging progress and errors.
17 - multiprocessing (mp): For parallel processing.
18 - os: For file operations and temporary file management.
19 - tempfile: For creating temporary files.
20 - numpy (np): For numerical operations and array.
21 - pandas (pd): For data manipulation and I/O (CSV, Parquet).
22 - torch: For tensor operations (attention pooling).
23 - torch.nn: For NN components (attention pooling).
24 - fastparquet: For reading and writing Parquet files.
25
26 Key Features:
27 - Multiple bagging: by sample (`bag_by_sample`), in
28 turns (`bag_in_turns`), or random (`bag_random`).
29 - Various pooling methods (e.g., max, mean, attention).
30 - Prevents data leakage by splitting at sample level.
31 - Balances bags by label imbalance or truncating.
32 - Outputs in Ludwig format (whitespace-separated vectors).
33 - Efficient large dataset processing (temp Parquet,
34 sequential CSV write).
35 - GPU acceleration for certain pooling (e.g., attention).
36
37 Usage:
38 Run the script from the command line with arguments:
39
40 ```bash
41 python ludwig_mil_temp.py --embeddings_csv <path_to_embeddings.csv>
42 --metadata_csv <path_to_metadata.csv> --bag_size <bag_size>
43 --pooling_method <method> --output_csv <output.csv>
44 [--split_proportions <train,val,test>] [--dataleak]
45 [--balance_enforced] [--by_sample <splits>] [--repeats <num>]
46 [--ludwig_format] [--random_seed <seed>]
47 [--imbalance_cap <percentage>] [--truncate_bags] [--use_gpu]
48 """
49
50 import argparse
51 import gc
52 import logging
53 import multiprocessing as mp
54 import os
55 import tempfile
56
57 import numpy as np
58 import pandas as pd
59 import torch
60 import torch.nn as nn
61
62
63 def parse_bag_size(bag_size_str):
64 """Parses bag size string into a range or single value."""
65 try:
66 if '-' in bag_size_str:
67 start, end = map(int, bag_size_str.split('-'))
68 return list(range(start, end + 1))
69 return [int(bag_size_str)]
70 except ValueError:
71 logging.error("Invalid bag_size format: %s", bag_size_str)
72 raise
73
74
75 def parse_by_sample(value):
76 """Parses by_sample string into a set of split values."""
77 try:
78 value = str(value)
79 splits = [int(x) for x in value.split(",")]
80 valid_splits = {0, 1, 2}
81 if not all(x in valid_splits for x in splits):
82 logging.warning("Invalid splits in by_sample: %s", splits)
83 return None
84 return splits
85 except (ValueError, AttributeError):
86 logging.warning("By_Sample not used")
87 return None
88
89
90 class BaggingConfig:
91 """Configuration class for bagging parameters."""
92
93 def __init__(self, params):
94 self.embeddings_csv = params.embeddings_csv
95 self.metadata_csv = params.metadata_csv
96 self.split_proportions = params.split_proportions
97 self.prevent_leakage = params.dataleak
98 self.balance_enforced = params.balance_enforced
99 self.bag_size = parse_bag_size(params.bag_size)
100 self.pooling_method = params.pooling_method
101 self.by_sample = parse_by_sample(params.by_sample)
102 self.repeats = params.repeats
103 self.ludwig_format = params.ludwig_format
104 self.output_csv = params.output_csv
105 self.random_seed = params.random_seed
106 self.imbalance_cap = params.imbalance_cap
107 self.truncate_bags = params.truncate_bags
108 self.use_gpu = params.use_gpu
109
110 def __str__(self):
111 """String representation of the config for logging."""
112 return (
113 f"embeddings_csv={self.embeddings_csv}, "
114 f"metadata_csv={self.metadata_csv}, "
115 f"split_proportions={self.split_proportions}, "
116 f"prevent_leakage={self.prevent_leakage}, "
117 f"balance_enforced={self.balance_enforced}, "
118 f"bag_size={self.bag_size}, "
119 f"pooling_method={self.pooling_method}, "
120 f"by_sample={self.by_sample}, "
121 f"repeats={self.repeats}, "
122 f"ludwig_format={self.ludwig_format}, "
123 f"output_csv={self.output_csv}, "
124 f"random_seed={self.random_seed}, "
125 f"imbalance_cap={self.imbalance_cap}, "
126 f"truncate_bags={self.truncate_bags}, "
127 f"use_gpu={self.use_gpu}"
128 )
129
130
131 def set_random_seed(configs):
132 """Sets random seeds for reproducibility."""
133 np.random.seed(configs.random_seed)
134 torch.manual_seed(configs.random_seed)
135 if torch.cuda.is_available():
136 torch.cuda.manual_seed_all(configs.random_seed)
137 torch.backends.cudnn.deterministic = True
138 torch.backends.cudnn.benchmark = False
139 logging.info("Random seed set to %d", configs.random_seed)
140
141
142 def validate_metadata(metadata):
143 """Validates metadata for required columns."""
144 required_cols = {"sample_name", "label"}
145 if not required_cols.issubset(metadata.columns):
146 missing = required_cols - set(metadata.columns)
147 raise ValueError(f"Metadata missing columns: {missing}")
148 return metadata
149
150
151 def load_metadata(file_path):
152 """Loads metadata from a CSV file."""
153 metadata = pd.read_csv(file_path)
154 validate_metadata(metadata)
155 logging.info("Metadata loaded with %d samples, cols: %s",
156 len(metadata), list(metadata.columns))
157 logging.info("Unique samples: %d, labels: %d",
158 metadata["sample_name"].nunique(),
159 metadata["label"].nunique())
160 return metadata
161
162
163 def convert_proportions(proportion_string):
164 """Converts a string of split proportions into a list of floats."""
165 proportion_list = [float(p) for p in proportion_string.split(",")]
166 print(proportion_list)
167 if len(proportion_list) == 2:
168 proportion_list = [proportion_list[0], 0.0, proportion_list[1]]
169
170 for proportion in proportion_list:
171 if proportion < 0 or proportion > 1:
172 raise ValueError("Each proportion must be between 0 and 1")
173
174 if abs(sum(proportion_list) - 1.0) > 1e-6:
175 raise ValueError("Proportions must sum to approximately 1.0")
176
177 return proportion_list
178
179
180 def calculate_split_counts(total_samples, proportions):
181 """Calculates sample counts for each split."""
182 counts = [int(p * total_samples) for p in proportions]
183 calculated_total = sum(counts)
184 if calculated_total < total_samples:
185 counts[-1] += total_samples - calculated_total
186 elif calculated_total > total_samples:
187 counts[0] -= calculated_total - total_samples
188 return counts
189
190
191 def assign_split_labels(proportions, sample_count):
192 """Assigns split labels based on proportions."""
193 proportion_values = convert_proportions(proportions)
194 train_fraction, val_fraction, test_fraction = proportion_values
195
196 if val_fraction == 0 and test_fraction == 0:
197 labels = np.zeros(sample_count, dtype=int)
198 elif val_fraction == 0:
199 train_size = int(train_fraction * sample_count)
200 test_size = sample_count - train_size
201 labels = np.array([0] * train_size + [2] * test_size)
202 else:
203 split_counts = calculate_split_counts(sample_count, proportion_values)
204 labels = np.concatenate([
205 np.zeros(split_counts[0], dtype=int),
206 np.ones(split_counts[1], dtype=int),
207 2 * np.ones(split_counts[2], dtype=int)
208 ])
209 return labels
210
211
212 def split_dataset(metadata, configs):
213 """Splits dataset into train, val, test sets if prevent_leakage is True."""
214 if configs.prevent_leakage:
215 logging.info("No data leakage allowed")
216 unique_samples = metadata["sample_name"].unique()
217 sample_count = len(unique_samples)
218 split_labels = assign_split_labels(configs.split_proportions,
219 sample_count)
220 shuffled_samples = np.random.permutation(unique_samples)
221 label_series = pd.Series(split_labels, index=shuffled_samples)
222 metadata["split"] = metadata["sample_name"].map(label_series)
223 train_count = (metadata["split"] == 0).sum()
224 val_count = (metadata["split"] == 1).sum()
225 test_count = (metadata["split"] == 2).sum()
226 logging.info("Dataset split: train %d, val %d, test %d",
227 train_count, val_count, test_count)
228 else:
229 logging.info("Data leakage allowed setup")
230 return metadata
231
232
233 def assign_chunk_splits(chunk, split_counts, current_counts):
234 """Assigns split labels to a chunk of embeddings."""
235 chunk_size = len(chunk)
236 remaining = {
237 0: split_counts[0] - current_counts[0],
238 1: split_counts[1] - current_counts[1],
239 2: split_counts[2] - current_counts[2]
240 }
241 available_splits = [s for s, count in remaining.items() if count > 0]
242 if not available_splits:
243 return chunk, current_counts
244
245 total_remaining = sum(remaining.values())
246 assign_count = min(chunk_size, total_remaining)
247 if assign_count == 0:
248 return chunk, current_counts
249
250 weights = [remaining[s] / total_remaining for s in available_splits]
251 splits = np.random.choice(available_splits, size=assign_count, p=weights)
252 chunk["split"] = pd.Series(splits, index=chunk.index[:assign_count])
253 chunk["split"] = chunk["split"].fillna(0).astype(int)
254
255 for split in available_splits:
256 current_counts[split] += np.sum(splits == split)
257
258 return chunk, current_counts
259
260
261 def setup_temp_files():
262 """Sets up temporary Parquet files for splits and bag outputs."""
263 splits = [0, 1, 2]
264 split_files = {}
265 for split in splits:
266 fd, path = tempfile.mkstemp(prefix=f"split_{split}_",
267 suffix=".parquet",
268 dir=os.getcwd())
269 os.close(fd) # Explicitly close the file descriptor
270 split_files[split] = path
271
272 bag_outputs = {}
273 for split in splits:
274 fd, path = tempfile.mkstemp(prefix=f"MIL_bags_{split}_",
275 suffix=".parquet",
276 dir=os.getcwd())
277 os.close(fd) # Explicitly close the file descriptor
278 bag_outputs[split] = path
279
280 return split_files, bag_outputs
281
282
283 def distribute_embeddings(configs, metadata, split_files):
284 embeddings_path = configs.embeddings_csv
285 proportion_string = configs.split_proportions
286 prevent_leakage = configs.prevent_leakage
287
288 logging.info("Distributing embeddings from %s to Parquet files",
289 embeddings_path)
290 buffer_size = 50000
291 merged_header = None
292 non_sample_columns = None
293
294 if not prevent_leakage:
295 logging.warning(
296 "Counting rows in %s; may be slow for large files",
297 embeddings_path
298 )
299 total_rows = sum(1 for _ in open(embeddings_path)) - 1
300 proportions = convert_proportions(proportion_string)
301 split_counts = calculate_split_counts(total_rows, proportions)
302 current_counts = {0: 0, 1: 0, 2: 0}
303 else:
304 sample_to_split = dict(zip(metadata["sample_name"], metadata["split"]))
305 sample_to_label = dict(zip(metadata["sample_name"], metadata["label"]))
306
307 first_write = {split: True for split in split_files}
308
309 try:
310 first_header_read = True
311 for chunk in pd.read_csv(embeddings_path, chunksize=buffer_size):
312 # Modify 'sample_name' to remove part after the last underscore
313 chunk['sample_name'] = chunk['sample_name'].apply(lambda x: x.rsplit('_', 1)[0])
314
315 if first_header_read:
316 orig_header = list(chunk.columns)
317 non_sample_columns = [
318 col for col in orig_header if col != "sample_name"
319 ]
320 merged_header = ["sample_name", "label"] + non_sample_columns
321 logging.info("Merged header: %s", merged_header)
322 first_header_read = False
323
324 if prevent_leakage:
325 chunk["split"] = chunk["sample_name"].map(sample_to_split)
326 chunk["label"] = chunk["sample_name"].map(sample_to_label)
327 else:
328 chunk, current_counts = assign_chunk_splits(chunk,
329 split_counts,
330 current_counts)
331 chunk = chunk.merge(metadata[["sample_name", "label"]],
332 on="sample_name",
333 how="left")
334
335 chunk = chunk.dropna(subset=["split", "label"])
336 for split in split_files:
337 split_chunk = chunk[chunk["split"] == split]
338 if not split_chunk.empty:
339 temp_file = split_files[split]
340 split_chunk[merged_header].to_parquet(
341 temp_file,
342 engine="fastparquet",
343 append=not first_write[split],
344 index=False
345 )
346 first_write[split] = False
347 del chunk
348 gc.collect()
349
350 except Exception as e:
351 logging.error("Error distributing embeddings to Parquet: %s", e)
352 raise
353
354
355 def aggregate_embeddings(embeddings, pooling_method, use_gpu=False):
356 # Convert embeddings to a float32 array explicitly.
357 embeddings = np.asarray(embeddings, dtype=np.float32)
358
359 if embeddings.ndim == 1:
360 embeddings = embeddings.reshape(1, -1)
361 elif embeddings.ndim == 0:
362 embeddings = embeddings.reshape(1, 1)
363
364 logging.debug("Aggregating embeddings with shape: %s", embeddings.shape)
365
366 if pooling_method == "max_pooling":
367 result = np.max(embeddings, axis=0)
368 elif pooling_method == "mean_pooling":
369 result = np.mean(embeddings, axis=0)
370 elif pooling_method == "sum_pooling":
371 result = np.sum(embeddings, axis=0)
372 elif pooling_method == "min_pooling":
373 result = np.min(embeddings, axis=0)
374 elif pooling_method == "median_pooling":
375 result = np.median(embeddings, axis=0)
376 elif pooling_method == "l2_norm_pooling":
377 norm = np.linalg.norm(embeddings, axis=1, keepdims=True)
378 if norm.any():
379 result = np.mean(embeddings / (norm + 1e-8), axis=0)
380 else:
381 result = np.mean(embeddings, axis=0)
382 elif pooling_method == "geometric_mean_pooling":
383 clipped = np.clip(embeddings, 1e-10, None)
384 result = np.exp(np.mean(np.log(clipped), axis=0))
385 elif pooling_method == "first_embedding":
386 result = embeddings[0]
387 elif pooling_method == "last_embedding":
388 result = embeddings[-1]
389 elif pooling_method == "attention_pooling":
390 device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu'
391 tensor = torch.tensor(embeddings, dtype=torch.float32).to(device)
392 with torch.no_grad():
393 linear = nn.Linear(tensor.shape[1], 1).to(device)
394 weights = nn.Softmax(dim=0)(linear(tensor))
395 result = torch.sum(weights * tensor, dim=0).cpu().detach().numpy()
396 else:
397 raise ValueError(f"Unknown pooling method: {pooling_method}")
398
399 logging.debug("Aggregated embedding shape: %s", result.shape)
400 return result
401
402
403 def bag_by_sample(df, split, bag_file, config, batch_size=1000,
404 fixed_target_bags=None):
405 """
406 Processes the provided DataFrame by grouping rows by sample,
407 constructs bags from each sample group using the configured bag_size,
408 and writes the bag rows directly to bag_file (a Parquet file) in batches.
409
410 Args:
411 df (pd.DataFrame): The DataFrame containing the data.
412 split (str): The split identifier (e.g., 'train', 'val').
413 bag_file (str): The path to the Parquet file to write the bags.
414 config (object): Configuration object with bag_size, pooling_method...
415 batch_size (int, optional): The number of rows to write in each batch.
416 fixed_target_bags (tuple, optional): (target_label, num_bags)
417 to generate bags only for target_label.
418
419 Output row format:
420 sample_name, bag_label, split, bag_size, vector_0, vector_1, vector_N
421 """
422 log_msg = f"Processing by sample for split: {split}"
423 if fixed_target_bags:
424 log_msg += f" with fixed target {fixed_target_bags}"
425 logging.info(log_msg)
426
427 batch_rows = []
428 bag_count = 0
429 vector_columns = [
430 col for col in df.columns
431 if col not in ["sample_name", "label", "split"]
432 ]
433
434 if fixed_target_bags is not None:
435 target_label, target_needed = fixed_target_bags
436 target_samples = list(
437 df[df["label"] == target_label]["sample_name"].unique()
438 )
439 df = df[df["sample_name"].isin(target_samples)]
440
441 if df.empty:
442 logging.warning(
443 "No samples available for target label %d in split %s",
444 target_label,
445 split
446 )
447 return
448
449 available_samples = target_samples.copy()
450 np.random.shuffle(available_samples)
451
452 while bag_count < target_needed:
453 if len(available_samples) == 0:
454 available_samples = target_samples.copy()
455 np.random.shuffle(available_samples)
456 logging.info(
457 "Reusing samples for target label %d in split %s",
458 target_label,
459 split
460 )
461
462 sample_name = available_samples.pop()
463 group = df[df["sample_name"] == sample_name]
464 embeddings = group[vector_columns].values
465 num_instances = len(group)
466
467 current_bag_size = config.bag_size[0] \
468 if len(config.bag_size) == 1 else \
469 np.random.randint(config.bag_size[0], config.bag_size[1] + 1)
470 current_bag_size = min(current_bag_size, num_instances)
471
472 selected = group.sample(n=current_bag_size, replace=True)
473 bag_embeddings = selected[vector_columns].values
474
475 aggregated_embedding = aggregate_embeddings(
476 bag_embeddings,
477 config.pooling_method,
478 config.use_gpu
479 )
480
481 bag_label = int(any(selected["label"] == 1))
482 if bag_label != target_label:
483 logging.warning(
484 "Generated bag for target %d but got label %d",
485 target_label, bag_label
486 )
487 continue
488
489 row = {
490 "sample_name": sample_name,
491 "bag_label": bag_label,
492 "split": split,
493 "bag_size": current_bag_size
494 }
495 for j, val in enumerate(aggregated_embedding):
496 row[f"vector_{j}"] = val
497
498 batch_rows.append(row)
499 bag_count += 1
500
501 if len(batch_rows) >= batch_size:
502 df_batch = pd.DataFrame(batch_rows)
503 # Check if the file has data to determine append mode
504 append_mode = os.path.getsize(bag_file) > 0
505 df_batch.to_parquet(
506 bag_file,
507 engine="fastparquet",
508 append=append_mode,
509 index=False
510 )
511 logging.debug(
512 "Fixed mode: Wrote batch of %d rows to %s",
513 len(batch_rows),
514 bag_file
515 )
516 batch_rows = []
517 del df_batch
518 gc.collect()
519
520 else:
521 # Standard mode: process all samples
522 groups = df.groupby("sample_name")
523 for sample_name, group in groups:
524 embeddings = group[vector_columns].values
525 labels = group["label"].values
526 num_instances = len(group)
527
528 current_bag_size = config.bag_size[0] \
529 if len(config.bag_size) == 1 else \
530 np.random.randint(
531 config.bag_size[0],
532 config.bag_size[1] + 1
533 )
534 num_bags = (
535 num_instances + current_bag_size - 1
536 ) // current_bag_size
537 logging.info(
538 "Sample %s: %d instances, creating %d bags (bag size %d)",
539 sample_name,
540 num_instances,
541 num_bags,
542 current_bag_size
543 )
544
545 for i in range(num_bags):
546 start_idx = i * current_bag_size
547 end_idx = min(start_idx + current_bag_size, num_instances)
548 bag_embeddings = embeddings[start_idx:end_idx]
549 bag_labels = labels[start_idx:end_idx]
550
551 aggregated_embedding = aggregate_embeddings(
552 bag_embeddings,
553 config.pooling_method,
554 config.use_gpu
555 )
556 bag_label = int(any(bag_labels == 1))
557
558 row = {
559 "sample_name": sample_name,
560 "bag_label": bag_label,
561 "split": split,
562 "bag_size": end_idx - start_idx
563 }
564 for j, val in enumerate(aggregated_embedding):
565 row[f"vector_{j}"] = val
566
567 batch_rows.append(row)
568 bag_count += 1
569
570 if len(batch_rows) >= batch_size:
571 df_batch = pd.DataFrame(batch_rows)
572 # Check if the file has data to determine append mode
573 append_mode = os.path.getsize(bag_file) > 0
574 df_batch.to_parquet(
575 bag_file,
576 engine="fastparquet",
577 append=append_mode,
578 index=False
579 )
580 logging.debug(
581 "Wrote batch of %d rows to %s",
582 len(batch_rows),
583 bag_file
584 )
585 batch_rows = []
586 del df_batch
587 gc.collect()
588
589 # Write any remaining rows
590 if batch_rows:
591 df_batch = pd.DataFrame(batch_rows)
592 append_mode = os.path.getsize(bag_file) > 0
593 df_batch.to_parquet(
594 bag_file,
595 engine="fastparquet",
596 append=append_mode,
597 index=False
598 )
599 logging.debug(
600 "Wrote final batch of %d rows to %s",
601 len(batch_rows),
602 bag_file
603 )
604 del df_batch
605 gc.collect()
606
607 logging.info("Created %d bags for split: %s", bag_count, split)
608
609
610 def bag_in_turns(df, split, bag_file, config, batch_size=500,
611 fixed_target_bags=None, allow_reuse=True):
612 """
613 Generate bags of instances from a DataFrame, with optional
614 fixed-target mode, data reuse, and enhanced diversity.
615
616 Parameters:
617 - df (pd.DataFrame): Input DataFrame with columns including
618 'sample_name', 'label', 'split', and embedding vectors.
619 - split (str): Dataset split (e.g., 'train', 'test').
620 - bag_file (str): Path to save the output Parquet file.
621 - config (object): Configuration object with attributes
622 'bag_size', 'pooling_method', and 'use_gpu'.
623 - batch_size (int): Number of bags to process before writing
624 to file (default: 500).
625 - fixed_target_bags (tuple): Optional (label, num_bags) to
626 generate bags for a specific label (e.g., (0, 100)).
627 - allow_reuse (bool): Allow resampling instances with
628 replacement if True (default: True).
629
630 Returns:
631 - None: Saves bags to the specified Parquet file.
632 """
633 logging.info(
634 "Processing bag in turns for split %s%s",
635 split,
636 (" with fixed target " + str(fixed_target_bags))
637 if fixed_target_bags is not None else ""
638 )
639
640 # Identify embedding columns (exclude non-vector columns).
641 vector_columns = [
642 col for col in df.columns
643 if col not in ["sample_name", "label", "split"]
644 ]
645
646 # Convert the DataFrame to a NumPy array for faster processing.
647 df_np = df.to_numpy()
648
649 # Determine bag size range from config.
650 if len(config.bag_size) == 1:
651 bag_min = bag_max = config.bag_size[0]
652 else:
653 bag_min, bag_max = config.bag_size
654
655 batch_rows = []
656 bag_count = 0
657
658 if fixed_target_bags is not None:
659 # Fixed-target mode: generate bags for a specific label.
660 target, target_needed = fixed_target_bags # e.g., (0, 100)
661 if target == 0:
662 # Optimize for target label 0: remove all label 1 instances
663 indices = np.where(df_np[:, 1] == 0)[0]
664 logging.info(
665 "Fixed mode: target label 0, using only label 0 instances, \
666 total available %d rows",
667 len(indices)
668 )
669 else:
670 # For target label 1, use all instances to allow mixing
671 indices = np.arange(len(df_np))
672 logging.info(
673 "Fixed mode: target label 1, using all instances, \
674 total available %d rows",
675 len(indices)
676 )
677
678 total_available = len(indices)
679
680 while bag_count < target_needed:
681 current_bag_size = np.random.randint(bag_min, bag_max + 1) \
682 if bag_min != bag_max else bag_min
683
684 if total_available < current_bag_size and not allow_reuse:
685 logging.warning(
686 "Not enough instances (%d) for bag size %d and \
687 target label %d",
688 total_available, current_bag_size, target
689 )
690 break
691
692 # Sample instances
693 selected = np.random.choice(
694 indices,
695 size=current_bag_size,
696 replace=allow_reuse
697 )
698 bag_data = df_np[selected]
699
700 if target == 1:
701 # For positive bags, ensure at least one instance has label 1
702 if not np.any(bag_data[:, 1] == 1):
703 continue # Skip if no positive instance
704 bag_label = 1
705 else:
706 # For negative bags, all instances are label 0 due to filtering
707 bag_label = 0
708
709 # Aggregate embeddings.
710 vec_col_indices = [
711 df.columns.get_loc(col) for col in vector_columns
712 ]
713 embeddings = bag_data[:, vec_col_indices].astype(np.float32)
714 aggregated_embedding = aggregate_embeddings(
715 embeddings,
716 config.pooling_method,
717 config.use_gpu
718 )
719
720 # Set bag metadata.
721 bsize = bag_data.shape[0]
722 samples = np.unique(bag_data[:, 0])
723 merged_sample_name = ",".join(map(str, samples))
724
725 # Create row for the bag.
726 row = {
727 "sample_name": merged_sample_name,
728 "bag_label": bag_label,
729 "split": split,
730 "bag_size": bsize
731 }
732 for j, val in enumerate(aggregated_embedding):
733 row[f"vector_{j}"] = val
734
735 batch_rows.append(row)
736 bag_count += 1
737
738 if len(batch_rows) >= batch_size:
739 df_batch = pd.DataFrame(batch_rows)
740 df_batch.to_parquet(
741 bag_file,
742 engine="fastparquet",
743 append=True,
744 index=False
745 )
746 logging.debug(
747 "Fixed mode: Wrote a batch of %d rows to %s",
748 len(batch_rows),
749 bag_file
750 )
751 batch_rows = []
752 del df_batch
753 gc.collect()
754
755 # Write any remaining rows.
756 if batch_rows:
757 df_batch = pd.DataFrame(batch_rows)
758 df_batch.to_parquet(
759 bag_file,
760 engine="fastparquet",
761 append=True,
762 index=False
763 )
764 logging.debug(
765 "Wrote the final batch of %d rows to %s",
766 len(batch_rows),
767 bag_file
768 )
769 del df_batch
770 gc.collect()
771
772 logging.info("Created %d bags for split: %s", bag_count, split)
773
774 else:
775 # Alternating mode: alternate between labels 0 and 1.
776 indices_0 = np.where(df_np[:, 1] == 0)[0]
777 indices_1 = np.where(df_np[:, 1] == 1)[0]
778 np.random.shuffle(indices_0)
779 np.random.shuffle(indices_1)
780 turn = 0 # 0: label 0, 1: label 1.
781
782 while len(indices_0) > 0 or len(indices_1) > 0:
783 current_bag_size = np.random.randint(bag_min, bag_max + 1) \
784 if bag_min != bag_max else bag_min
785
786 if turn == 0:
787 if len(indices_0) > 0:
788 num_to_select = min(current_bag_size, len(indices_0))
789 selected = indices_0[:num_to_select]
790 indices_0 = indices_0[num_to_select:]
791 else:
792 if len(indices_1) == 0:
793 break
794 num_to_select = min(current_bag_size, len(indices_1))
795 selected = indices_1[:num_to_select]
796 indices_1 = indices_1[num_to_select:]
797 else:
798 if len(indices_1) > 0:
799 num_to_select = min(current_bag_size, len(indices_1))
800 selected = indices_1[:num_to_select]
801 indices_1 = indices_1[num_to_select:]
802 else:
803 if len(indices_0) == 0:
804 break
805 num_to_select = min(current_bag_size, len(indices_0))
806 selected = indices_0[:num_to_select]
807 indices_0 = indices_0[num_to_select:]
808
809 bag_data = df_np[selected]
810 if bag_data.shape[0] == 0:
811 break
812
813 # Aggregate embeddings.
814 vec_col_indices = [
815 df.columns.get_loc(col) for col in vector_columns
816 ]
817 embeddings = bag_data[:, vec_col_indices].astype(np.float32)
818 aggregated_embedding = aggregate_embeddings(
819 embeddings,
820 config.pooling_method,
821 config.use_gpu
822 )
823
824 # Set bag label and metadata.
825 bag_label = int(np.any(bag_data[:, 1] == 1))
826 bsize = bag_data.shape[0]
827 samples = np.unique(bag_data[:, 0])
828 merged_sample_name = ",".join(map(str, samples))
829
830 # Create row for the bag.
831 row = {
832 "sample_name": merged_sample_name,
833 "bag_label": bag_label,
834 "split": split,
835 "bag_size": bsize
836 }
837 for j, val in enumerate(aggregated_embedding):
838 row[f"vector_{j}"] = val
839
840 batch_rows.append(row)
841 bag_count += 1
842 turn = 1 - turn
843
844 # Write batch to file if batch_size is reached.
845 if len(batch_rows) >= batch_size:
846 df_batch = pd.DataFrame(batch_rows)
847 df_batch.to_parquet(
848 bag_file,
849 engine="fastparquet",
850 append=(bag_count > len(batch_rows)),
851 index=False
852 )
853 logging.debug(
854 "Alternating mode: Wrote a batch of %d rows to %s",
855 len(batch_rows),
856 bag_file
857 )
858 batch_rows = []
859 del df_batch
860 gc.collect()
861
862 # Write any remaining rows.
863 if batch_rows:
864 df_batch = pd.DataFrame(batch_rows)
865 df_batch.to_parquet(
866 bag_file,
867 engine="fastparquet",
868 append=(bag_count > len(batch_rows)),
869 index=False
870 )
871 logging.debug(
872 "Wrote the final batch of %d rows to %s",
873 len(batch_rows),
874 bag_file
875 )
876 del df_batch
877 gc.collect()
878
879 logging.info("Created %d bags for split: %s", bag_count, split)
880
881
882 def bag_random(df, split, bag_file, configs, batch_size=500):
883 """
884 Processes the provided DataFrame by randomly selecting instances
885 to create bags.
886 """
887 logging.info("Processing bag randomly for split %s", split)
888
889 # Identify vector columns (exclude non-vector columns).
890 vector_columns = [
891 col for col in df.columns
892 if col not in ["sample_name", "label", "split"]
893 ]
894
895 df_np = df.to_numpy()
896
897 # Create an array of all row indices and shuffle them.
898 indices = np.arange(df.shape[0])
899 np.random.shuffle(indices)
900
901 bag_count = 0
902 batch_rows = []
903
904 # Determine bag size parameters.
905 if len(configs.bag_size) == 1:
906 bag_min = bag_max = configs.bag_size[0]
907 else:
908 bag_min, bag_max = configs.bag_size
909
910 pos = 0
911 total_rows = len(indices)
912
913 # Process until all indices have been used.
914 while pos < total_rows:
915 # Ensuring we do not exceed remaining rows.
916 current_bag_size = (np.random.randint(bag_min, bag_max + 1)
917 if bag_min != bag_max else bag_min)
918 current_bag_size = min(current_bag_size, total_rows - pos)
919
920 # Select the indices for this bag.
921 selected = indices[pos: pos + current_bag_size]
922 pos += current_bag_size
923
924 # Extract the bag data.
925 bag_data = df_np[selected]
926 if bag_data.shape[0] == 0:
927 break
928
929 # Identify the positions of the vector columns using the column names.
930 vec_col_indices = [df.columns.get_loc(col) for col in vector_columns]
931 embeddings = bag_data[:, vec_col_indices].astype(np.float32)
932 aggregated_embedding = aggregate_embeddings(
933 embeddings,
934 configs.pooling_method,
935 configs.use_gpu
936 )
937
938 # Determine bag_label: 1 if any instance in this bag has label == 1.
939 bag_label = int(np.any(bag_data[:, 1] == 1))
940
941 # Merge all sample names from the bag (unique names, comma-separated).
942 samples = np.unique(bag_data[:, 0])
943 merged_sample_name = ",".join(map(str, samples))
944
945 # Use the provided split value.
946 bag_split = split
947 bsize = bag_data.shape[0]
948
949 # Build the output row with header fields:
950 # sample_name, bag_label, split, bag_size, then embeddings.
951 row = {
952 "sample_name": merged_sample_name,
953 "bag_label": bag_label,
954 "split": bag_split,
955 "bag_size": bsize
956 }
957 for j, val in enumerate(aggregated_embedding):
958 row[f"vector_{j}"] = val
959
960 batch_rows.append(row)
961 bag_count += 1
962
963 # Write out rows in batches.
964 if len(batch_rows) >= batch_size:
965 df_batch = pd.DataFrame(batch_rows)
966 # For the first batch,
967 # append=False (header written),
968 # then append=True on subsequent batches.
969 df_batch.to_parquet(
970 bag_file,
971 engine="fastparquet",
972 append=(bag_count > len(batch_rows)),
973 index=False
974 )
975 logging.debug(
976 "Wrote a batch of %d rows to %s",
977 len(batch_rows),
978 bag_file
979 )
980 batch_rows = []
981 del df_batch
982 gc.collect()
983
984 # Write any remaining rows.
985 if batch_rows:
986 df_batch = pd.DataFrame(batch_rows)
987 df_batch.to_parquet(
988 bag_file,
989 engine="fastparquet",
990 append=(bag_count > len(batch_rows)),
991 index=False
992 )
993 logging.debug(
994 "Wrote the final batch of %d rows to %s",
995 len(batch_rows),
996 bag_file
997 )
998 del df_batch
999 gc.collect()
1000
1001 logging.info("Created %d bags for split: %s", bag_count, split)
1002
1003
1004 def imbalance_adjustment(bag_file, split, configs, df):
1005 """
1006 Verifies if the number of bags per label in bag_file is
1007 within imbalance_cap.
1008 If not, generates additional bags for the minority label.
1009
1010 Args:
1011 bag_file (str): Path to the Parquet file containing bags.
1012 split (str): The current split (e.g., 'train', 'val').
1013 config (object): Configuration with imbalance_cap, by_sample, etc.
1014 df (pd.DataFrame): Original DataFrame for generating additional bags.
1015 """
1016 # Read the bag file and count bags per label
1017 bags_df = pd.read_parquet(bag_file)
1018 n0 = (bags_df["bag_label"] == 0).sum()
1019 n1 = (bags_df["bag_label"] == 1).sum()
1020 total = n0 + n1
1021
1022 if total == 0:
1023 logging.warning("No bags found in %s for split %s", bag_file, split)
1024 return
1025
1026 # Calculate imbalance as a percentage
1027 imbalance = abs(n0 - n1) / total * 100
1028 logging.info(
1029 "Split %s: %d bags (label 0: %d, label 1: %d), imbalance %.2f%%",
1030 split, total, n0, n1, imbalance
1031 )
1032
1033 if imbalance > configs.imbalance_cap:
1034 # Identify minority label
1035 min_label = 0 if n0 < n1 else 1
1036 n_min = n0 if min_label == 0 else n1
1037 n_maj = n1 if min_label == 0 else n0
1038
1039 # Calculate how many bags are needed to balance (aim for equality)
1040 num_needed = n_maj - n_min
1041 logging.info(
1042 "Imbalance %.2f%% exceeds cap %.2f%% in split %s, \
1043 need %d bags for label %d",
1044 imbalance,
1045 configs.imbalance_cap,
1046 split,
1047 num_needed,
1048 min_label
1049 )
1050
1051 # Generate additional bags based on the bag creation method
1052 if split in configs.by_sample:
1053 bag_by_sample(
1054 df,
1055 split,
1056 bag_file,
1057 configs,
1058 fixed_target_bags=(min_label, num_needed)
1059 )
1060 else:
1061 bag_in_turns(
1062 df,
1063 split,
1064 bag_file,
1065 configs,
1066 fixed_target_bags=(min_label, num_needed)
1067 )
1068
1069 # Verify the new balance (optional, for logging)
1070 updated_bags_df = pd.read_parquet(bag_file)
1071 new_n0 = (updated_bags_df["bag_label"] == 0).sum()
1072 new_n1 = (updated_bags_df["bag_label"] == 1).sum()
1073 new_total = new_n0 + new_n1
1074 new_imbalance = abs(new_n0 - new_n1) / new_total * 100
1075 logging.info(
1076 "After adjustment, split %s: %d bags (label 0: %d, label 1: %d), \
1077 imbalance %.2f%%",
1078 split,
1079 new_total,
1080 new_n0,
1081 new_n1,
1082 new_imbalance
1083 )
1084 else:
1085 logging.info(
1086 "Imbalance %.2f%% within cap %.2f%% for split %s, \
1087 no adjustment needed",
1088 imbalance,
1089 configs.imbalance_cap,
1090 split
1091 )
1092
1093
1094 def truncate_bag(bag_file, split):
1095 """
1096 Truncates the bags in the bag_file to balance the counts of label 0
1097 and label 1,
1098 ensuring that the file is never left empty (at least one bag remains).
1099
1100 Args:
1101 bag_file (str): Path to the Parquet file containing the bags.
1102 split (str): The current split (e.g., 'train', 'val')
1103 for logging purposes.
1104
1105 Returns:
1106 None: Overwrites the bag_file with the truncated bags,
1107 ensuring at least one bag remains.
1108 """
1109 logging.info("Truncating bags for split %s in file: %s", split, bag_file)
1110
1111 # Step 1: Read the bag file to get the total number of bags
1112 try:
1113 bags_df = pd.read_parquet(bag_file)
1114 except Exception as e:
1115 logging.error("Failed to read bag file %s: %s", bag_file, e)
1116 return
1117
1118 total_bags = len(bags_df)
1119 if total_bags == 0:
1120 logging.warning("No bags found in %s for split %s", bag_file, split)
1121 return
1122
1123 # Step 2: Count bags with label 0 and label 1
1124 n0 = (bags_df["bag_label"] == 0).sum()
1125 n1 = (bags_df["bag_label"] == 1).sum()
1126 logging.info(
1127 "Split %s: Total bags %d (label 0: %d, label 1: %d)",
1128 split,
1129 total_bags,
1130 n0,
1131 n1
1132 )
1133
1134 # Determine the minority count and majority label
1135 min_count = min(n0, n1)
1136 majority_label = 0 if n0 > n1 else 1
1137
1138 if n0 == n1:
1139 logging.info(
1140 "Bags already balanced for split %s, no truncation needed",
1141 split
1142 )
1143 return
1144
1145 # Step 3: Adjust min_count to ensure at least one bag remains
1146 if min_count == 0:
1147 logging.warning(
1148 "Minority label has 0 bags in split %s, keeping 1 bag from \
1149 majority label %d to avoid empty file",
1150 split,
1151 majority_label
1152 )
1153 min_count = 1 # Ensure at least one bag is kept
1154
1155 # Step 4: Truncate excess bags from the majority label
1156 logging.info(
1157 "Truncating %d bags from label %d to match %d bags per label",
1158 max(0, (n0 if majority_label == 0 else n1) - min_count),
1159 majority_label,
1160 min_count
1161 )
1162
1163 # Shuffle the majority label bags to randomly select which to keep
1164 majority_bags = bags_df[
1165 bags_df["bag_label"] == majority_label
1166 ].sample(frac=1, random_state=None)
1167
1168 minority_bags = bags_df[bags_df["bag_label"] != majority_label]
1169
1170 # Keep only min_count bags from the majority label
1171 majority_bags_truncated = majority_bags.iloc[:min_count]
1172
1173 # Combine the truncated majority and minority bags
1174 truncated_bags_df = pd.concat(
1175 [majority_bags_truncated,
1176 minority_bags],
1177 ignore_index=True
1178 )
1179
1180 # Verify that the resulting DataFrame is not empty
1181 if len(truncated_bags_df) == 0:
1182 logging.error(
1183 "Unexpected empty DataFrame after truncation for split %s, \
1184 this should not happen",
1185 split
1186 )
1187 return
1188
1189 # Step 5: Overwrite the bag file with the truncated bags
1190 try:
1191 truncated_bags_df.to_parquet(
1192 bag_file,
1193 engine="fastparquet",
1194 index=False
1195 )
1196 logging.info(
1197 "Overwrote %s with %d balanced bags (label 0: %d, label 1: %d)",
1198 bag_file,
1199 len(truncated_bags_df),
1200 (truncated_bags_df["bag_label"] == 0).sum(),
1201 (truncated_bags_df["bag_label"] == 1).sum()
1202 )
1203 except Exception as e:
1204 logging.error("Failed to overwrite bag file %s: %s", bag_file, e)
1205
1206
1207 def columns_into_string(bag_file):
1208 """
1209 Reads the bag file (Parquet) from the given path, identifies
1210 the vector columns (i.e., columns not among 'sample_name', 'bag_label', 'split',
1211 and 'bag_size'), concatenates these vector values (as strings) into a single
1212 whitespace-separated string wrapped in double quotes, stored in a new column
1213 "embeddings", drops the individual vector columns, and writes the modified
1214 DataFrame back to the same Parquet file.
1215
1216 The final output format is:
1217 "sample_name", "bag_label", "split", "bag_size", "embeddings"
1218 where "embeddings" is a string like: "0.1 0.2 0.3"
1219 """
1220 logging.info(
1221 "Converting vector columns into string for bag file: %s",
1222 bag_file
1223 )
1224
1225 try:
1226 df = pd.read_parquet(bag_file, engine="fastparquet")
1227 except Exception as e:
1228 logging.error("Error reading bag file %s: %s", bag_file, e)
1229 return
1230
1231 # Define non-vector columns.
1232 non_vector = ["sample_name", "bag_label", "split", "bag_size"]
1233
1234 # Identify vector columns.
1235 vector_columns = [col for col in df.columns if col not in non_vector]
1236 logging.info("Identified vector columns: %s", vector_columns)
1237
1238 # Create new 'embeddings' column by converting vector columns to str,
1239 # joining them with whitespace, and wrapping the result in double quotes.
1240 # Use apply() to ensure the result is a Series with one string per row.
1241 df["embeddings"] = df[vector_columns].astype(str).apply(
1242 lambda x: " ".join(x), axis=1
1243 )
1244 # Drop the original vector columns.
1245 df.drop(columns=vector_columns, inplace=True)
1246
1247 try:
1248 # Write the modified DataFrame back to the same bag file.
1249 df.to_parquet(bag_file, engine="fastparquet", index=False)
1250 logging.info(
1251 "Conversion complete. Final columns: %s",
1252 df.columns.tolist()
1253 )
1254 except Exception as e:
1255 logging.error("Error writing updated bag file %s: %s", bag_file, e)
1256
1257
1258 def processing_bag(configs, bag_file, temp_file, split):
1259 """
1260 Processes a single split and writes bag results
1261 directly to the bag output Parquet file.
1262 """
1263 logging.info("Processing split %s using file: %s", split, temp_file)
1264 df = pd.read_parquet(temp_file, engine="fastparquet")
1265
1266 if configs.by_sample is not None and split in configs.by_sample:
1267 bag_by_sample(df, split, bag_file, configs)
1268 elif configs.balance_enforced:
1269 bag_in_turns(df, split, bag_file, configs)
1270 else:
1271 bag_random(df, split, bag_file, configs)
1272
1273 # Free df if imbalance_adjustment is not needed
1274 if configs.imbalance_cap is None:
1275 del df
1276 gc.collect()
1277
1278 if configs.imbalance_cap is not None:
1279 imbalance_adjustment(bag_file, split, configs, df)
1280 del df
1281 gc.collect()
1282 elif configs.truncate_bags:
1283 truncate_bag(bag_file, split)
1284
1285 if configs.ludwig_format:
1286 columns_into_string(bag_file)
1287
1288 return bag_file
1289
1290
1291 def write_final_csv(output_csv, bag_file_paths):
1292 """
1293 Merges all Parquet files into a single CSV file,
1294 processing one file at a time to minimize memory usage.
1295
1296 Args:
1297 output_csv (str): Path to the output CSV file specified
1298 in config.output_csv.
1299 bag_file_paths (list): List of paths to the Parquet files
1300 for each split.
1301
1302 Returns:
1303 str: Path to the output CSV file.
1304 """
1305 logging.info("Merging Parquet files into final CSV: %s", output_csv)
1306
1307 first_file = True # Flag to determine if we need to write the header
1308 total_rows_written = 0
1309
1310 # Process each Parquet file sequentially
1311 for bag_file in bag_file_paths:
1312 try:
1313 # Skip empty or invalid files
1314 if os.path.getsize(bag_file) == 0:
1315 logging.warning(
1316 "Parquet file %s is empty (zero size), skipping",
1317 bag_file
1318 )
1319 continue
1320
1321 # Load the Parquet file into a DataFrame
1322 df = pd.read_parquet(bag_file, engine="fastparquet")
1323 if df.empty:
1324 logging.warning("Parquet file %s is empty, skipping", bag_file)
1325 continue
1326
1327 logging.info("Loaded %d rows from Parquet file: %s, columns: %s",
1328 len(df), bag_file, list(df.columns))
1329
1330 # Write the DataFrame to the CSV file
1331 # - For the first file, write with header (mode='w')
1332 # - For subsequent files, append without header (mode='a')
1333 mode = 'w' if first_file else 'a'
1334 header = first_file # Write header only for the first file
1335 df.to_csv(output_csv, mode=mode, header=header, index=False)
1336 total_rows_written += len(df)
1337
1338 logging.info(
1339 "Wrote %d rows from %s to CSV, total rows written: %d",
1340 len(df), bag_file, total_rows_written
1341 )
1342
1343 # Clear memory
1344 del df
1345 gc.collect()
1346
1347 first_file = False
1348
1349 except Exception as e:
1350 logging.error("Failed to process Parquet file %s: %s", bag_file, e)
1351 continue
1352
1353 # Check if any rows were written
1354 if total_rows_written == 0:
1355 logging.error(
1356 "No valid data loaded from Parquet files, cannot create CSV"
1357 )
1358 raise ValueError("No data available to write to CSV")
1359
1360 logging.info(
1361 "Successfully wrote %d rows to final CSV: %s",
1362 total_rows_written,
1363 output_csv
1364 )
1365 return output_csv
1366
1367
1368 def process_splits(configs, embedding_files, bag_files):
1369 """Processes splits in parallel and returns all bags."""
1370 splits = [0, 1, 2] # Consistent with setup_temp_files()
1371
1372 # Filter non-empty split files
1373 valid_info = []
1374 for split in splits:
1375 temp_file = embedding_files[split]
1376 bag_file = bag_files[split]
1377 if os.path.getsize(temp_file) > 0: # Check if file has content
1378 valid_info.append((configs, bag_file, temp_file, split))
1379 else:
1380 logging.info("Skipping empty split file: %s", temp_file)
1381
1382 if not valid_info:
1383 logging.warning("No non-empty split files to process")
1384 return []
1385
1386 # Process splits in parallel and collect bag file paths
1387 bag_file_paths = []
1388 with mp.Pool(processes=mp.cpu_count()) as pool:
1389 logging.info("Starting multiprocessing")
1390 bag_file_paths = pool.starmap(processing_bag, valid_info)
1391 logging.info("Multiprocessing is done")
1392
1393 # Write the final CSV by merging the Parquet files
1394 output_file = write_final_csv(configs.output_csv, bag_file_paths)
1395 return output_file
1396
1397
1398 def cleanup_temp_files(split_files, bag_outputs):
1399 """Cleans up temporary Parquet files."""
1400 for temp_file in split_files.values():
1401 try:
1402 os.remove(temp_file)
1403 logging.info("Cleaned up temp file: %s", temp_file)
1404 except Exception as e:
1405 logging.error("Error removing %s: %s", temp_file, e)
1406 for bag_output in bag_outputs.values():
1407 try:
1408 os.remove(bag_output)
1409 logging.info("Cleaned up temp bag file: %s", bag_output)
1410 except Exception as e:
1411 logging.error("Error removing %s: %s", bag_output, e)
1412
1413
1414 if __name__ == "__main__":
1415 mp.set_start_method('spawn', force=True)
1416 logging.basicConfig(
1417 level=logging.DEBUG,
1418 format='%(asctime)s - %(levelname)s - %(message)s'
1419 )
1420
1421 parser = argparse.ArgumentParser(
1422 description="Create bags from embeddings and metadata"
1423 )
1424 parser.add_argument(
1425 "--embeddings_csv", type=str, required=True,
1426 help="Path to embeddings CSV"
1427 )
1428 parser.add_argument(
1429 "--metadata_csv", type=str, required=True,
1430 help="Path to metadata CSV"
1431 )
1432 parser.add_argument(
1433 "--split_proportions", type=str, default='0.7,0.1,0.2',
1434 help="Proportions for train, val, test splits"
1435 )
1436 parser.add_argument(
1437 "--dataleak", action="store_true",
1438 help="Prevents data leakage"
1439 )
1440 parser.add_argument(
1441 "--balance_enforced", action="store_true",
1442 help="Enforce balanced bagging"
1443 )
1444 parser.add_argument(
1445 "--bag_size", type=str, required=True,
1446 help="Bag size (e.g., '4' or '3-5')"
1447 )
1448 parser.add_argument(
1449 "--pooling_method", type=str, required=True,
1450 help="Pooling method"
1451 )
1452 parser.add_argument(
1453 "--by_sample", type=str, default=None,
1454 help="Splits to bag by sample"
1455 )
1456 parser.add_argument(
1457 "--repeats", type=int, default=1,
1458 help="Number of bagging repeats"
1459 )
1460 parser.add_argument(
1461 "--ludwig_format", action="store_true",
1462 help="Output in Ludwig format"
1463 )
1464 parser.add_argument(
1465 "--output_csv", type=str, required=True,
1466 help="Path to output CSV"
1467 )
1468 parser.add_argument(
1469 "--random_seed", type=int, default=42,
1470 help="Random seed"
1471 )
1472 parser.add_argument(
1473 "--imbalance_cap", type=int, default=None,
1474 help="Max imbalance percentage"
1475 )
1476 parser.add_argument(
1477 "--truncate_bags", action="store_true",
1478 help="Truncate bags for balance"
1479 )
1480 parser.add_argument(
1481 "--use_gpu", action="store_true",
1482 help="Use GPU for pooling"
1483 )
1484 args = parser.parse_args()
1485
1486 config = BaggingConfig(args)
1487 logging.info("Starting bagging with args: %s", config)
1488
1489 set_random_seed(config)
1490
1491 metadata_csv = load_metadata(config.metadata_csv)
1492 if config.prevent_leakage:
1493 metadata_csv = split_dataset(metadata_csv, config)
1494
1495 split_temp_files, split_bag_outputs = setup_temp_files()
1496
1497 try:
1498 logging.info("Writing embeddings to split temp Parquet files")
1499 distribute_embeddings(config, metadata_csv, split_temp_files)
1500
1501 logging.info("Processing embeddings for each split")
1502 bags = process_splits(config, split_temp_files, split_bag_outputs)
1503 logging.info("Bags processed. File generated: %s", bags)
1504
1505 finally:
1506 cleanup_temp_files(split_temp_files, split_bag_outputs)