Mercurial > repos > goeckslab > bagging_tool
annotate mil_bag.py @ 0:e6e9ea0703ef draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
author | goeckslab |
---|---|
date | Thu, 19 Jun 2025 23:31:55 +0000 |
parents | |
children |
rev | line source |
---|---|
0
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
2 A script for creating bags of instances from embeddings |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
3 and metadata for Multiple Instance Learning (MIL) tasks. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
4 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
5 Processes embedding and metadata CSV files to generate |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
6 bags of instances, saved as a single CSV file. Supports |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
7 bagging strategies (by sample, in turns, or random), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
8 pooling methods, and options for balancing, preventing |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
9 data leakage, and Ludwig formatting. Handles large |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
10 datasets efficiently using temporary Parquet files, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
11 sequential processing, and multiprocessing. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
12 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
13 Dependencies: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
14 - gc: For manual garbage collection to manage memory. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
15 - argparse: For parsing command-line arguments. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
16 - logging: For logging progress and errors. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
17 - multiprocessing (mp): For parallel processing. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
18 - os: For file operations and temporary file management. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
19 - tempfile: For creating temporary files. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
20 - numpy (np): For numerical operations and array. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
21 - pandas (pd): For data manipulation and I/O (CSV, Parquet). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
22 - torch: For tensor operations (attention pooling). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
23 - torch.nn: For NN components (attention pooling). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
24 - fastparquet: For reading and writing Parquet files. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
25 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
26 Key Features: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
27 - Multiple bagging: by sample (`bag_by_sample`), in |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
28 turns (`bag_in_turns`), or random (`bag_random`). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
29 - Various pooling methods (e.g., max, mean, attention). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
30 - Prevents data leakage by splitting at sample level. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
31 - Balances bags by label imbalance or truncating. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
32 - Outputs in Ludwig format (whitespace-separated vectors). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
33 - Efficient large dataset processing (temp Parquet, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
34 sequential CSV write). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
35 - GPU acceleration for certain pooling (e.g., attention). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
36 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
37 Usage: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
38 Run the script from the command line with arguments: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
39 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
40 ```bash |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
41 python ludwig_mil_temp.py --embeddings_csv <path_to_embeddings.csv> |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
42 --metadata_csv <path_to_metadata.csv> --bag_size <bag_size> |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
43 --pooling_method <method> --output_csv <output.csv> |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
44 [--split_proportions <train,val,test>] [--dataleak] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
45 [--balance_enforced] [--by_sample <splits>] [--repeats <num>] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
46 [--ludwig_format] [--random_seed <seed>] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
47 [--imbalance_cap <percentage>] [--truncate_bags] [--use_gpu] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
48 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
49 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
50 import argparse |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
51 import gc |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
52 import logging |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
53 import multiprocessing as mp |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
54 import os |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
55 import tempfile |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
56 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
57 import numpy as np |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
58 import pandas as pd |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
59 import torch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
60 import torch.nn as nn |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
61 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
62 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
63 def parse_bag_size(bag_size_str): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
64 """Parses bag size string into a range or single value.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
65 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
66 if '-' in bag_size_str: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
67 start, end = map(int, bag_size_str.split('-')) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
68 return list(range(start, end + 1)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
69 return [int(bag_size_str)] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
70 except ValueError: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
71 logging.error("Invalid bag_size format: %s", bag_size_str) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
72 raise |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
73 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
74 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
75 def parse_by_sample(value): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
76 """Parses by_sample string into a set of split values.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
77 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
78 value = str(value) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
79 splits = [int(x) for x in value.split(",")] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
80 valid_splits = {0, 1, 2} |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
81 if not all(x in valid_splits for x in splits): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
82 logging.warning("Invalid splits in by_sample: %s", splits) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
83 return None |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
84 return splits |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
85 except (ValueError, AttributeError): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
86 logging.warning("By_Sample not used") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
87 return None |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
88 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
89 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
90 class BaggingConfig: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
91 """Configuration class for bagging parameters.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
92 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
93 def __init__(self, params): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
94 self.embeddings_csv = params.embeddings_csv |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
95 self.metadata_csv = params.metadata_csv |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
96 self.split_proportions = params.split_proportions |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
97 self.prevent_leakage = params.dataleak |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
98 self.balance_enforced = params.balance_enforced |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
99 self.bag_size = parse_bag_size(params.bag_size) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
100 self.pooling_method = params.pooling_method |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
101 self.by_sample = parse_by_sample(params.by_sample) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
102 self.repeats = params.repeats |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
103 self.ludwig_format = params.ludwig_format |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
104 self.output_csv = params.output_csv |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
105 self.random_seed = params.random_seed |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
106 self.imbalance_cap = params.imbalance_cap |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
107 self.truncate_bags = params.truncate_bags |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
108 self.use_gpu = params.use_gpu |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
109 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
110 def __str__(self): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
111 """String representation of the config for logging.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
112 return ( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
113 f"embeddings_csv={self.embeddings_csv}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
114 f"metadata_csv={self.metadata_csv}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
115 f"split_proportions={self.split_proportions}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
116 f"prevent_leakage={self.prevent_leakage}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
117 f"balance_enforced={self.balance_enforced}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
118 f"bag_size={self.bag_size}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
119 f"pooling_method={self.pooling_method}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
120 f"by_sample={self.by_sample}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
121 f"repeats={self.repeats}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
122 f"ludwig_format={self.ludwig_format}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
123 f"output_csv={self.output_csv}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
124 f"random_seed={self.random_seed}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
125 f"imbalance_cap={self.imbalance_cap}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
126 f"truncate_bags={self.truncate_bags}, " |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
127 f"use_gpu={self.use_gpu}" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
128 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
129 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
130 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
131 def set_random_seed(configs): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
132 """Sets random seeds for reproducibility.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
133 np.random.seed(configs.random_seed) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
134 torch.manual_seed(configs.random_seed) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
135 if torch.cuda.is_available(): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
136 torch.cuda.manual_seed_all(configs.random_seed) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
137 torch.backends.cudnn.deterministic = True |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
138 torch.backends.cudnn.benchmark = False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
139 logging.info("Random seed set to %d", configs.random_seed) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
140 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
141 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
142 def validate_metadata(metadata): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
143 """Validates metadata for required columns.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
144 required_cols = {"sample_name", "label"} |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
145 if not required_cols.issubset(metadata.columns): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
146 missing = required_cols - set(metadata.columns) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
147 raise ValueError(f"Metadata missing columns: {missing}") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
148 return metadata |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
149 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
150 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
151 def load_metadata(file_path): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
152 """Loads metadata from a CSV file.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
153 metadata = pd.read_csv(file_path) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
154 validate_metadata(metadata) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
155 logging.info("Metadata loaded with %d samples, cols: %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
156 len(metadata), list(metadata.columns)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
157 logging.info("Unique samples: %d, labels: %d", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
158 metadata["sample_name"].nunique(), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
159 metadata["label"].nunique()) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
160 return metadata |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
161 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
162 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
163 def convert_proportions(proportion_string): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
164 """Converts a string of split proportions into a list of floats.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
165 proportion_list = [float(p) for p in proportion_string.split(",")] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
166 print(proportion_list) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
167 if len(proportion_list) == 2: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
168 proportion_list = [proportion_list[0], 0.0, proportion_list[1]] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
169 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
170 for proportion in proportion_list: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
171 if proportion < 0 or proportion > 1: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
172 raise ValueError("Each proportion must be between 0 and 1") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
173 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
174 if abs(sum(proportion_list) - 1.0) > 1e-6: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
175 raise ValueError("Proportions must sum to approximately 1.0") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
176 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
177 return proportion_list |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
178 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
179 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
180 def calculate_split_counts(total_samples, proportions): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
181 """Calculates sample counts for each split.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
182 counts = [int(p * total_samples) for p in proportions] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
183 calculated_total = sum(counts) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
184 if calculated_total < total_samples: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
185 counts[-1] += total_samples - calculated_total |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
186 elif calculated_total > total_samples: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
187 counts[0] -= calculated_total - total_samples |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
188 return counts |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
189 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
190 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
191 def assign_split_labels(proportions, sample_count): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
192 """Assigns split labels based on proportions.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
193 proportion_values = convert_proportions(proportions) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
194 train_fraction, val_fraction, test_fraction = proportion_values |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
195 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
196 if val_fraction == 0 and test_fraction == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
197 labels = np.zeros(sample_count, dtype=int) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
198 elif val_fraction == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
199 train_size = int(train_fraction * sample_count) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
200 test_size = sample_count - train_size |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
201 labels = np.array([0] * train_size + [2] * test_size) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
202 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
203 split_counts = calculate_split_counts(sample_count, proportion_values) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
204 labels = np.concatenate([ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
205 np.zeros(split_counts[0], dtype=int), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
206 np.ones(split_counts[1], dtype=int), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
207 2 * np.ones(split_counts[2], dtype=int) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
208 ]) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
209 return labels |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
210 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
211 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
212 def split_dataset(metadata, configs): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
213 """Splits dataset into train, val, test sets if prevent_leakage is True.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
214 if configs.prevent_leakage: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
215 logging.info("No data leakage allowed") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
216 unique_samples = metadata["sample_name"].unique() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
217 sample_count = len(unique_samples) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
218 split_labels = assign_split_labels(configs.split_proportions, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
219 sample_count) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
220 shuffled_samples = np.random.permutation(unique_samples) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
221 label_series = pd.Series(split_labels, index=shuffled_samples) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
222 metadata["split"] = metadata["sample_name"].map(label_series) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
223 train_count = (metadata["split"] == 0).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
224 val_count = (metadata["split"] == 1).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
225 test_count = (metadata["split"] == 2).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
226 logging.info("Dataset split: train %d, val %d, test %d", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
227 train_count, val_count, test_count) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
228 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
229 logging.info("Data leakage allowed setup") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
230 return metadata |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
231 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
232 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
233 def assign_chunk_splits(chunk, split_counts, current_counts): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
234 """Assigns split labels to a chunk of embeddings.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
235 chunk_size = len(chunk) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
236 remaining = { |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
237 0: split_counts[0] - current_counts[0], |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
238 1: split_counts[1] - current_counts[1], |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
239 2: split_counts[2] - current_counts[2] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
240 } |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
241 available_splits = [s for s, count in remaining.items() if count > 0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
242 if not available_splits: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
243 return chunk, current_counts |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
244 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
245 total_remaining = sum(remaining.values()) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
246 assign_count = min(chunk_size, total_remaining) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
247 if assign_count == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
248 return chunk, current_counts |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
249 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
250 weights = [remaining[s] / total_remaining for s in available_splits] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
251 splits = np.random.choice(available_splits, size=assign_count, p=weights) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
252 chunk["split"] = pd.Series(splits, index=chunk.index[:assign_count]) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
253 chunk["split"] = chunk["split"].fillna(0).astype(int) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
254 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
255 for split in available_splits: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
256 current_counts[split] += np.sum(splits == split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
257 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
258 return chunk, current_counts |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
259 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
260 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
261 def setup_temp_files(): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
262 """Sets up temporary Parquet files for splits and bag outputs.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
263 splits = [0, 1, 2] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
264 split_files = {} |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
265 for split in splits: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
266 fd, path = tempfile.mkstemp(prefix=f"split_{split}_", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
267 suffix=".parquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
268 dir=os.getcwd()) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
269 os.close(fd) # Explicitly close the file descriptor |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
270 split_files[split] = path |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
271 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
272 bag_outputs = {} |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
273 for split in splits: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
274 fd, path = tempfile.mkstemp(prefix=f"MIL_bags_{split}_", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
275 suffix=".parquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
276 dir=os.getcwd()) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
277 os.close(fd) # Explicitly close the file descriptor |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
278 bag_outputs[split] = path |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
279 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
280 return split_files, bag_outputs |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
281 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
282 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
283 def distribute_embeddings(configs, metadata, split_files): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
284 embeddings_path = configs.embeddings_csv |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
285 proportion_string = configs.split_proportions |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
286 prevent_leakage = configs.prevent_leakage |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
287 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
288 logging.info("Distributing embeddings from %s to Parquet files", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
289 embeddings_path) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
290 buffer_size = 50000 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
291 merged_header = None |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
292 non_sample_columns = None |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
293 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
294 if not prevent_leakage: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
295 logging.warning( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
296 "Counting rows in %s; may be slow for large files", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
297 embeddings_path |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
298 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
299 total_rows = sum(1 for _ in open(embeddings_path)) - 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
300 proportions = convert_proportions(proportion_string) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
301 split_counts = calculate_split_counts(total_rows, proportions) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
302 current_counts = {0: 0, 1: 0, 2: 0} |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
303 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
304 sample_to_split = dict(zip(metadata["sample_name"], metadata["split"])) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
305 sample_to_label = dict(zip(metadata["sample_name"], metadata["label"])) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
306 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
307 first_write = {split: True for split in split_files} |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
308 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
309 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
310 first_header_read = True |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
311 for chunk in pd.read_csv(embeddings_path, chunksize=buffer_size): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
312 # Modify 'sample_name' to remove part after the last underscore |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
313 chunk['sample_name'] = chunk['sample_name'].apply(lambda x: x.rsplit('_', 1)[0]) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
314 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
315 if first_header_read: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
316 orig_header = list(chunk.columns) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
317 non_sample_columns = [ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
318 col for col in orig_header if col != "sample_name" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
319 ] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
320 merged_header = ["sample_name", "label"] + non_sample_columns |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
321 logging.info("Merged header: %s", merged_header) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
322 first_header_read = False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
323 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
324 if prevent_leakage: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
325 chunk["split"] = chunk["sample_name"].map(sample_to_split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
326 chunk["label"] = chunk["sample_name"].map(sample_to_label) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
327 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
328 chunk, current_counts = assign_chunk_splits(chunk, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
329 split_counts, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
330 current_counts) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
331 chunk = chunk.merge(metadata[["sample_name", "label"]], |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
332 on="sample_name", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
333 how="left") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
334 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
335 chunk = chunk.dropna(subset=["split", "label"]) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
336 for split in split_files: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
337 split_chunk = chunk[chunk["split"] == split] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
338 if not split_chunk.empty: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
339 temp_file = split_files[split] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
340 split_chunk[merged_header].to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
341 temp_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
342 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
343 append=not first_write[split], |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
344 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
345 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
346 first_write[split] = False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
347 del chunk |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
348 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
349 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
350 except Exception as e: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
351 logging.error("Error distributing embeddings to Parquet: %s", e) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
352 raise |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
353 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
354 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
355 def aggregate_embeddings(embeddings, pooling_method, use_gpu=False): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
356 # Convert embeddings to a float32 array explicitly. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
357 embeddings = np.asarray(embeddings, dtype=np.float32) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
358 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
359 if embeddings.ndim == 1: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
360 embeddings = embeddings.reshape(1, -1) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
361 elif embeddings.ndim == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
362 embeddings = embeddings.reshape(1, 1) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
363 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
364 logging.debug("Aggregating embeddings with shape: %s", embeddings.shape) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
365 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
366 if pooling_method == "max_pooling": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
367 result = np.max(embeddings, axis=0) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
368 elif pooling_method == "mean_pooling": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
369 result = np.mean(embeddings, axis=0) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
370 elif pooling_method == "sum_pooling": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
371 result = np.sum(embeddings, axis=0) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
372 elif pooling_method == "min_pooling": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
373 result = np.min(embeddings, axis=0) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
374 elif pooling_method == "median_pooling": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
375 result = np.median(embeddings, axis=0) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
376 elif pooling_method == "l2_norm_pooling": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
377 norm = np.linalg.norm(embeddings, axis=1, keepdims=True) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
378 if norm.any(): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
379 result = np.mean(embeddings / (norm + 1e-8), axis=0) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
380 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
381 result = np.mean(embeddings, axis=0) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
382 elif pooling_method == "geometric_mean_pooling": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
383 clipped = np.clip(embeddings, 1e-10, None) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
384 result = np.exp(np.mean(np.log(clipped), axis=0)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
385 elif pooling_method == "first_embedding": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
386 result = embeddings[0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
387 elif pooling_method == "last_embedding": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
388 result = embeddings[-1] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
389 elif pooling_method == "attention_pooling": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
390 device = 'cuda' if use_gpu and torch.cuda.is_available() else 'cpu' |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
391 tensor = torch.tensor(embeddings, dtype=torch.float32).to(device) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
392 with torch.no_grad(): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
393 linear = nn.Linear(tensor.shape[1], 1).to(device) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
394 weights = nn.Softmax(dim=0)(linear(tensor)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
395 result = torch.sum(weights * tensor, dim=0).cpu().detach().numpy() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
396 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
397 raise ValueError(f"Unknown pooling method: {pooling_method}") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
398 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
399 logging.debug("Aggregated embedding shape: %s", result.shape) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
400 return result |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
401 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
402 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
403 def bag_by_sample(df, split, bag_file, config, batch_size=1000, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
404 fixed_target_bags=None): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
405 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
406 Processes the provided DataFrame by grouping rows by sample, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
407 constructs bags from each sample group using the configured bag_size, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
408 and writes the bag rows directly to bag_file (a Parquet file) in batches. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
409 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
410 Args: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
411 df (pd.DataFrame): The DataFrame containing the data. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
412 split (str): The split identifier (e.g., 'train', 'val'). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
413 bag_file (str): The path to the Parquet file to write the bags. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
414 config (object): Configuration object with bag_size, pooling_method... |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
415 batch_size (int, optional): The number of rows to write in each batch. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
416 fixed_target_bags (tuple, optional): (target_label, num_bags) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
417 to generate bags only for target_label. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
418 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
419 Output row format: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
420 sample_name, bag_label, split, bag_size, vector_0, vector_1, vector_N |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
421 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
422 log_msg = f"Processing by sample for split: {split}" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
423 if fixed_target_bags: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
424 log_msg += f" with fixed target {fixed_target_bags}" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
425 logging.info(log_msg) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
426 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
427 batch_rows = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
428 bag_count = 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
429 vector_columns = [ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
430 col for col in df.columns |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
431 if col not in ["sample_name", "label", "split"] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
432 ] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
433 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
434 if fixed_target_bags is not None: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
435 target_label, target_needed = fixed_target_bags |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
436 target_samples = list( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
437 df[df["label"] == target_label]["sample_name"].unique() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
438 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
439 df = df[df["sample_name"].isin(target_samples)] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
440 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
441 if df.empty: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
442 logging.warning( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
443 "No samples available for target label %d in split %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
444 target_label, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
445 split |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
446 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
447 return |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
448 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
449 available_samples = target_samples.copy() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
450 np.random.shuffle(available_samples) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
451 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
452 while bag_count < target_needed: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
453 if len(available_samples) == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
454 available_samples = target_samples.copy() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
455 np.random.shuffle(available_samples) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
456 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
457 "Reusing samples for target label %d in split %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
458 target_label, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
459 split |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
460 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
461 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
462 sample_name = available_samples.pop() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
463 group = df[df["sample_name"] == sample_name] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
464 embeddings = group[vector_columns].values |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
465 num_instances = len(group) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
466 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
467 current_bag_size = config.bag_size[0] \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
468 if len(config.bag_size) == 1 else \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
469 np.random.randint(config.bag_size[0], config.bag_size[1] + 1) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
470 current_bag_size = min(current_bag_size, num_instances) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
471 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
472 selected = group.sample(n=current_bag_size, replace=True) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
473 bag_embeddings = selected[vector_columns].values |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
474 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
475 aggregated_embedding = aggregate_embeddings( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
476 bag_embeddings, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
477 config.pooling_method, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
478 config.use_gpu |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
479 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
480 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
481 bag_label = int(any(selected["label"] == 1)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
482 if bag_label != target_label: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
483 logging.warning( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
484 "Generated bag for target %d but got label %d", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
485 target_label, bag_label |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
486 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
487 continue |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
488 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
489 row = { |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
490 "sample_name": sample_name, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
491 "bag_label": bag_label, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
492 "split": split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
493 "bag_size": current_bag_size |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
494 } |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
495 for j, val in enumerate(aggregated_embedding): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
496 row[f"vector_{j}"] = val |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
497 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
498 batch_rows.append(row) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
499 bag_count += 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
500 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
501 if len(batch_rows) >= batch_size: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
502 df_batch = pd.DataFrame(batch_rows) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
503 # Check if the file has data to determine append mode |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
504 append_mode = os.path.getsize(bag_file) > 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
505 df_batch.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
506 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
507 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
508 append=append_mode, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
509 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
510 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
511 logging.debug( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
512 "Fixed mode: Wrote batch of %d rows to %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
513 len(batch_rows), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
514 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
515 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
516 batch_rows = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
517 del df_batch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
518 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
519 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
520 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
521 # Standard mode: process all samples |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
522 groups = df.groupby("sample_name") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
523 for sample_name, group in groups: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
524 embeddings = group[vector_columns].values |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
525 labels = group["label"].values |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
526 num_instances = len(group) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
527 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
528 current_bag_size = config.bag_size[0] \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
529 if len(config.bag_size) == 1 else \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
530 np.random.randint( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
531 config.bag_size[0], |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
532 config.bag_size[1] + 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
533 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
534 num_bags = ( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
535 num_instances + current_bag_size - 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
536 ) // current_bag_size |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
537 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
538 "Sample %s: %d instances, creating %d bags (bag size %d)", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
539 sample_name, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
540 num_instances, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
541 num_bags, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
542 current_bag_size |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
543 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
544 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
545 for i in range(num_bags): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
546 start_idx = i * current_bag_size |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
547 end_idx = min(start_idx + current_bag_size, num_instances) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
548 bag_embeddings = embeddings[start_idx:end_idx] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
549 bag_labels = labels[start_idx:end_idx] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
550 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
551 aggregated_embedding = aggregate_embeddings( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
552 bag_embeddings, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
553 config.pooling_method, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
554 config.use_gpu |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
555 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
556 bag_label = int(any(bag_labels == 1)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
557 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
558 row = { |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
559 "sample_name": sample_name, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
560 "bag_label": bag_label, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
561 "split": split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
562 "bag_size": end_idx - start_idx |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
563 } |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
564 for j, val in enumerate(aggregated_embedding): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
565 row[f"vector_{j}"] = val |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
566 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
567 batch_rows.append(row) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
568 bag_count += 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
569 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
570 if len(batch_rows) >= batch_size: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
571 df_batch = pd.DataFrame(batch_rows) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
572 # Check if the file has data to determine append mode |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
573 append_mode = os.path.getsize(bag_file) > 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
574 df_batch.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
575 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
576 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
577 append=append_mode, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
578 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
579 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
580 logging.debug( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
581 "Wrote batch of %d rows to %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
582 len(batch_rows), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
583 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
584 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
585 batch_rows = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
586 del df_batch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
587 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
588 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
589 # Write any remaining rows |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
590 if batch_rows: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
591 df_batch = pd.DataFrame(batch_rows) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
592 append_mode = os.path.getsize(bag_file) > 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
593 df_batch.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
594 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
595 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
596 append=append_mode, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
597 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
598 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
599 logging.debug( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
600 "Wrote final batch of %d rows to %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
601 len(batch_rows), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
602 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
603 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
604 del df_batch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
605 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
606 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
607 logging.info("Created %d bags for split: %s", bag_count, split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
608 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
609 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
610 def bag_in_turns(df, split, bag_file, config, batch_size=500, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
611 fixed_target_bags=None, allow_reuse=True): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
612 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
613 Generate bags of instances from a DataFrame, with optional |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
614 fixed-target mode, data reuse, and enhanced diversity. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
615 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
616 Parameters: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
617 - df (pd.DataFrame): Input DataFrame with columns including |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
618 'sample_name', 'label', 'split', and embedding vectors. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
619 - split (str): Dataset split (e.g., 'train', 'test'). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
620 - bag_file (str): Path to save the output Parquet file. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
621 - config (object): Configuration object with attributes |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
622 'bag_size', 'pooling_method', and 'use_gpu'. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
623 - batch_size (int): Number of bags to process before writing |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
624 to file (default: 500). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
625 - fixed_target_bags (tuple): Optional (label, num_bags) to |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
626 generate bags for a specific label (e.g., (0, 100)). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
627 - allow_reuse (bool): Allow resampling instances with |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
628 replacement if True (default: True). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
629 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
630 Returns: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
631 - None: Saves bags to the specified Parquet file. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
632 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
633 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
634 "Processing bag in turns for split %s%s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
635 split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
636 (" with fixed target " + str(fixed_target_bags)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
637 if fixed_target_bags is not None else "" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
638 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
639 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
640 # Identify embedding columns (exclude non-vector columns). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
641 vector_columns = [ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
642 col for col in df.columns |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
643 if col not in ["sample_name", "label", "split"] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
644 ] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
645 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
646 # Convert the DataFrame to a NumPy array for faster processing. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
647 df_np = df.to_numpy() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
648 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
649 # Determine bag size range from config. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
650 if len(config.bag_size) == 1: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
651 bag_min = bag_max = config.bag_size[0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
652 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
653 bag_min, bag_max = config.bag_size |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
654 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
655 batch_rows = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
656 bag_count = 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
657 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
658 if fixed_target_bags is not None: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
659 # Fixed-target mode: generate bags for a specific label. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
660 target, target_needed = fixed_target_bags # e.g., (0, 100) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
661 if target == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
662 # Optimize for target label 0: remove all label 1 instances |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
663 indices = np.where(df_np[:, 1] == 0)[0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
664 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
665 "Fixed mode: target label 0, using only label 0 instances, \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
666 total available %d rows", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
667 len(indices) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
668 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
669 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
670 # For target label 1, use all instances to allow mixing |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
671 indices = np.arange(len(df_np)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
672 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
673 "Fixed mode: target label 1, using all instances, \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
674 total available %d rows", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
675 len(indices) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
676 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
677 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
678 total_available = len(indices) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
679 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
680 while bag_count < target_needed: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
681 current_bag_size = np.random.randint(bag_min, bag_max + 1) \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
682 if bag_min != bag_max else bag_min |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
683 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
684 if total_available < current_bag_size and not allow_reuse: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
685 logging.warning( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
686 "Not enough instances (%d) for bag size %d and \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
687 target label %d", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
688 total_available, current_bag_size, target |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
689 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
690 break |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
691 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
692 # Sample instances |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
693 selected = np.random.choice( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
694 indices, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
695 size=current_bag_size, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
696 replace=allow_reuse |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
697 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
698 bag_data = df_np[selected] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
699 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
700 if target == 1: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
701 # For positive bags, ensure at least one instance has label 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
702 if not np.any(bag_data[:, 1] == 1): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
703 continue # Skip if no positive instance |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
704 bag_label = 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
705 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
706 # For negative bags, all instances are label 0 due to filtering |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
707 bag_label = 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
708 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
709 # Aggregate embeddings. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
710 vec_col_indices = [ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
711 df.columns.get_loc(col) for col in vector_columns |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
712 ] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
713 embeddings = bag_data[:, vec_col_indices].astype(np.float32) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
714 aggregated_embedding = aggregate_embeddings( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
715 embeddings, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
716 config.pooling_method, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
717 config.use_gpu |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
718 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
719 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
720 # Set bag metadata. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
721 bsize = bag_data.shape[0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
722 samples = np.unique(bag_data[:, 0]) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
723 merged_sample_name = ",".join(map(str, samples)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
724 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
725 # Create row for the bag. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
726 row = { |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
727 "sample_name": merged_sample_name, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
728 "bag_label": bag_label, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
729 "split": split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
730 "bag_size": bsize |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
731 } |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
732 for j, val in enumerate(aggregated_embedding): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
733 row[f"vector_{j}"] = val |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
734 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
735 batch_rows.append(row) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
736 bag_count += 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
737 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
738 if len(batch_rows) >= batch_size: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
739 df_batch = pd.DataFrame(batch_rows) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
740 df_batch.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
741 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
742 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
743 append=True, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
744 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
745 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
746 logging.debug( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
747 "Fixed mode: Wrote a batch of %d rows to %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
748 len(batch_rows), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
749 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
750 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
751 batch_rows = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
752 del df_batch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
753 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
754 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
755 # Write any remaining rows. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
756 if batch_rows: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
757 df_batch = pd.DataFrame(batch_rows) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
758 df_batch.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
759 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
760 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
761 append=True, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
762 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
763 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
764 logging.debug( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
765 "Wrote the final batch of %d rows to %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
766 len(batch_rows), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
767 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
768 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
769 del df_batch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
770 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
771 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
772 logging.info("Created %d bags for split: %s", bag_count, split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
773 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
774 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
775 # Alternating mode: alternate between labels 0 and 1. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
776 indices_0 = np.where(df_np[:, 1] == 0)[0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
777 indices_1 = np.where(df_np[:, 1] == 1)[0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
778 np.random.shuffle(indices_0) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
779 np.random.shuffle(indices_1) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
780 turn = 0 # 0: label 0, 1: label 1. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
781 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
782 while len(indices_0) > 0 or len(indices_1) > 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
783 current_bag_size = np.random.randint(bag_min, bag_max + 1) \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
784 if bag_min != bag_max else bag_min |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
785 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
786 if turn == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
787 if len(indices_0) > 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
788 num_to_select = min(current_bag_size, len(indices_0)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
789 selected = indices_0[:num_to_select] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
790 indices_0 = indices_0[num_to_select:] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
791 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
792 if len(indices_1) == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
793 break |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
794 num_to_select = min(current_bag_size, len(indices_1)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
795 selected = indices_1[:num_to_select] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
796 indices_1 = indices_1[num_to_select:] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
797 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
798 if len(indices_1) > 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
799 num_to_select = min(current_bag_size, len(indices_1)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
800 selected = indices_1[:num_to_select] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
801 indices_1 = indices_1[num_to_select:] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
802 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
803 if len(indices_0) == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
804 break |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
805 num_to_select = min(current_bag_size, len(indices_0)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
806 selected = indices_0[:num_to_select] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
807 indices_0 = indices_0[num_to_select:] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
808 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
809 bag_data = df_np[selected] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
810 if bag_data.shape[0] == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
811 break |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
812 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
813 # Aggregate embeddings. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
814 vec_col_indices = [ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
815 df.columns.get_loc(col) for col in vector_columns |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
816 ] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
817 embeddings = bag_data[:, vec_col_indices].astype(np.float32) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
818 aggregated_embedding = aggregate_embeddings( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
819 embeddings, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
820 config.pooling_method, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
821 config.use_gpu |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
822 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
823 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
824 # Set bag label and metadata. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
825 bag_label = int(np.any(bag_data[:, 1] == 1)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
826 bsize = bag_data.shape[0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
827 samples = np.unique(bag_data[:, 0]) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
828 merged_sample_name = ",".join(map(str, samples)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
829 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
830 # Create row for the bag. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
831 row = { |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
832 "sample_name": merged_sample_name, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
833 "bag_label": bag_label, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
834 "split": split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
835 "bag_size": bsize |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
836 } |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
837 for j, val in enumerate(aggregated_embedding): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
838 row[f"vector_{j}"] = val |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
839 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
840 batch_rows.append(row) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
841 bag_count += 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
842 turn = 1 - turn |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
843 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
844 # Write batch to file if batch_size is reached. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
845 if len(batch_rows) >= batch_size: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
846 df_batch = pd.DataFrame(batch_rows) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
847 df_batch.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
848 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
849 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
850 append=(bag_count > len(batch_rows)), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
851 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
852 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
853 logging.debug( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
854 "Alternating mode: Wrote a batch of %d rows to %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
855 len(batch_rows), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
856 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
857 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
858 batch_rows = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
859 del df_batch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
860 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
861 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
862 # Write any remaining rows. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
863 if batch_rows: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
864 df_batch = pd.DataFrame(batch_rows) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
865 df_batch.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
866 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
867 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
868 append=(bag_count > len(batch_rows)), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
869 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
870 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
871 logging.debug( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
872 "Wrote the final batch of %d rows to %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
873 len(batch_rows), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
874 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
875 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
876 del df_batch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
877 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
878 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
879 logging.info("Created %d bags for split: %s", bag_count, split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
880 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
881 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
882 def bag_random(df, split, bag_file, configs, batch_size=500): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
883 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
884 Processes the provided DataFrame by randomly selecting instances |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
885 to create bags. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
886 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
887 logging.info("Processing bag randomly for split %s", split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
888 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
889 # Identify vector columns (exclude non-vector columns). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
890 vector_columns = [ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
891 col for col in df.columns |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
892 if col not in ["sample_name", "label", "split"] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
893 ] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
894 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
895 df_np = df.to_numpy() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
896 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
897 # Create an array of all row indices and shuffle them. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
898 indices = np.arange(df.shape[0]) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
899 np.random.shuffle(indices) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
900 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
901 bag_count = 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
902 batch_rows = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
903 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
904 # Determine bag size parameters. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
905 if len(configs.bag_size) == 1: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
906 bag_min = bag_max = configs.bag_size[0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
907 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
908 bag_min, bag_max = configs.bag_size |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
909 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
910 pos = 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
911 total_rows = len(indices) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
912 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
913 # Process until all indices have been used. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
914 while pos < total_rows: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
915 # Ensuring we do not exceed remaining rows. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
916 current_bag_size = (np.random.randint(bag_min, bag_max + 1) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
917 if bag_min != bag_max else bag_min) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
918 current_bag_size = min(current_bag_size, total_rows - pos) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
919 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
920 # Select the indices for this bag. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
921 selected = indices[pos: pos + current_bag_size] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
922 pos += current_bag_size |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
923 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
924 # Extract the bag data. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
925 bag_data = df_np[selected] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
926 if bag_data.shape[0] == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
927 break |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
928 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
929 # Identify the positions of the vector columns using the column names. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
930 vec_col_indices = [df.columns.get_loc(col) for col in vector_columns] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
931 embeddings = bag_data[:, vec_col_indices].astype(np.float32) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
932 aggregated_embedding = aggregate_embeddings( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
933 embeddings, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
934 configs.pooling_method, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
935 configs.use_gpu |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
936 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
937 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
938 # Determine bag_label: 1 if any instance in this bag has label == 1. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
939 bag_label = int(np.any(bag_data[:, 1] == 1)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
940 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
941 # Merge all sample names from the bag (unique names, comma-separated). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
942 samples = np.unique(bag_data[:, 0]) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
943 merged_sample_name = ",".join(map(str, samples)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
944 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
945 # Use the provided split value. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
946 bag_split = split |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
947 bsize = bag_data.shape[0] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
948 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
949 # Build the output row with header fields: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
950 # sample_name, bag_label, split, bag_size, then embeddings. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
951 row = { |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
952 "sample_name": merged_sample_name, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
953 "bag_label": bag_label, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
954 "split": bag_split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
955 "bag_size": bsize |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
956 } |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
957 for j, val in enumerate(aggregated_embedding): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
958 row[f"vector_{j}"] = val |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
959 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
960 batch_rows.append(row) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
961 bag_count += 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
962 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
963 # Write out rows in batches. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
964 if len(batch_rows) >= batch_size: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
965 df_batch = pd.DataFrame(batch_rows) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
966 # For the first batch, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
967 # append=False (header written), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
968 # then append=True on subsequent batches. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
969 df_batch.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
970 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
971 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
972 append=(bag_count > len(batch_rows)), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
973 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
974 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
975 logging.debug( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
976 "Wrote a batch of %d rows to %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
977 len(batch_rows), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
978 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
979 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
980 batch_rows = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
981 del df_batch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
982 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
983 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
984 # Write any remaining rows. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
985 if batch_rows: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
986 df_batch = pd.DataFrame(batch_rows) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
987 df_batch.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
988 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
989 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
990 append=(bag_count > len(batch_rows)), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
991 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
992 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
993 logging.debug( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
994 "Wrote the final batch of %d rows to %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
995 len(batch_rows), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
996 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
997 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
998 del df_batch |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
999 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1000 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1001 logging.info("Created %d bags for split: %s", bag_count, split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1002 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1003 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1004 def imbalance_adjustment(bag_file, split, configs, df): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1005 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1006 Verifies if the number of bags per label in bag_file is |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1007 within imbalance_cap. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1008 If not, generates additional bags for the minority label. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1009 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1010 Args: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1011 bag_file (str): Path to the Parquet file containing bags. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1012 split (str): The current split (e.g., 'train', 'val'). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1013 config (object): Configuration with imbalance_cap, by_sample, etc. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1014 df (pd.DataFrame): Original DataFrame for generating additional bags. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1015 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1016 # Read the bag file and count bags per label |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1017 bags_df = pd.read_parquet(bag_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1018 n0 = (bags_df["bag_label"] == 0).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1019 n1 = (bags_df["bag_label"] == 1).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1020 total = n0 + n1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1021 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1022 if total == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1023 logging.warning("No bags found in %s for split %s", bag_file, split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1024 return |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1025 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1026 # Calculate imbalance as a percentage |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1027 imbalance = abs(n0 - n1) / total * 100 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1028 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1029 "Split %s: %d bags (label 0: %d, label 1: %d), imbalance %.2f%%", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1030 split, total, n0, n1, imbalance |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1031 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1032 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1033 if imbalance > configs.imbalance_cap: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1034 # Identify minority label |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1035 min_label = 0 if n0 < n1 else 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1036 n_min = n0 if min_label == 0 else n1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1037 n_maj = n1 if min_label == 0 else n0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1038 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1039 # Calculate how many bags are needed to balance (aim for equality) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1040 num_needed = n_maj - n_min |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1041 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1042 "Imbalance %.2f%% exceeds cap %.2f%% in split %s, \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1043 need %d bags for label %d", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1044 imbalance, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1045 configs.imbalance_cap, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1046 split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1047 num_needed, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1048 min_label |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1049 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1050 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1051 # Generate additional bags based on the bag creation method |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1052 if split in configs.by_sample: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1053 bag_by_sample( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1054 df, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1055 split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1056 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1057 configs, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1058 fixed_target_bags=(min_label, num_needed) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1059 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1060 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1061 bag_in_turns( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1062 df, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1063 split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1064 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1065 configs, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1066 fixed_target_bags=(min_label, num_needed) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1067 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1068 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1069 # Verify the new balance (optional, for logging) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1070 updated_bags_df = pd.read_parquet(bag_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1071 new_n0 = (updated_bags_df["bag_label"] == 0).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1072 new_n1 = (updated_bags_df["bag_label"] == 1).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1073 new_total = new_n0 + new_n1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1074 new_imbalance = abs(new_n0 - new_n1) / new_total * 100 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1075 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1076 "After adjustment, split %s: %d bags (label 0: %d, label 1: %d), \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1077 imbalance %.2f%%", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1078 split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1079 new_total, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1080 new_n0, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1081 new_n1, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1082 new_imbalance |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1083 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1084 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1085 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1086 "Imbalance %.2f%% within cap %.2f%% for split %s, \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1087 no adjustment needed", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1088 imbalance, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1089 configs.imbalance_cap, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1090 split |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1091 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1092 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1093 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1094 def truncate_bag(bag_file, split): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1095 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1096 Truncates the bags in the bag_file to balance the counts of label 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1097 and label 1, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1098 ensuring that the file is never left empty (at least one bag remains). |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1099 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1100 Args: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1101 bag_file (str): Path to the Parquet file containing the bags. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1102 split (str): The current split (e.g., 'train', 'val') |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1103 for logging purposes. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1104 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1105 Returns: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1106 None: Overwrites the bag_file with the truncated bags, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1107 ensuring at least one bag remains. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1108 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1109 logging.info("Truncating bags for split %s in file: %s", split, bag_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1110 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1111 # Step 1: Read the bag file to get the total number of bags |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1112 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1113 bags_df = pd.read_parquet(bag_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1114 except Exception as e: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1115 logging.error("Failed to read bag file %s: %s", bag_file, e) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1116 return |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1117 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1118 total_bags = len(bags_df) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1119 if total_bags == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1120 logging.warning("No bags found in %s for split %s", bag_file, split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1121 return |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1122 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1123 # Step 2: Count bags with label 0 and label 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1124 n0 = (bags_df["bag_label"] == 0).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1125 n1 = (bags_df["bag_label"] == 1).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1126 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1127 "Split %s: Total bags %d (label 0: %d, label 1: %d)", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1128 split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1129 total_bags, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1130 n0, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1131 n1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1132 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1133 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1134 # Determine the minority count and majority label |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1135 min_count = min(n0, n1) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1136 majority_label = 0 if n0 > n1 else 1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1137 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1138 if n0 == n1: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1139 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1140 "Bags already balanced for split %s, no truncation needed", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1141 split |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1142 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1143 return |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1144 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1145 # Step 3: Adjust min_count to ensure at least one bag remains |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1146 if min_count == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1147 logging.warning( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1148 "Minority label has 0 bags in split %s, keeping 1 bag from \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1149 majority label %d to avoid empty file", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1150 split, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1151 majority_label |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1152 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1153 min_count = 1 # Ensure at least one bag is kept |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1154 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1155 # Step 4: Truncate excess bags from the majority label |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1156 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1157 "Truncating %d bags from label %d to match %d bags per label", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1158 max(0, (n0 if majority_label == 0 else n1) - min_count), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1159 majority_label, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1160 min_count |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1161 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1162 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1163 # Shuffle the majority label bags to randomly select which to keep |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1164 majority_bags = bags_df[ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1165 bags_df["bag_label"] == majority_label |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1166 ].sample(frac=1, random_state=None) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1167 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1168 minority_bags = bags_df[bags_df["bag_label"] != majority_label] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1169 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1170 # Keep only min_count bags from the majority label |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1171 majority_bags_truncated = majority_bags.iloc[:min_count] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1172 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1173 # Combine the truncated majority and minority bags |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1174 truncated_bags_df = pd.concat( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1175 [majority_bags_truncated, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1176 minority_bags], |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1177 ignore_index=True |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1178 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1179 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1180 # Verify that the resulting DataFrame is not empty |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1181 if len(truncated_bags_df) == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1182 logging.error( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1183 "Unexpected empty DataFrame after truncation for split %s, \ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1184 this should not happen", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1185 split |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1186 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1187 return |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1188 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1189 # Step 5: Overwrite the bag file with the truncated bags |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1190 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1191 truncated_bags_df.to_parquet( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1192 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1193 engine="fastparquet", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1194 index=False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1195 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1196 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1197 "Overwrote %s with %d balanced bags (label 0: %d, label 1: %d)", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1198 bag_file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1199 len(truncated_bags_df), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1200 (truncated_bags_df["bag_label"] == 0).sum(), |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1201 (truncated_bags_df["bag_label"] == 1).sum() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1202 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1203 except Exception as e: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1204 logging.error("Failed to overwrite bag file %s: %s", bag_file, e) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1205 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1206 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1207 def columns_into_string(bag_file): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1208 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1209 Reads the bag file (Parquet) from the given path, identifies |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1210 the vector columns (i.e., columns not among 'sample_name', 'bag_label', 'split', |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1211 and 'bag_size'), concatenates these vector values (as strings) into a single |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1212 whitespace-separated string wrapped in double quotes, stored in a new column |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1213 "embeddings", drops the individual vector columns, and writes the modified |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1214 DataFrame back to the same Parquet file. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1215 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1216 The final output format is: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1217 "sample_name", "bag_label", "split", "bag_size", "embeddings" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1218 where "embeddings" is a string like: "0.1 0.2 0.3" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1219 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1220 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1221 "Converting vector columns into string for bag file: %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1222 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1223 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1224 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1225 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1226 df = pd.read_parquet(bag_file, engine="fastparquet") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1227 except Exception as e: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1228 logging.error("Error reading bag file %s: %s", bag_file, e) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1229 return |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1230 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1231 # Define non-vector columns. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1232 non_vector = ["sample_name", "bag_label", "split", "bag_size"] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1233 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1234 # Identify vector columns. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1235 vector_columns = [col for col in df.columns if col not in non_vector] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1236 logging.info("Identified vector columns: %s", vector_columns) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1237 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1238 # Create new 'embeddings' column by converting vector columns to str, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1239 # joining them with whitespace, and wrapping the result in double quotes. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1240 # Use apply() to ensure the result is a Series with one string per row. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1241 df["embeddings"] = df[vector_columns].astype(str).apply( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1242 lambda x: " ".join(x), axis=1 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1243 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1244 # Drop the original vector columns. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1245 df.drop(columns=vector_columns, inplace=True) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1246 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1247 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1248 # Write the modified DataFrame back to the same bag file. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1249 df.to_parquet(bag_file, engine="fastparquet", index=False) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1250 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1251 "Conversion complete. Final columns: %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1252 df.columns.tolist() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1253 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1254 except Exception as e: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1255 logging.error("Error writing updated bag file %s: %s", bag_file, e) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1256 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1257 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1258 def processing_bag(configs, bag_file, temp_file, split): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1259 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1260 Processes a single split and writes bag results |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1261 directly to the bag output Parquet file. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1262 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1263 logging.info("Processing split %s using file: %s", split, temp_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1264 df = pd.read_parquet(temp_file, engine="fastparquet") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1265 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1266 if configs.by_sample is not None and split in configs.by_sample: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1267 bag_by_sample(df, split, bag_file, configs) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1268 elif configs.balance_enforced: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1269 bag_in_turns(df, split, bag_file, configs) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1270 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1271 bag_random(df, split, bag_file, configs) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1272 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1273 # Free df if imbalance_adjustment is not needed |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1274 if configs.imbalance_cap is None: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1275 del df |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1276 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1277 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1278 if configs.imbalance_cap is not None: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1279 imbalance_adjustment(bag_file, split, configs, df) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1280 del df |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1281 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1282 elif configs.truncate_bags: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1283 truncate_bag(bag_file, split) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1284 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1285 if configs.ludwig_format: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1286 columns_into_string(bag_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1287 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1288 return bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1289 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1290 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1291 def write_final_csv(output_csv, bag_file_paths): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1292 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1293 Merges all Parquet files into a single CSV file, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1294 processing one file at a time to minimize memory usage. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1295 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1296 Args: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1297 output_csv (str): Path to the output CSV file specified |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1298 in config.output_csv. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1299 bag_file_paths (list): List of paths to the Parquet files |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1300 for each split. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1301 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1302 Returns: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1303 str: Path to the output CSV file. |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1304 """ |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1305 logging.info("Merging Parquet files into final CSV: %s", output_csv) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1306 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1307 first_file = True # Flag to determine if we need to write the header |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1308 total_rows_written = 0 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1309 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1310 # Process each Parquet file sequentially |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1311 for bag_file in bag_file_paths: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1312 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1313 # Skip empty or invalid files |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1314 if os.path.getsize(bag_file) == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1315 logging.warning( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1316 "Parquet file %s is empty (zero size), skipping", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1317 bag_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1318 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1319 continue |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1320 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1321 # Load the Parquet file into a DataFrame |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1322 df = pd.read_parquet(bag_file, engine="fastparquet") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1323 if df.empty: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1324 logging.warning("Parquet file %s is empty, skipping", bag_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1325 continue |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1326 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1327 logging.info("Loaded %d rows from Parquet file: %s, columns: %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1328 len(df), bag_file, list(df.columns)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1329 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1330 # Write the DataFrame to the CSV file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1331 # - For the first file, write with header (mode='w') |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1332 # - For subsequent files, append without header (mode='a') |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1333 mode = 'w' if first_file else 'a' |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1334 header = first_file # Write header only for the first file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1335 df.to_csv(output_csv, mode=mode, header=header, index=False) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1336 total_rows_written += len(df) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1337 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1338 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1339 "Wrote %d rows from %s to CSV, total rows written: %d", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1340 len(df), bag_file, total_rows_written |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1341 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1342 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1343 # Clear memory |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1344 del df |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1345 gc.collect() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1346 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1347 first_file = False |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1348 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1349 except Exception as e: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1350 logging.error("Failed to process Parquet file %s: %s", bag_file, e) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1351 continue |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1352 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1353 # Check if any rows were written |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1354 if total_rows_written == 0: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1355 logging.error( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1356 "No valid data loaded from Parquet files, cannot create CSV" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1357 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1358 raise ValueError("No data available to write to CSV") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1359 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1360 logging.info( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1361 "Successfully wrote %d rows to final CSV: %s", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1362 total_rows_written, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1363 output_csv |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1364 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1365 return output_csv |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1366 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1367 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1368 def process_splits(configs, embedding_files, bag_files): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1369 """Processes splits in parallel and returns all bags.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1370 splits = [0, 1, 2] # Consistent with setup_temp_files() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1371 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1372 # Filter non-empty split files |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1373 valid_info = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1374 for split in splits: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1375 temp_file = embedding_files[split] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1376 bag_file = bag_files[split] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1377 if os.path.getsize(temp_file) > 0: # Check if file has content |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1378 valid_info.append((configs, bag_file, temp_file, split)) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1379 else: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1380 logging.info("Skipping empty split file: %s", temp_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1381 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1382 if not valid_info: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1383 logging.warning("No non-empty split files to process") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1384 return [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1385 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1386 # Process splits in parallel and collect bag file paths |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1387 bag_file_paths = [] |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1388 with mp.Pool(processes=mp.cpu_count()) as pool: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1389 logging.info("Starting multiprocessing") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1390 bag_file_paths = pool.starmap(processing_bag, valid_info) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1391 logging.info("Multiprocessing is done") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1392 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1393 # Write the final CSV by merging the Parquet files |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1394 output_file = write_final_csv(configs.output_csv, bag_file_paths) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1395 return output_file |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1396 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1397 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1398 def cleanup_temp_files(split_files, bag_outputs): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1399 """Cleans up temporary Parquet files.""" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1400 for temp_file in split_files.values(): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1401 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1402 os.remove(temp_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1403 logging.info("Cleaned up temp file: %s", temp_file) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1404 except Exception as e: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1405 logging.error("Error removing %s: %s", temp_file, e) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1406 for bag_output in bag_outputs.values(): |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1407 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1408 os.remove(bag_output) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1409 logging.info("Cleaned up temp bag file: %s", bag_output) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1410 except Exception as e: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1411 logging.error("Error removing %s: %s", bag_output, e) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1412 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1413 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1414 if __name__ == "__main__": |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1415 mp.set_start_method('spawn', force=True) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1416 logging.basicConfig( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1417 level=logging.DEBUG, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1418 format='%(asctime)s - %(levelname)s - %(message)s' |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1419 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1420 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1421 parser = argparse.ArgumentParser( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1422 description="Create bags from embeddings and metadata" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1423 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1424 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1425 "--embeddings_csv", type=str, required=True, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1426 help="Path to embeddings CSV" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1427 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1428 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1429 "--metadata_csv", type=str, required=True, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1430 help="Path to metadata CSV" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1431 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1432 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1433 "--split_proportions", type=str, default='0.7,0.1,0.2', |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1434 help="Proportions for train, val, test splits" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1435 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1436 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1437 "--dataleak", action="store_true", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1438 help="Prevents data leakage" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1439 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1440 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1441 "--balance_enforced", action="store_true", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1442 help="Enforce balanced bagging" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1443 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1444 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1445 "--bag_size", type=str, required=True, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1446 help="Bag size (e.g., '4' or '3-5')" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1447 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1448 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1449 "--pooling_method", type=str, required=True, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1450 help="Pooling method" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1451 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1452 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1453 "--by_sample", type=str, default=None, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1454 help="Splits to bag by sample" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1455 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1456 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1457 "--repeats", type=int, default=1, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1458 help="Number of bagging repeats" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1459 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1460 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1461 "--ludwig_format", action="store_true", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1462 help="Output in Ludwig format" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1463 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1464 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1465 "--output_csv", type=str, required=True, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1466 help="Path to output CSV" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1467 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1468 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1469 "--random_seed", type=int, default=42, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1470 help="Random seed" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1471 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1472 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1473 "--imbalance_cap", type=int, default=None, |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1474 help="Max imbalance percentage" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1475 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1476 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1477 "--truncate_bags", action="store_true", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1478 help="Truncate bags for balance" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1479 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1480 parser.add_argument( |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1481 "--use_gpu", action="store_true", |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1482 help="Use GPU for pooling" |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1483 ) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1484 args = parser.parse_args() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1485 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1486 config = BaggingConfig(args) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1487 logging.info("Starting bagging with args: %s", config) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1488 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1489 set_random_seed(config) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1490 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1491 metadata_csv = load_metadata(config.metadata_csv) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1492 if config.prevent_leakage: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1493 metadata_csv = split_dataset(metadata_csv, config) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1494 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1495 split_temp_files, split_bag_outputs = setup_temp_files() |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1496 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1497 try: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1498 logging.info("Writing embeddings to split temp Parquet files") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1499 distribute_embeddings(config, metadata_csv, split_temp_files) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1500 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1501 logging.info("Processing embeddings for each split") |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1502 bags = process_splits(config, split_temp_files, split_bag_outputs) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1503 logging.info("Bags processed. File generated: %s", bags) |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1504 |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1505 finally: |
e6e9ea0703ef
planemo upload for repository https://github.com/goeckslab/gleam.git commit 783551569c645073698fce50f1ed9c4605b3e65a
goeckslab
parents:
diff
changeset
|
1506 cleanup_temp_files(split_temp_files, split_bag_outputs) |