Mercurial > repos > goeckslab > bagging_tool
view mil_bag.py @ 0:e6e9ea0703ef draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
author | goeckslab |
---|---|
date | Thu, 19 Jun 2025 23:31:55 +0000 |
parents | |
children |
line wrap: on
line source
""" A script for creating bags of instances from embeddings and metadata for Multiple Instance Learning (MIL) tasks. Processes embedding and metadata CSV files to generate bags of instances, saved as a single CSV file. Supports bagging strategies (by sample, in turns, or random), pooling methods, and options for balancing, preventing data leakage, and Ludwig formatting. Handles large datasets efficiently using temporary Parquet files, sequential processing, and multiprocessing. Dependencies: - gc: For manual garbage collection to manage memory. - argparse: For parsing command-line arguments. - logging: For logging progress and errors. - multiprocessing (mp): For parallel processing. - os: For file operations and temporary file management. - tempfile: For creating temporary files. - numpy (np): For numerical operations and array. - pandas (pd): For data manipulation and I/O (CSV, Parquet). - torch: For tensor operations (attention pooling). - torch.nn: For NN components (attention pooling). - fastparquet: For reading and writing Parquet files. Key Features: - Multiple bagging: by sample (`bag_by_sample`), in turns (`bag_in_turns`), or random (`bag_random`). - Various pooling methods (e.g., max, mean, attention). - Prevents data leakage by splitting at sample level. - Balances bags by label imbalance or truncating. - Outputs in Ludwig format (whitespace-separated vectors). - Efficient large dataset processing (temp Parquet, sequential CSV write). - GPU acceleration for certain pooling (e.g., attention). Usage: Run the script from the command line with arguments: ```bash python ludwig_mil_temp.py --embeddings_csv <path_to_embeddings.csv> --metadata_csv <path_to_metadata.csv> --bag_size <bag_size> --pooling_method <method> --output_csv <output.csv> [--split_proportions <train,val,test>] [--dataleak] [--balance_enforced] [--by_sample <splits>] [--repeats <num>] [--ludwig_format] [--random_seed <seed>] [--imbalance_cap <percentage>] [--truncate_bags] [--use_gpu] """ import argparse import gc import logging import multiprocessing as mp import os import tempfile import numpy as np import pandas as pd import torch import torch.nn as nn def parse_bag_size(bag_size_str): """Parses bag size string into a range or single value.""" try: if '-' in bag_size_str: start, end = map(int, bag_size_str.split('-')) return list(range(start, end + 1)) return [int(bag_size_str)] except ValueError: logging.error("Invalid bag_size format: %s", bag_size_str) raise def parse_by_sample(value): """Parses by_sample string into a set of split values.""" try: value = str(value) splits = [int(x) for x in value.split(",")] valid_splits = {0, 1, 2} if not all(x in valid_splits for x in splits): logging.warning("Invalid splits in by_sample: %s", splits) return None return splits except (ValueError, AttributeError): logging.warning("By_Sample not used") return None class BaggingConfig: """Configuration class for bagging parameters.""" def __init__(self, params): self.embeddings_csv = params.embeddings_csv self.metadata_csv = params.metadata_csv self.split_proportions = params.split_proportions self.prevent_leakage = params.dataleak self.balance_enforced = params.balance_enforced self.bag_size = parse_bag_size(params.bag_size) self.pooling_method = params.pooling_method self.by_sample = parse_by_sample(params.by_sample) self.repeats = params.repeats self.ludwig_format = params.ludwig_format self.output_csv = params.output_csv self.random_seed = params.random_seed self.imbalance_cap = params.imbalance_cap self.truncate_bags = params.truncate_bags self.use_gpu = params.use_gpu def __str__(self): """String representation of the config for logging.""" return ( f"embeddings_csv={self.embeddings_csv}, " f"metadata_csv={self.metadata_csv}, " f"split_proportions={self.split_proportions}, " f"prevent_leakage={self.prevent_leakage}, " f"balance_enforced={self.balance_enforced}, " f"bag_size={self.bag_size}, " f"pooling_method={self.pooling_method}, " f"by_sample={self.by_sample}, " f"repeats={self.repeats}, " f"ludwig_format={self.ludwig_format}, " f"output_csv={self.output_csv}, " f"random_seed={self.random_seed}, " f"imbalance_cap={self.imbalance_cap}, " f"truncate_bags={self.truncate_bags}, " f"use_gpu={self.use_gpu}" ) def set_random_seed(configs): """Sets random seeds for reproducibility.""" np.random.seed(configs.random_seed) torch.manual_seed(configs.random_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(configs.random_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False logging.info("Random seed set to %d", configs.random_seed) def validate_metadata(metadata): """Validates metadata for required columns.""" required_cols = {"sample_name", "label"} if not required_cols.issubset(metadata.columns): missing = required_cols - set(metadata.columns) raise ValueError(f"Metadata missing columns: {missing}") return metadata def load_metadata(file_path): """Loads metadata from a CSV file.""" metadata = pd.read_csv(file_path) validate_metadata(metadata) logging.info("Metadata loaded with %d samples, cols: %s", len(metadata), list(metadata.columns)) logging.info("Unique samples: %d, labels: %d", metadata["sample_name"].nunique(), metadata["label"].nunique()) return metadata def convert_proportions(proportion_string): """Converts a string of split proportions into a list of floats.""" proportion_list = [float(p) for p in proportion_string.split(",")] print(proportion_list) if len(proportion_list) == 2: proportion_list = [proportion_list[0], 0.0, proportion_list[1]] for proportion in proportion_list: if proportion < 0 or proportion > 1: raise ValueError("Each proportion must be between 0 and 1") if abs(sum(proportion_list) - 1.0) > 1e-6: raise ValueError("Proportions must sum to approximately 1.0") return proportion_list def calculate_split_counts(total_samples, proportions): """Calculates sample counts for each split.""" counts = [int(p * total_samples) for p in proportions] calculated_total = sum(counts) if calculated_total < total_samples: counts[-1] += total_samples - calculated_total elif calculated_total > total_samples: counts[0] -= calculated_total - total_samples return counts def assign_split_labels(proportions, sample_count): """Assigns split labels based on proportions.""" proportion_values = convert_proportions(proportions) train_fraction, val_fraction, test_fraction = proportion_values if val_fraction == 0 and test_fraction == 0: labels = np.zeros(sample_count, dtype=int) elif val_fraction == 0: train_size = int(train_fraction * sample_count) test_size = sample_count - train_size labels = np.array([0] * train_size + [2] * test_size) else: split_counts = calculate_split_counts(sample_count, proportion_values) labels = np.concatenate([ np.zeros(split_counts[0], dtype=int), np.ones(split_counts[1], dtype=int), 2 * np.ones(split_counts[2], dtype=int) ]) return labels def split_dataset(metadata, configs): """Splits dataset into train, val, test sets if prevent_leakage is True.""" if configs.prevent_leakage: logging.info("No data leakage allowed") unique_samples = metadata["sample_name"].unique() sample_count = len(unique_samples) split_labels = assign_split_labels(configs.split_proportions, sample_count) shuffled_samples = np.random.permutation(unique_samples) label_series = pd.Series(split_labels, index=shuffled_samples) metadata["split"] = metadata["sample_name"].map(label_series) train_count = (metadata["split"] == 0).sum() val_count = (metadata["split"] == 1).sum() test_count = (metadata["split"] == 2).sum() logging.info("Dataset split: train %d, val %d, test %d", train_count, val_count, test_count) else: logging.info("Data leakage allowed setup") return metadata def assign_chunk_splits(chunk, split_counts, current_counts): """Assigns split labels to a chunk of embeddings.""" chunk_size = len(chunk) remaining = { 0: split_counts[0] - current_counts[0], 1: split_counts[1] - current_counts[1], 2: split_counts[2] - current_counts[2] } available_splits = [s for s, count in remaining.items() if count > 0] if not available_splits: return chunk, current_counts total_remaining = sum(remaining.values()) assign_count = min(chunk_size, total_remaining) if assign_count == 0: return chunk, current_counts weights = [remaining[s] / total_remaining for s in available_splits] splits = np.random.choice(available_splits, size=assign_count, p=weights) chunk["split"] = pd.Series(splits, index=chunk.index[:assign_count]) chunk["split"] = chunk["split"].fillna(0).astype(int) for split in available_splits: current_counts[split] += np.sum(splits == split) return chunk, current_counts def setup_temp_files(): """Sets up temporary Parquet files for splits and bag outputs.""" splits = [0, 1, 2] split_files = {} for split in splits: fd, path = tempfile.mkstemp(prefix=f"split_{split}_", suffix=".parquet", dir=os.getcwd()) os.close(fd) # Explicitly close the file descriptor split_files[split] = path bag_outputs = {} for split in splits: fd, path = tempfile.mkstemp(prefix=f"MIL_bags_{split}_", suffix=".parquet", dir=os.getcwd()) os.close(fd) # Explicitly close the file descriptor bag_outputs[split] = path return split_files, bag_outputs def distribute_embeddings(configs, metadata, split_files): embeddings_path = configs.embeddings_csv proportion_string = configs.split_proportions prevent_leakage = configs.prevent_leakage logging.info("Distributing embeddings from %s to Parquet files", embeddings_path) buffer_size = 50000 merged_header = None non_sample_columns = None if not prevent_leakage: logging.warning( "Counting rows in %s; may be slow for large files", embeddings_path ) total_rows = sum(1 for _ in open(embeddings_path)) - 1 proportions = convert_proportions(proportion_string) split_counts = calculate_split_counts(total_rows, proportions) current_counts = {0: 0, 1: 0, 2: 0} else: sample_to_split = dict(zip(metadata["sample_name"], metadata["split"])) sample_to_label = dict(zip(metadata["sample_name"], metadata["label"])) first_write = {split: True for split in split_files} try: first_header_read = True for chunk in pd.read_csv(embeddings_path, chunksize=buffer_size): # Modify 'sample_name' to remove part after the last underscore chunk['sample_name'] = chunk['sample_name'].apply(lambda x: x.rsplit('_', 1)[0]) if first_header_read: orig_header = list(chunk.columns) non_sample_columns = [ col for col in orig_header if col != "sample_name" ] merged_header = ["sample_name", "label"] + non_sample_columns logging.info("Merged header: %s", merged_header) first_header_read = False if prevent_leakage: chunk["split"] = chunk["sample_name"].map(sample_to_split) chunk["label"] = chunk["sample_name"].map(sample_to_label) else: chunk, current_counts = assign_chunk_splits(chunk, split_counts, current_counts) chunk = chunk.merge(metadata[["sample_name", "label"]], on="sample_name", how="left") chunk = chunk.dropna(subset=["split", "label"]) for split in split_files: split_chunk = chunk[chunk["split"] == split] if not split_chunk.empty: temp_file = split_files[split] split_chunk[merged_header].to_parquet( temp_file, engine="fastparquet", append=not first_write[split], index=False ) first_write[split] = False del chunk gc.collect() except Exception as e: logging.error("Error distributing embeddings to Parquet: %s", e) raise def aggregate_embeddings(embeddings, pooling_method, use_gpu=False): # Convert embeddings to a float32 array explicitly. embeddings = np.asarray(embeddings, dtype=np.float32) if embeddings.ndim == 1: embeddings = embeddings.reshape(1, -1) elif embeddings.ndim == 0: embeddings = embeddings.reshape(1, 1) logging.debug("Aggregating embeddings with shape: %s", embeddings.shape) if pooling_method == "max_pooling": result = np.max(embeddings, axis=0) elif pooling_method == "mean_pooling": result = np.mean(embeddings, axis=0) elif pooling_method == "sum_pooling": result = np.sum(embeddings, axis=0) elif pooling_method == "min_pooling": result = np.min(embeddings, axis=0) elif pooling_method == "median_pooling": result = np.median(embeddings, axis=0) elif pooling_method == "l2_norm_pooling": norm = np.linalg.norm(embeddings, axis=1, keepdims=True) if norm.any(): result = np.mean(embeddings / (norm + 1e-8), axis=0) else: result = np.mean(embeddings, axis=0) elif pooling_method == "geometric_mean_pooling": clipped = np.clip(embeddings, 1e-10, None) result = np.exp(np.mean(np.log(clipped), axis=0)) elif pooling_method == "first_embedding": result = embeddings[0] elif pooling_method == "last_embedding": result = embeddings[-1] elif pooling_method == "attention_pooling": device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu' tensor = torch.tensor(embeddings, dtype=torch.float32).to(device) with torch.no_grad(): linear = nn.Linear(tensor.shape[1], 1).to(device) weights = nn.Softmax(dim=0)(linear(tensor)) result = torch.sum(weights * tensor, dim=0).cpu().detach().numpy() else: raise ValueError(f"Unknown pooling method: {pooling_method}") logging.debug("Aggregated embedding shape: %s", result.shape) return result def bag_by_sample(df, split, bag_file, config, batch_size=1000, fixed_target_bags=None): """ Processes the provided DataFrame by grouping rows by sample, constructs bags from each sample group using the configured bag_size, and writes the bag rows directly to bag_file (a Parquet file) in batches. Args: df (pd.DataFrame): The DataFrame containing the data. split (str): The split identifier (e.g., 'train', 'val'). bag_file (str): The path to the Parquet file to write the bags. config (object): Configuration object with bag_size, pooling_method... batch_size (int, optional): The number of rows to write in each batch. fixed_target_bags (tuple, optional): (target_label, num_bags) to generate bags only for target_label. Output row format: sample_name, bag_label, split, bag_size, vector_0, vector_1, vector_N """ log_msg = f"Processing by sample for split: {split}" if fixed_target_bags: log_msg += f" with fixed target {fixed_target_bags}" logging.info(log_msg) batch_rows = [] bag_count = 0 vector_columns = [ col for col in df.columns if col not in ["sample_name", "label", "split"] ] if fixed_target_bags is not None: target_label, target_needed = fixed_target_bags target_samples = list( df[df["label"] == target_label]["sample_name"].unique() ) df = df[df["sample_name"].isin(target_samples)] if df.empty: logging.warning( "No samples available for target label %d in split %s", target_label, split ) return available_samples = target_samples.copy() np.random.shuffle(available_samples) while bag_count < target_needed: if len(available_samples) == 0: available_samples = target_samples.copy() np.random.shuffle(available_samples) logging.info( "Reusing samples for target label %d in split %s", target_label, split ) sample_name = available_samples.pop() group = df[df["sample_name"] == sample_name] embeddings = group[vector_columns].values num_instances = len(group) current_bag_size = config.bag_size[0] \ if len(config.bag_size) == 1 else \ np.random.randint(config.bag_size[0], config.bag_size[1] + 1) current_bag_size = min(current_bag_size, num_instances) selected = group.sample(n=current_bag_size, replace=True) bag_embeddings = selected[vector_columns].values aggregated_embedding = aggregate_embeddings( bag_embeddings, config.pooling_method, config.use_gpu ) bag_label = int(any(selected["label"] == 1)) if bag_label != target_label: logging.warning( "Generated bag for target %d but got label %d", target_label, bag_label ) continue row = { "sample_name": sample_name, "bag_label": bag_label, "split": split, "bag_size": current_bag_size } for j, val in enumerate(aggregated_embedding): row[f"vector_{j}"] = val batch_rows.append(row) bag_count += 1 if len(batch_rows) >= batch_size: df_batch = pd.DataFrame(batch_rows) # Check if the file has data to determine append mode append_mode = os.path.getsize(bag_file) > 0 df_batch.to_parquet( bag_file, engine="fastparquet", append=append_mode, index=False ) logging.debug( "Fixed mode: Wrote batch of %d rows to %s", len(batch_rows), bag_file ) batch_rows = [] del df_batch gc.collect() else: # Standard mode: process all samples groups = df.groupby("sample_name") for sample_name, group in groups: embeddings = group[vector_columns].values labels = group["label"].values num_instances = len(group) current_bag_size = config.bag_size[0] \ if len(config.bag_size) == 1 else \ np.random.randint( config.bag_size[0], config.bag_size[1] + 1 ) num_bags = ( num_instances + current_bag_size - 1 ) // current_bag_size logging.info( "Sample %s: %d instances, creating %d bags (bag size %d)", sample_name, num_instances, num_bags, current_bag_size ) for i in range(num_bags): start_idx = i * current_bag_size end_idx = min(start_idx + current_bag_size, num_instances) bag_embeddings = embeddings[start_idx:end_idx] bag_labels = labels[start_idx:end_idx] aggregated_embedding = aggregate_embeddings( bag_embeddings, config.pooling_method, config.use_gpu ) bag_label = int(any(bag_labels == 1)) row = { "sample_name": sample_name, "bag_label": bag_label, "split": split, "bag_size": end_idx - start_idx } for j, val in enumerate(aggregated_embedding): row[f"vector_{j}"] = val batch_rows.append(row) bag_count += 1 if len(batch_rows) >= batch_size: df_batch = pd.DataFrame(batch_rows) # Check if the file has data to determine append mode append_mode = os.path.getsize(bag_file) > 0 df_batch.to_parquet( bag_file, engine="fastparquet", append=append_mode, index=False ) logging.debug( "Wrote batch of %d rows to %s", len(batch_rows), bag_file ) batch_rows = [] del df_batch gc.collect() # Write any remaining rows if batch_rows: df_batch = pd.DataFrame(batch_rows) append_mode = os.path.getsize(bag_file) > 0 df_batch.to_parquet( bag_file, engine="fastparquet", append=append_mode, index=False ) logging.debug( "Wrote final batch of %d rows to %s", len(batch_rows), bag_file ) del df_batch gc.collect() logging.info("Created %d bags for split: %s", bag_count, split) def bag_in_turns(df, split, bag_file, config, batch_size=500, fixed_target_bags=None, allow_reuse=True): """ Generate bags of instances from a DataFrame, with optional fixed-target mode, data reuse, and enhanced diversity. Parameters: - df (pd.DataFrame): Input DataFrame with columns including 'sample_name', 'label', 'split', and embedding vectors. - split (str): Dataset split (e.g., 'train', 'test'). - bag_file (str): Path to save the output Parquet file. - config (object): Configuration object with attributes 'bag_size', 'pooling_method', and 'use_gpu'. - batch_size (int): Number of bags to process before writing to file (default: 500). - fixed_target_bags (tuple): Optional (label, num_bags) to generate bags for a specific label (e.g., (0, 100)). - allow_reuse (bool): Allow resampling instances with replacement if True (default: True). Returns: - None: Saves bags to the specified Parquet file. """ logging.info( "Processing bag in turns for split %s%s", split, (" with fixed target " + str(fixed_target_bags)) if fixed_target_bags is not None else "" ) # Identify embedding columns (exclude non-vector columns). vector_columns = [ col for col in df.columns if col not in ["sample_name", "label", "split"] ] # Convert the DataFrame to a NumPy array for faster processing. df_np = df.to_numpy() # Determine bag size range from config. if len(config.bag_size) == 1: bag_min = bag_max = config.bag_size[0] else: bag_min, bag_max = config.bag_size batch_rows = [] bag_count = 0 if fixed_target_bags is not None: # Fixed-target mode: generate bags for a specific label. target, target_needed = fixed_target_bags # e.g., (0, 100) if target == 0: # Optimize for target label 0: remove all label 1 instances indices = np.where(df_np[:, 1] == 0)[0] logging.info( "Fixed mode: target label 0, using only label 0 instances, \ total available %d rows", len(indices) ) else: # For target label 1, use all instances to allow mixing indices = np.arange(len(df_np)) logging.info( "Fixed mode: target label 1, using all instances, \ total available %d rows", len(indices) ) total_available = len(indices) while bag_count < target_needed: current_bag_size = np.random.randint(bag_min, bag_max + 1) \ if bag_min != bag_max else bag_min if total_available < current_bag_size and not allow_reuse: logging.warning( "Not enough instances (%d) for bag size %d and \ target label %d", total_available, current_bag_size, target ) break # Sample instances selected = np.random.choice( indices, size=current_bag_size, replace=allow_reuse ) bag_data = df_np[selected] if target == 1: # For positive bags, ensure at least one instance has label 1 if not np.any(bag_data[:, 1] == 1): continue # Skip if no positive instance bag_label = 1 else: # For negative bags, all instances are label 0 due to filtering bag_label = 0 # Aggregate embeddings. vec_col_indices = [ df.columns.get_loc(col) for col in vector_columns ] embeddings = bag_data[:, vec_col_indices].astype(np.float32) aggregated_embedding = aggregate_embeddings( embeddings, config.pooling_method, config.use_gpu ) # Set bag metadata. bsize = bag_data.shape[0] samples = np.unique(bag_data[:, 0]) merged_sample_name = ",".join(map(str, samples)) # Create row for the bag. row = { "sample_name": merged_sample_name, "bag_label": bag_label, "split": split, "bag_size": bsize } for j, val in enumerate(aggregated_embedding): row[f"vector_{j}"] = val batch_rows.append(row) bag_count += 1 if len(batch_rows) >= batch_size: df_batch = pd.DataFrame(batch_rows) df_batch.to_parquet( bag_file, engine="fastparquet", append=True, index=False ) logging.debug( "Fixed mode: Wrote a batch of %d rows to %s", len(batch_rows), bag_file ) batch_rows = [] del df_batch gc.collect() # Write any remaining rows. if batch_rows: df_batch = pd.DataFrame(batch_rows) df_batch.to_parquet( bag_file, engine="fastparquet", append=True, index=False ) logging.debug( "Wrote the final batch of %d rows to %s", len(batch_rows), bag_file ) del df_batch gc.collect() logging.info("Created %d bags for split: %s", bag_count, split) else: # Alternating mode: alternate between labels 0 and 1. indices_0 = np.where(df_np[:, 1] == 0)[0] indices_1 = np.where(df_np[:, 1] == 1)[0] np.random.shuffle(indices_0) np.random.shuffle(indices_1) turn = 0 # 0: label 0, 1: label 1. while len(indices_0) > 0 or len(indices_1) > 0: current_bag_size = np.random.randint(bag_min, bag_max + 1) \ if bag_min != bag_max else bag_min if turn == 0: if len(indices_0) > 0: num_to_select = min(current_bag_size, len(indices_0)) selected = indices_0[:num_to_select] indices_0 = indices_0[num_to_select:] else: if len(indices_1) == 0: break num_to_select = min(current_bag_size, len(indices_1)) selected = indices_1[:num_to_select] indices_1 = indices_1[num_to_select:] else: if len(indices_1) > 0: num_to_select = min(current_bag_size, len(indices_1)) selected = indices_1[:num_to_select] indices_1 = indices_1[num_to_select:] else: if len(indices_0) == 0: break num_to_select = min(current_bag_size, len(indices_0)) selected = indices_0[:num_to_select] indices_0 = indices_0[num_to_select:] bag_data = df_np[selected] if bag_data.shape[0] == 0: break # Aggregate embeddings. vec_col_indices = [ df.columns.get_loc(col) for col in vector_columns ] embeddings = bag_data[:, vec_col_indices].astype(np.float32) aggregated_embedding = aggregate_embeddings( embeddings, config.pooling_method, config.use_gpu ) # Set bag label and metadata. bag_label = int(np.any(bag_data[:, 1] == 1)) bsize = bag_data.shape[0] samples = np.unique(bag_data[:, 0]) merged_sample_name = ",".join(map(str, samples)) # Create row for the bag. row = { "sample_name": merged_sample_name, "bag_label": bag_label, "split": split, "bag_size": bsize } for j, val in enumerate(aggregated_embedding): row[f"vector_{j}"] = val batch_rows.append(row) bag_count += 1 turn = 1 - turn # Write batch to file if batch_size is reached. if len(batch_rows) >= batch_size: df_batch = pd.DataFrame(batch_rows) df_batch.to_parquet( bag_file, engine="fastparquet", append=(bag_count > len(batch_rows)), index=False ) logging.debug( "Alternating mode: Wrote a batch of %d rows to %s", len(batch_rows), bag_file ) batch_rows = [] del df_batch gc.collect() # Write any remaining rows. if batch_rows: df_batch = pd.DataFrame(batch_rows) df_batch.to_parquet( bag_file, engine="fastparquet", append=(bag_count > len(batch_rows)), index=False ) logging.debug( "Wrote the final batch of %d rows to %s", len(batch_rows), bag_file ) del df_batch gc.collect() logging.info("Created %d bags for split: %s", bag_count, split) def bag_random(df, split, bag_file, configs, batch_size=500): """ Processes the provided DataFrame by randomly selecting instances to create bags. """ logging.info("Processing bag randomly for split %s", split) # Identify vector columns (exclude non-vector columns). vector_columns = [ col for col in df.columns if col not in ["sample_name", "label", "split"] ] df_np = df.to_numpy() # Create an array of all row indices and shuffle them. indices = np.arange(df.shape[0]) np.random.shuffle(indices) bag_count = 0 batch_rows = [] # Determine bag size parameters. if len(configs.bag_size) == 1: bag_min = bag_max = configs.bag_size[0] else: bag_min, bag_max = configs.bag_size pos = 0 total_rows = len(indices) # Process until all indices have been used. while pos < total_rows: # Ensuring we do not exceed remaining rows. current_bag_size = (np.random.randint(bag_min, bag_max + 1) if bag_min != bag_max else bag_min) current_bag_size = min(current_bag_size, total_rows - pos) # Select the indices for this bag. selected = indices[pos: pos + current_bag_size] pos += current_bag_size # Extract the bag data. bag_data = df_np[selected] if bag_data.shape[0] == 0: break # Identify the positions of the vector columns using the column names. vec_col_indices = [df.columns.get_loc(col) for col in vector_columns] embeddings = bag_data[:, vec_col_indices].astype(np.float32) aggregated_embedding = aggregate_embeddings( embeddings, configs.pooling_method, configs.use_gpu ) # Determine bag_label: 1 if any instance in this bag has label == 1. bag_label = int(np.any(bag_data[:, 1] == 1)) # Merge all sample names from the bag (unique names, comma-separated). samples = np.unique(bag_data[:, 0]) merged_sample_name = ",".join(map(str, samples)) # Use the provided split value. bag_split = split bsize = bag_data.shape[0] # Build the output row with header fields: # sample_name, bag_label, split, bag_size, then embeddings. row = { "sample_name": merged_sample_name, "bag_label": bag_label, "split": bag_split, "bag_size": bsize } for j, val in enumerate(aggregated_embedding): row[f"vector_{j}"] = val batch_rows.append(row) bag_count += 1 # Write out rows in batches. if len(batch_rows) >= batch_size: df_batch = pd.DataFrame(batch_rows) # For the first batch, # append=False (header written), # then append=True on subsequent batches. df_batch.to_parquet( bag_file, engine="fastparquet", append=(bag_count > len(batch_rows)), index=False ) logging.debug( "Wrote a batch of %d rows to %s", len(batch_rows), bag_file ) batch_rows = [] del df_batch gc.collect() # Write any remaining rows. if batch_rows: df_batch = pd.DataFrame(batch_rows) df_batch.to_parquet( bag_file, engine="fastparquet", append=(bag_count > len(batch_rows)), index=False ) logging.debug( "Wrote the final batch of %d rows to %s", len(batch_rows), bag_file ) del df_batch gc.collect() logging.info("Created %d bags for split: %s", bag_count, split) def imbalance_adjustment(bag_file, split, configs, df): """ Verifies if the number of bags per label in bag_file is within imbalance_cap. If not, generates additional bags for the minority label. Args: bag_file (str): Path to the Parquet file containing bags. split (str): The current split (e.g., 'train', 'val'). config (object): Configuration with imbalance_cap, by_sample, etc. df (pd.DataFrame): Original DataFrame for generating additional bags. """ # Read the bag file and count bags per label bags_df = pd.read_parquet(bag_file) n0 = (bags_df["bag_label"] == 0).sum() n1 = (bags_df["bag_label"] == 1).sum() total = n0 + n1 if total == 0: logging.warning("No bags found in %s for split %s", bag_file, split) return # Calculate imbalance as a percentage imbalance = abs(n0 - n1) / total * 100 logging.info( "Split %s: %d bags (label 0: %d, label 1: %d), imbalance %.2f%%", split, total, n0, n1, imbalance ) if imbalance > configs.imbalance_cap: # Identify minority label min_label = 0 if n0 < n1 else 1 n_min = n0 if min_label == 0 else n1 n_maj = n1 if min_label == 0 else n0 # Calculate how many bags are needed to balance (aim for equality) num_needed = n_maj - n_min logging.info( "Imbalance %.2f%% exceeds cap %.2f%% in split %s, \ need %d bags for label %d", imbalance, configs.imbalance_cap, split, num_needed, min_label ) # Generate additional bags based on the bag creation method if split in configs.by_sample: bag_by_sample( df, split, bag_file, configs, fixed_target_bags=(min_label, num_needed) ) else: bag_in_turns( df, split, bag_file, configs, fixed_target_bags=(min_label, num_needed) ) # Verify the new balance (optional, for logging) updated_bags_df = pd.read_parquet(bag_file) new_n0 = (updated_bags_df["bag_label"] == 0).sum() new_n1 = (updated_bags_df["bag_label"] == 1).sum() new_total = new_n0 + new_n1 new_imbalance = abs(new_n0 - new_n1) / new_total * 100 logging.info( "After adjustment, split %s: %d bags (label 0: %d, label 1: %d), \ imbalance %.2f%%", split, new_total, new_n0, new_n1, new_imbalance ) else: logging.info( "Imbalance %.2f%% within cap %.2f%% for split %s, \ no adjustment needed", imbalance, configs.imbalance_cap, split ) def truncate_bag(bag_file, split): """ Truncates the bags in the bag_file to balance the counts of label 0 and label 1, ensuring that the file is never left empty (at least one bag remains). Args: bag_file (str): Path to the Parquet file containing the bags. split (str): The current split (e.g., 'train', 'val') for logging purposes. Returns: None: Overwrites the bag_file with the truncated bags, ensuring at least one bag remains. """ logging.info("Truncating bags for split %s in file: %s", split, bag_file) # Step 1: Read the bag file to get the total number of bags try: bags_df = pd.read_parquet(bag_file) except Exception as e: logging.error("Failed to read bag file %s: %s", bag_file, e) return total_bags = len(bags_df) if total_bags == 0: logging.warning("No bags found in %s for split %s", bag_file, split) return # Step 2: Count bags with label 0 and label 1 n0 = (bags_df["bag_label"] == 0).sum() n1 = (bags_df["bag_label"] == 1).sum() logging.info( "Split %s: Total bags %d (label 0: %d, label 1: %d)", split, total_bags, n0, n1 ) # Determine the minority count and majority label min_count = min(n0, n1) majority_label = 0 if n0 > n1 else 1 if n0 == n1: logging.info( "Bags already balanced for split %s, no truncation needed", split ) return # Step 3: Adjust min_count to ensure at least one bag remains if min_count == 0: logging.warning( "Minority label has 0 bags in split %s, keeping 1 bag from \ majority label %d to avoid empty file", split, majority_label ) min_count = 1 # Ensure at least one bag is kept # Step 4: Truncate excess bags from the majority label logging.info( "Truncating %d bags from label %d to match %d bags per label", max(0, (n0 if majority_label == 0 else n1) - min_count), majority_label, min_count ) # Shuffle the majority label bags to randomly select which to keep majority_bags = bags_df[ bags_df["bag_label"] == majority_label ].sample(frac=1, random_state=None) minority_bags = bags_df[bags_df["bag_label"] != majority_label] # Keep only min_count bags from the majority label majority_bags_truncated = majority_bags.iloc[:min_count] # Combine the truncated majority and minority bags truncated_bags_df = pd.concat( [majority_bags_truncated, minority_bags], ignore_index=True ) # Verify that the resulting DataFrame is not empty if len(truncated_bags_df) == 0: logging.error( "Unexpected empty DataFrame after truncation for split %s, \ this should not happen", split ) return # Step 5: Overwrite the bag file with the truncated bags try: truncated_bags_df.to_parquet( bag_file, engine="fastparquet", index=False ) logging.info( "Overwrote %s with %d balanced bags (label 0: %d, label 1: %d)", bag_file, len(truncated_bags_df), (truncated_bags_df["bag_label"] == 0).sum(), (truncated_bags_df["bag_label"] == 1).sum() ) except Exception as e: logging.error("Failed to overwrite bag file %s: %s", bag_file, e) def columns_into_string(bag_file): """ Reads the bag file (Parquet) from the given path, identifies the vector columns (i.e., columns not among 'sample_name', 'bag_label', 'split', and 'bag_size'), concatenates these vector values (as strings) into a single whitespace-separated string wrapped in double quotes, stored in a new column "embeddings", drops the individual vector columns, and writes the modified DataFrame back to the same Parquet file. The final output format is: "sample_name", "bag_label", "split", "bag_size", "embeddings" where "embeddings" is a string like: "0.1 0.2 0.3" """ logging.info( "Converting vector columns into string for bag file: %s", bag_file ) try: df = pd.read_parquet(bag_file, engine="fastparquet") except Exception as e: logging.error("Error reading bag file %s: %s", bag_file, e) return # Define non-vector columns. non_vector = ["sample_name", "bag_label", "split", "bag_size"] # Identify vector columns. vector_columns = [col for col in df.columns if col not in non_vector] logging.info("Identified vector columns: %s", vector_columns) # Create new 'embeddings' column by converting vector columns to str, # joining them with whitespace, and wrapping the result in double quotes. # Use apply() to ensure the result is a Series with one string per row. df["embeddings"] = df[vector_columns].astype(str).apply( lambda x: " ".join(x), axis=1 ) # Drop the original vector columns. df.drop(columns=vector_columns, inplace=True) try: # Write the modified DataFrame back to the same bag file. df.to_parquet(bag_file, engine="fastparquet", index=False) logging.info( "Conversion complete. Final columns: %s", df.columns.tolist() ) except Exception as e: logging.error("Error writing updated bag file %s: %s", bag_file, e) def processing_bag(configs, bag_file, temp_file, split): """ Processes a single split and writes bag results directly to the bag output Parquet file. """ logging.info("Processing split %s using file: %s", split, temp_file) df = pd.read_parquet(temp_file, engine="fastparquet") if configs.by_sample is not None and split in configs.by_sample: bag_by_sample(df, split, bag_file, configs) elif configs.balance_enforced: bag_in_turns(df, split, bag_file, configs) else: bag_random(df, split, bag_file, configs) # Free df if imbalance_adjustment is not needed if configs.imbalance_cap is None: del df gc.collect() if configs.imbalance_cap is not None: imbalance_adjustment(bag_file, split, configs, df) del df gc.collect() elif configs.truncate_bags: truncate_bag(bag_file, split) if configs.ludwig_format: columns_into_string(bag_file) return bag_file def write_final_csv(output_csv, bag_file_paths): """ Merges all Parquet files into a single CSV file, processing one file at a time to minimize memory usage. Args: output_csv (str): Path to the output CSV file specified in config.output_csv. bag_file_paths (list): List of paths to the Parquet files for each split. Returns: str: Path to the output CSV file. """ logging.info("Merging Parquet files into final CSV: %s", output_csv) first_file = True # Flag to determine if we need to write the header total_rows_written = 0 # Process each Parquet file sequentially for bag_file in bag_file_paths: try: # Skip empty or invalid files if os.path.getsize(bag_file) == 0: logging.warning( "Parquet file %s is empty (zero size), skipping", bag_file ) continue # Load the Parquet file into a DataFrame df = pd.read_parquet(bag_file, engine="fastparquet") if df.empty: logging.warning("Parquet file %s is empty, skipping", bag_file) continue logging.info("Loaded %d rows from Parquet file: %s, columns: %s", len(df), bag_file, list(df.columns)) # Write the DataFrame to the CSV file # - For the first file, write with header (mode='w') # - For subsequent files, append without header (mode='a') mode = 'w' if first_file else 'a' header = first_file # Write header only for the first file df.to_csv(output_csv, mode=mode, header=header, index=False) total_rows_written += len(df) logging.info( "Wrote %d rows from %s to CSV, total rows written: %d", len(df), bag_file, total_rows_written ) # Clear memory del df gc.collect() first_file = False except Exception as e: logging.error("Failed to process Parquet file %s: %s", bag_file, e) continue # Check if any rows were written if total_rows_written == 0: logging.error( "No valid data loaded from Parquet files, cannot create CSV" ) raise ValueError("No data available to write to CSV") logging.info( "Successfully wrote %d rows to final CSV: %s", total_rows_written, output_csv ) return output_csv def process_splits(configs, embedding_files, bag_files): """Processes splits in parallel and returns all bags.""" splits = [0, 1, 2] # Consistent with setup_temp_files() # Filter non-empty split files valid_info = [] for split in splits: temp_file = embedding_files[split] bag_file = bag_files[split] if os.path.getsize(temp_file) > 0: # Check if file has content valid_info.append((configs, bag_file, temp_file, split)) else: logging.info("Skipping empty split file: %s", temp_file) if not valid_info: logging.warning("No non-empty split files to process") return [] # Process splits in parallel and collect bag file paths bag_file_paths = [] with mp.Pool(processes=mp.cpu_count()) as pool: logging.info("Starting multiprocessing") bag_file_paths = pool.starmap(processing_bag, valid_info) logging.info("Multiprocessing is done") # Write the final CSV by merging the Parquet files output_file = write_final_csv(configs.output_csv, bag_file_paths) return output_file def cleanup_temp_files(split_files, bag_outputs): """Cleans up temporary Parquet files.""" for temp_file in split_files.values(): try: os.remove(temp_file) logging.info("Cleaned up temp file: %s", temp_file) except Exception as e: logging.error("Error removing %s: %s", temp_file, e) for bag_output in bag_outputs.values(): try: os.remove(bag_output) logging.info("Cleaned up temp bag file: %s", bag_output) except Exception as e: logging.error("Error removing %s: %s", bag_output, e) if __name__ == "__main__": mp.set_start_method('spawn', force=True) logging.basicConfig( level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s' ) parser = argparse.ArgumentParser( description="Create bags from embeddings and metadata" ) parser.add_argument( "--embeddings_csv", type=str, required=True, help="Path to embeddings CSV" ) parser.add_argument( "--metadata_csv", type=str, required=True, help="Path to metadata CSV" ) parser.add_argument( "--split_proportions", type=str, default='0.7,0.1,0.2', help="Proportions for train, val, test splits" ) parser.add_argument( "--dataleak", action="store_true", help="Prevents data leakage" ) parser.add_argument( "--balance_enforced", action="store_true", help="Enforce balanced bagging" ) parser.add_argument( "--bag_size", type=str, required=True, help="Bag size (e.g., '4' or '3-5')" ) parser.add_argument( "--pooling_method", type=str, required=True, help="Pooling method" ) parser.add_argument( "--by_sample", type=str, default=None, help="Splits to bag by sample" ) parser.add_argument( "--repeats", type=int, default=1, help="Number of bagging repeats" ) parser.add_argument( "--ludwig_format", action="store_true", help="Output in Ludwig format" ) parser.add_argument( "--output_csv", type=str, required=True, help="Path to output CSV" ) parser.add_argument( "--random_seed", type=int, default=42, help="Random seed" ) parser.add_argument( "--imbalance_cap", type=int, default=None, help="Max imbalance percentage" ) parser.add_argument( "--truncate_bags", action="store_true", help="Truncate bags for balance" ) parser.add_argument( "--use_gpu", action="store_true", help="Use GPU for pooling" ) args = parser.parse_args() config = BaggingConfig(args) logging.info("Starting bagging with args: %s", config) set_random_seed(config) metadata_csv = load_metadata(config.metadata_csv) if config.prevent_leakage: metadata_csv = split_dataset(metadata_csv, config) split_temp_files, split_bag_outputs = setup_temp_files() try: logging.info("Writing embeddings to split temp Parquet files") distribute_embeddings(config, metadata_csv, split_temp_files) logging.info("Processing embeddings for each split") bags = process_splits(config, split_temp_files, split_bag_outputs) logging.info("Bags processed. File generated: %s", bags) finally: cleanup_temp_files(split_temp_files, split_bag_outputs)