Mercurial > repos > bgruening > flexynesis
comparison flexynesis_plot.py @ 3:0a8fe19cebeb draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit b2463fb68d0ae54864d87718ee72f5e063aa4587
| author | bgruening | 
|---|---|
| date | Tue, 24 Jun 2025 05:56:31 +0000 | 
| parents | |
| children | 9450286c42ab | 
   comparison
  equal
  deleted
  inserted
  replaced
| 2:2134c3079055 | 3:0a8fe19cebeb | 
|---|---|
| 1 #!/usr/bin/env python | |
| 2 """Generate plots using flexynesis | |
| 3 This script generates dimensionality reduction plots, Kaplan-Meier survival curves, | |
| 4 and Cox proportional hazards models from data processed by flexynesis.""" | |
| 5 | |
| 6 import argparse | |
| 7 import os | |
| 8 from pathlib import Path | |
| 9 | |
| 10 import matplotlib.pyplot as plt | |
| 11 import numpy as np | |
| 12 import pandas as pd | |
| 13 import seaborn as sns | |
| 14 import torch | |
| 15 from flexynesis import ( | |
| 16 build_cox_model, | |
| 17 get_important_features, | |
| 18 plot_dim_reduced, | |
| 19 plot_hazard_ratios, | |
| 20 plot_kaplan_meier_curves, | |
| 21 plot_pr_curves, | |
| 22 plot_roc_curves, | |
| 23 plot_scatter | |
| 24 ) | |
| 25 from scipy.stats import kruskal, mannwhitneyu | |
| 26 | |
| 27 | |
| 28 def load_embeddings(embeddings_path): | |
| 29 """Load embeddings from a file""" | |
| 30 try: | |
| 31 # Determine file extension | |
| 32 file_ext = Path(embeddings_path).suffix.lower() | |
| 33 | |
| 34 if file_ext == '.csv': | |
| 35 df = pd.read_csv(embeddings_path, index_col=0) | |
| 36 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
| 37 df = pd.read_csv(embeddings_path, sep='\t', index_col=0) | |
| 38 else: | |
| 39 raise ValueError(f"Unsupported file extension: {file_ext}") | |
| 40 | |
| 41 return df, df.index.tolist() | |
| 42 | |
| 43 except Exception as e: | |
| 44 raise ValueError(f"Error loading embeddings from {embeddings_path}: {e}") from e | |
| 45 | |
| 46 | |
| 47 def load_labels(labels_input): | |
| 48 """Load predicted labels from flexynesis""" | |
| 49 try: | |
| 50 # Determine file extension | |
| 51 file_ext = Path(labels_input).suffix.lower() | |
| 52 | |
| 53 if file_ext == '.csv': | |
| 54 df = pd.read_csv(labels_input) | |
| 55 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
| 56 df = pd.read_csv(labels_input, sep='\t') | |
| 57 | |
| 58 # Check if this is the specific format with sample_id, known_label, predicted_label | |
| 59 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
| 60 if all(col in df.columns for col in required_cols): | |
| 61 return df | |
| 62 else: | |
| 63 raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}") | |
| 64 | |
| 65 except Exception as e: | |
| 66 raise ValueError(f"Error loading labels from {labels_input}: {e}") from e | |
| 67 | |
| 68 | |
| 69 def load_survival_data(survival_path): | |
| 70 """Load survival data from a file. First column should be sample_id""" | |
| 71 try: | |
| 72 # Determine file extension | |
| 73 file_ext = Path(survival_path).suffix.lower() | |
| 74 | |
| 75 if file_ext == '.csv': | |
| 76 df = pd.read_csv(survival_path, index_col=0) | |
| 77 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
| 78 df = pd.read_csv(survival_path, sep='\t', index_col=0) | |
| 79 else: | |
| 80 raise ValueError(f"Unsupported file extension: {file_ext}") | |
| 81 return df | |
| 82 | |
| 83 except Exception as e: | |
| 84 raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e | |
| 85 | |
| 86 | |
| 87 def load_omics(omics_path): | |
| 88 """Load omics data from a file. First column should be features""" | |
| 89 try: | |
| 90 # Determine file extension | |
| 91 file_ext = Path(omics_path).suffix.lower() | |
| 92 | |
| 93 if file_ext == '.csv': | |
| 94 df = pd.read_csv(omics_path, index_col=0) | |
| 95 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
| 96 df = pd.read_csv(omics_path, sep='\t', index_col=0) | |
| 97 else: | |
| 98 raise ValueError(f"Unsupported file extension: {file_ext}") | |
| 99 return df | |
| 100 | |
| 101 except Exception as e: | |
| 102 raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e | |
| 103 | |
| 104 | |
| 105 def load_model(model_path): | |
| 106 """Load flexynesis model from pickle file""" | |
| 107 try: | |
| 108 with open(model_path, 'rb') as f: | |
| 109 model = torch.load(f, weights_only=False) | |
| 110 return model | |
| 111 except Exception as e: | |
| 112 raise ValueError(f"Error loading model from {model_path}: {e}") from e | |
| 113 | |
| 114 | |
| 115 def match_samples_to_embeddings(sample_names, label_data): | |
| 116 """Filter label data to match sample names in the embeddings""" | |
| 117 df_matched = label_data[label_data['sample_id'].isin(sample_names)] | |
| 118 return df_matched | |
| 119 | |
| 120 | |
| 121 def detect_color_type(labels_series): | |
| 122 """Auto-detect whether target variables should be treated as categorical or numerical""" | |
| 123 # Remove NaN | |
| 124 clean_labels = labels_series.dropna() | |
| 125 | |
| 126 if clean_labels.empty: | |
| 127 return 'categorical' # default output if no labels | |
| 128 | |
| 129 # Check if all values can be converted to numbers | |
| 130 try: | |
| 131 numeric_labels = pd.to_numeric(clean_labels, errors='coerce') | |
| 132 | |
| 133 # If conversion failed -> categorical | |
| 134 if numeric_labels.isna().any(): | |
| 135 return 'categorical' | |
| 136 | |
| 137 # Check number of unique values | |
| 138 unique_count = len(clean_labels.unique()) | |
| 139 total_count = len(clean_labels) | |
| 140 | |
| 141 # If few unique values relative to total -> categorical | |
| 142 # Threshold: if unique values < 10 OR unique/total < 0.1 | |
| 143 if unique_count < 10 or (unique_count / total_count) < 0.1: | |
| 144 return 'categorical' | |
| 145 else: | |
| 146 return 'numerical' | |
| 147 | |
| 148 except Exception: | |
| 149 return 'categorical' | |
| 150 | |
| 151 | |
| 152 def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)): | |
| 153 """ | |
| 154 Plot a heatmap reflecting the concordance between two sets of labels using pandas crosstab. | |
| 155 | |
| 156 Parameters: | |
| 157 - labels1: The first set of labels. | |
| 158 - labels2: The second set of labels. | |
| 159 """ | |
| 160 # Compute the cross-tabulation | |
| 161 ct = pd.crosstab(pd.Series(labels1, name='Labels Set 1'), pd.Series(labels2, name='Labels Set 2')) | |
| 162 # Normalize the cross-tabulation matrix column-wise | |
| 163 ct_normalized = ct.div(ct.sum(axis=1), axis=0) | |
| 164 | |
| 165 # Plot the heatmap | |
| 166 plt.figure(figsize=figsize) | |
| 167 sns.heatmap(ct_normalized, annot=True, cmap='viridis', linewidths=.5) # col_cluster=False) | |
| 168 plt.title('Concordance between label groups') | |
| 169 | |
| 170 return plt.gcf() | |
| 171 | |
| 172 | |
| 173 def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Values', figsize=(10, 6), jittersize=4): | |
| 174 """ | |
| 175 Create a boxplot with to visualize the distribution of predicted probabilities across different categories. | |
| 176 the x axis represents the true labels, and the y axis represents the predicted probabilities for specific categories. | |
| 177 """ | |
| 178 df = pd.DataFrame({title_x: categorical_x, title_y: numerical_y}) | |
| 179 | |
| 180 # Compute p-value | |
| 181 groups = df[title_x].unique() | |
| 182 if len(groups) == 2: | |
| 183 group1 = df[df[title_x] == groups[0]][title_y] | |
| 184 group2 = df[df[title_x] == groups[1]][title_y] | |
| 185 stat, p = mannwhitneyu(group1, group2, alternative='two-sided') | |
| 186 test_name = "Mann-Whitney U" | |
| 187 else: | |
| 188 group_data = [df[df[title_x] == group][title_y] for group in groups] | |
| 189 stat, p = kruskal(*group_data) | |
| 190 test_name = "Kruskal-Wallis" | |
| 191 | |
| 192 # Create a boxplot with jittered points | |
| 193 plt.figure(figsize=figsize) | |
| 194 sns.boxplot(x=title_x, y=title_y, hue=title_x, data=df, palette='Set2', legend=False, fill=False) | |
| 195 sns.stripplot(x=title_x, y=title_y, data=df, color='black', size=jittersize, jitter=True, dodge=True, alpha=0.4) | |
| 196 | |
| 197 # Labels and p-value annotation | |
| 198 plt.xlabel(title_x) | |
| 199 plt.ylabel(title_y) | |
| 200 plt.text( | |
| 201 x=-0.4, | |
| 202 y=plt.ylim()[1], | |
| 203 s=f'{test_name} p = {p:.3e}', | |
| 204 verticalalignment='top', | |
| 205 horizontalalignment='left', | |
| 206 fontsize=12, | |
| 207 bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='gray') | |
| 208 ) | |
| 209 | |
| 210 plt.tight_layout() | |
| 211 return plt.gcf() | |
| 212 | |
| 213 | |
| 214 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base): | |
| 215 """Generate dimensionality reduction plots""" | |
| 216 | |
| 217 # Parse target variables | |
| 218 target_vars = [var.strip() for var in args.target_variables.split(',')] | |
| 219 | |
| 220 print(f"Generating {args.method.upper()} plots for {len(target_vars)} target variable(s): {', '.join(target_vars)}") | |
| 221 | |
| 222 # Check variables | |
| 223 available_vars = matched_labels['variable'].unique() | |
| 224 missing_vars = [var for var in target_vars if var not in available_vars] | |
| 225 | |
| 226 if missing_vars: | |
| 227 print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}") | |
| 228 print(f"Available variables: {', '.join(available_vars)}") | |
| 229 | |
| 230 # Filter to only process available variables | |
| 231 valid_vars = [var for var in target_vars if var in available_vars] | |
| 232 | |
| 233 if not valid_vars: | |
| 234 raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}") | |
| 235 | |
| 236 # Generate plots for each valid target variable | |
| 237 for var in valid_vars: | |
| 238 print(f"\nPlotting variable: {var}") | |
| 239 | |
| 240 # Filter matched labels for current variable | |
| 241 var_labels = matched_labels[matched_labels['variable'] == var].copy() | |
| 242 var_labels = var_labels.drop_duplicates(subset='sample_id') | |
| 243 | |
| 244 if var_labels.empty: | |
| 245 print(f"Warning: No data found for variable '{var}', skipping...") | |
| 246 continue | |
| 247 | |
| 248 # Auto-detect color type | |
| 249 known_color_type = detect_color_type(var_labels['known_label']) | |
| 250 predicted_color_type = detect_color_type(var_labels['predicted_label']) | |
| 251 | |
| 252 print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") | |
| 253 | |
| 254 try: | |
| 255 # Plot 1: Known labels | |
| 256 print(f" Creating known labels plot for {var}...") | |
| 257 fig_known = plot_dim_reduced( | |
| 258 matrix=embeddings, | |
| 259 labels=var_labels['known_label'], | |
| 260 method=args.method, | |
| 261 color_type=known_color_type | |
| 262 ) | |
| 263 | |
| 264 output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" | |
| 265 print(f" Saving known labels plot to: {output_path_known.name}") | |
| 266 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | |
| 267 | |
| 268 # Plot 2: Predicted labels | |
| 269 print(f" Creating predicted labels plot for {var}...") | |
| 270 fig_predicted = plot_dim_reduced( | |
| 271 matrix=embeddings, | |
| 272 labels=var_labels['predicted_label'], | |
| 273 method=args.method, | |
| 274 color_type=predicted_color_type | |
| 275 ) | |
| 276 | |
| 277 output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}" | |
| 278 print(f" Saving predicted labels plot to: {output_path_predicted.name}") | |
| 279 fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight') | |
| 280 | |
| 281 print(f" ✓ Successfully created plots for variable '{var}'") | |
| 282 | |
| 283 except Exception as e: | |
| 284 print(f" ✗ Error creating plots for variable '{var}': {e}") | |
| 285 continue | |
| 286 | |
| 287 print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!") | |
| 288 | |
| 289 | |
| 290 def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base): | |
| 291 """Generate Kaplan-Meier plots""" | |
| 292 print("Generating Kaplan-Meier curves of risk subtypes...") | |
| 293 | |
| 294 # Reset index and rename the index column to sample_id | |
| 295 survival_data = survival_data.reset_index() | |
| 296 if survival_data.columns[0] != 'sample_id': | |
| 297 survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'}) | |
| 298 | |
| 299 # Convert survival event column to binary (0/1) based on event_value | |
| 300 # Check if the event column exists | |
| 301 if args.surv_event_var not in survival_data.columns: | |
| 302 raise ValueError(f"Column '{args.surv_event_var}' not found in survival data") | |
| 303 | |
| 304 # Convert to string for comparison to handle mixed types | |
| 305 survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str) | |
| 306 event_value_str = str(args.event_value) | |
| 307 | |
| 308 # Create binary event column (1 if matches event_value, 0 otherwise) | |
| 309 survival_data[f'{args.surv_event_var}_binary'] = ( | |
| 310 survival_data[args.surv_event_var] == event_value_str | |
| 311 ).astype(int) | |
| 312 | |
| 313 # Filter for survival category and class_label == '1:DECEASED' | |
| 314 label_data['class_label'] = label_data['class_label'].astype(str) | |
| 315 | |
| 316 label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)] | |
| 317 | |
| 318 # check survival data | |
| 319 for col in [args.surv_time_var, args.surv_event_var]: | |
| 320 if col not in survival_data.columns: | |
| 321 raise ValueError(f"Column '{col}' not found in survival data") | |
| 322 | |
| 323 # Merge survival data with labels | |
| 324 df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner') | |
| 325 | |
| 326 if df_deceased.empty: | |
| 327 raise ValueError("No matching samples found after merging survival and label data.") | |
| 328 | |
| 329 # Get risk scores | |
| 330 risk_scores = df_deceased['probability'].values | |
| 331 | |
| 332 # Compute groups (e.g., median split) | |
| 333 quantiles = np.quantile(risk_scores, [0.5]) | |
| 334 groups = np.digitize(risk_scores, quantiles) | |
| 335 group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups] | |
| 336 | |
| 337 fig_known = plot_kaplan_meier_curves( | |
| 338 durations=df_deceased[args.surv_time_var], | |
| 339 events=df_deceased[f'{args.surv_event_var}_binary'], | |
| 340 categorical_variable=group_labels | |
| 341 ) | |
| 342 | |
| 343 output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}" | |
| 344 print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}") | |
| 345 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | |
| 346 | |
| 347 print("Kaplan-Meier plot saved successfully!") | |
| 348 | |
| 349 | |
| 350 def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): | |
| 351 """Generate Cox proportional hazards plots""" | |
| 352 print("Generating Cox proportional hazards analysis...") | |
| 353 | |
| 354 # Parse clinical variables | |
| 355 clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] | |
| 356 | |
| 357 # Validate that survival variables are included | |
| 358 required_vars = [args.surv_time_var, args.surv_event_var] | |
| 359 for var in required_vars: | |
| 360 if var not in clinical_vars: | |
| 361 clinical_vars.append(var) | |
| 362 | |
| 363 print(f"Using clinical variables: {', '.join(clinical_vars)}") | |
| 364 | |
| 365 # filter datasets for clinical variables | |
| 366 if all(var in clinical_train.columns and var in clinical_test.columns for var in clinical_vars): | |
| 367 df_clin_train = clinical_train[clinical_vars] | |
| 368 df_clin_test = clinical_test[clinical_vars] | |
| 369 # Drop rows with NaN in clinical variables | |
| 370 df_clin_train = df_clin_train.dropna(subset=clinical_vars) | |
| 371 df_clin_test = df_clin_test.dropna(subset=clinical_vars) | |
| 372 else: | |
| 373 raise ValueError(f"Not all clinical variables found in datasets. Available in train dataset: {clinical_train.columns.tolist()}, Available in test dataset: {clinical_test.columns.tolist()}") | |
| 374 | |
| 375 # Combine | |
| 376 df_clin = pd.concat([df_clin_train, df_clin_test], axis=0) | |
| 377 | |
| 378 # Get top survival markers | |
| 379 print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...") | |
| 380 try: | |
| 381 imp = get_important_features(model, | |
| 382 var=args.surv_event_var, | |
| 383 top=args.top_features | |
| 384 )['name'].unique().tolist() | |
| 385 print(f"Top features: {', '.join(imp)}") | |
| 386 except Exception as e: | |
| 387 raise ValueError(f"Error getting important features: {e}") | |
| 388 | |
| 389 # Extract feature data from omics datasets | |
| 390 try: | |
| 391 omics_test = omics_test.loc[omics_test.index.isin(imp)] | |
| 392 omics_train = omics_train.loc[omics_train.index.isin(imp)] | |
| 393 # Drop rows with NaN in omics datasets | |
| 394 omics_test = omics_test.dropna(subset=omics_test.columns) | |
| 395 omics_train = omics_train.dropna(subset=omics_train.columns) | |
| 396 | |
| 397 df_imp = pd.concat([omics_train, omics_test], axis=1) | |
| 398 df_imp = df_imp.T # Transpose to have samples as rows | |
| 399 | |
| 400 print(f"Feature data shape: {df_imp.shape}") | |
| 401 except Exception as e: | |
| 402 raise ValueError(f"Error extracting feature subset: {e}") | |
| 403 | |
| 404 # Combine markers with clinical variables | |
| 405 df = pd.merge(df_imp, df_clin, left_index=True, right_index=True) | |
| 406 print(f"Combined data shape: {df.shape}") | |
| 407 | |
| 408 # Remove samples without survival endpoints | |
| 409 initial_samples = len(df) | |
| 410 df = df[df[args.surv_event_var].notna()] | |
| 411 final_samples = len(df) | |
| 412 print(f"Removed {initial_samples - final_samples} samples without survival data") | |
| 413 | |
| 414 if df.empty: | |
| 415 raise ValueError("No samples remain after filtering for survival data") | |
| 416 | |
| 417 # Convert survival event column to binary (0/1) based on event_value | |
| 418 # Convert to string for comparison to handle mixed types | |
| 419 df[args.surv_event_var] = df[args.surv_event_var].astype(str) | |
| 420 event_value_str = str(args.event_value) | |
| 421 | |
| 422 df[f'{args.surv_event_var}'] = ( | |
| 423 df[args.surv_event_var] == event_value_str | |
| 424 ).astype(int) | |
| 425 | |
| 426 # Build Cox model | |
| 427 print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}") | |
| 428 try: | |
| 429 coxm = build_cox_model(df, | |
| 430 duration_col=args.surv_time_var, | |
| 431 event_col=args.surv_event_var, | |
| 432 crossval=args.crossval, | |
| 433 n_splits=args.n_splits, | |
| 434 random_state=args.random_state) | |
| 435 print("Cox model built successfully") | |
| 436 except Exception as e: | |
| 437 raise ValueError(f"Error building Cox model: {e}") | |
| 438 | |
| 439 # Generate hazard ratios plot | |
| 440 try: | |
| 441 print("Generating hazard ratios plot...") | |
| 442 fig = plot_hazard_ratios(coxm) | |
| 443 | |
| 444 output_path = output_dir / f"{output_name_base}_hazard_ratios.{args.format}" | |
| 445 print(f"Saving hazard ratios plot to: {output_path.absolute()}") | |
| 446 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 447 | |
| 448 print("Cox proportional hazards analysis completed successfully!") | |
| 449 | |
| 450 except Exception as e: | |
| 451 raise ValueError(f"Error generating hazard ratios plot: {e}") | |
| 452 | |
| 453 | |
| 454 def generate_plot_scatter(labels, args, output_dir, output_name_base): | |
| 455 """Generate scatter plot of known vs predicted labels""" | |
| 456 print("Generating scatter plots of known vs predicted labels...") | |
| 457 | |
| 458 # Parse target values from comma-separated string | |
| 459 if args.target_value: | |
| 460 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 461 else: | |
| 462 # If no target values specified, use all unique variables | |
| 463 target_values = labels['variable'].unique().tolist() | |
| 464 | |
| 465 print(f"Processing target values: {target_values}") | |
| 466 | |
| 467 successful_plots = 0 | |
| 468 skipped_plots = 0 | |
| 469 | |
| 470 for target_value in target_values: | |
| 471 print(f"\nProcessing target value: '{target_value}'") | |
| 472 | |
| 473 # Filter labels for the current target value | |
| 474 target_labels = labels[labels['variable'] == target_value] | |
| 475 | |
| 476 if target_labels.empty: | |
| 477 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
| 478 skipped_plots += 1 | |
| 479 continue | |
| 480 | |
| 481 # Check if labels are numeric and convert | |
| 482 true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') | |
| 483 predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') | |
| 484 | |
| 485 if true_values.isna().all() or predicted_values.isna().all(): | |
| 486 print(f"No valid numeric values found for known or predicted labels in '{target_value}'") | |
| 487 skipped_plots += 1 | |
| 488 continue | |
| 489 | |
| 490 try: | |
| 491 print(f" Generating scatter plot for '{target_value}'...") | |
| 492 fig = plot_scatter(true_values, predicted_values) | |
| 493 | |
| 494 # Create output filename with target value | |
| 495 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
| 496 if len(target_values) > 1: | |
| 497 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 498 else: | |
| 499 output_filename = f"{output_name_base}.{args.format}" | |
| 500 | |
| 501 output_path = output_dir / output_filename | |
| 502 print(f" Saving scatter plot to: {output_path.absolute()}") | |
| 503 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 504 | |
| 505 successful_plots += 1 | |
| 506 print(f" Scatter plot for '{target_value}' generated successfully!") | |
| 507 | |
| 508 except Exception as e: | |
| 509 print(f" Error generating plot for '{target_value}': {str(e)}") | |
| 510 skipped_plots += 1 | |
| 511 | |
| 512 # Summary | |
| 513 print(" Summary:") | |
| 514 print(f" Successfully generated: {successful_plots} plots") | |
| 515 print(f" Skipped: {skipped_plots} plots") | |
| 516 | |
| 517 if successful_plots == 0: | |
| 518 raise ValueError("No scatter plots could be generated. Check your data and target values.") | |
| 519 | |
| 520 print("Scatter plot generation completed!") | |
| 521 | |
| 522 | |
| 523 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base): | |
| 524 """Generate label concordance heatmap""" | |
| 525 print("Generating label concordance heatmaps...") | |
| 526 | |
| 527 # Parse target values from comma-separated string | |
| 528 if args.target_value: | |
| 529 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 530 else: | |
| 531 # If no target values specified, use all unique variables | |
| 532 target_values = labels['variable'].unique().tolist() | |
| 533 | |
| 534 print(f"Processing target values: {target_values}") | |
| 535 | |
| 536 for target_value in target_values: | |
| 537 print(f"\nProcessing target value: '{target_value}'") | |
| 538 | |
| 539 # Filter labels for the current target value | |
| 540 target_labels = labels[labels['variable'] == target_value] | |
| 541 | |
| 542 if target_labels.empty: | |
| 543 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
| 544 continue | |
| 545 | |
| 546 true_values = target_labels['known_label'].tolist() | |
| 547 predicted_values = target_labels['predicted_label'].tolist() | |
| 548 | |
| 549 try: | |
| 550 print(f" Generating heatmap for '{target_value}'...") | |
| 551 fig = plot_label_concordance_heatmap(true_values, predicted_values) | |
| 552 plt.close(fig) | |
| 553 | |
| 554 # Create output filename with target value | |
| 555 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
| 556 if len(target_values) > 1: | |
| 557 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 558 else: | |
| 559 output_filename = f"{output_name_base}.{args.format}" | |
| 560 | |
| 561 output_path = output_dir / output_filename | |
| 562 print(f" Saving heatmap to: {output_path.absolute()}") | |
| 563 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 564 | |
| 565 except Exception as e: | |
| 566 print(f" Error generating heatmap for '{target_value}': {str(e)}") | |
| 567 continue | |
| 568 | |
| 569 print("Label concordance heatmap generated successfully!") | |
| 570 | |
| 571 | |
| 572 def generate_pr_curves(labels, args, output_dir, output_name_base): | |
| 573 """Generate precision-recall curves""" | |
| 574 print("Generating precision-recall curves...") | |
| 575 | |
| 576 # Parse target values from comma-separated string | |
| 577 if args.target_value: | |
| 578 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 579 else: | |
| 580 # If no target values specified, use all unique variables | |
| 581 target_values = labels['variable'].unique().tolist() | |
| 582 | |
| 583 print(f"Processing target values: {target_values}") | |
| 584 | |
| 585 for target_value in target_values: | |
| 586 print(f"\nProcessing target value: '{target_value}'") | |
| 587 | |
| 588 # Filter labels for the current target value | |
| 589 target_labels = labels[labels['variable'] == target_value] | |
| 590 | |
| 591 # Check if this is a regression problem (no class probabilities) | |
| 592 prob_columns = target_labels['class_label'].unique() | |
| 593 non_na_probs = target_labels['probability'].notna().sum() | |
| 594 | |
| 595 print(f" Class labels found: {list(prob_columns)}") | |
| 596 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | |
| 597 | |
| 598 # If most probabilities are NaN, this is likely a regression problem | |
| 599 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | |
| 600 print(" Detected regression problem - precision-recall curves not applicable") | |
| 601 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | |
| 602 continue | |
| 603 | |
| 604 # Debug: Check data quality | |
| 605 total_rows = len(target_labels) | |
| 606 missing_labels = target_labels['known_label'].isna().sum() | |
| 607 missing_probs = target_labels['probability'].isna().sum() | |
| 608 unique_samples = target_labels['sample_id'].nunique() | |
| 609 unique_classes = target_labels['class_label'].nunique() | |
| 610 | |
| 611 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
| 612 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
| 613 | |
| 614 if missing_labels > 0: | |
| 615 print(f" Warning: Found {missing_labels} missing known_label values") | |
| 616 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
| 617 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
| 618 | |
| 619 # Remove rows with missing known_label | |
| 620 target_labels = target_labels.dropna(subset=['known_label']) | |
| 621 if target_labels.empty: | |
| 622 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
| 623 continue | |
| 624 | |
| 625 # 1. Pivot to wide format | |
| 626 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') | |
| 627 | |
| 628 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") | |
| 629 print(f" Class columns: {list(prob_df.columns)}") | |
| 630 | |
| 631 # Check for NaN values in probability data | |
| 632 nan_counts = prob_df.isna().sum() | |
| 633 if nan_counts.any(): | |
| 634 print(f" NaN counts per class: {dict(nan_counts)}") | |
| 635 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") | |
| 636 | |
| 637 # Drop only rows where ALL probabilities are NaN | |
| 638 all_nan_rows = prob_df.isna().all(axis=1) | |
| 639 if all_nan_rows.any(): | |
| 640 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") | |
| 641 prob_df = prob_df[~all_nan_rows] | |
| 642 | |
| 643 remaining_nans = prob_df.isna().sum().sum() | |
| 644 if remaining_nans > 0: | |
| 645 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") | |
| 646 prob_df = prob_df.fillna(0) | |
| 647 | |
| 648 if prob_df.empty: | |
| 649 print(f" Error: No valid probability data remaining for '{target_value}' - skipping") | |
| 650 continue | |
| 651 | |
| 652 # 2. Get true labels | |
| 653 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') | |
| 654 | |
| 655 # 3. Align indices - only keep samples that exist in both datasets | |
| 656 common_indices = prob_df.index.intersection(true_labels_df.index) | |
| 657 if len(common_indices) == 0: | |
| 658 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") | |
| 659 continue | |
| 660 | |
| 661 print(f" Found {len(common_indices)} samples with both probability and true label data") | |
| 662 | |
| 663 # Filter both datasets to common indices | |
| 664 prob_df_aligned = prob_df.loc[common_indices] | |
| 665 y_true = true_labels_df.loc[common_indices]['known_label'] | |
| 666 | |
| 667 # 4. Final check for NaN values | |
| 668 if y_true.isna().any(): | |
| 669 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") | |
| 670 continue | |
| 671 | |
| 672 if prob_df_aligned.isna().any().any(): | |
| 673 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") | |
| 674 continue | |
| 675 | |
| 676 # 5. Convert categorical labels to integer labels | |
| 677 # Create a mapping from class names to integers | |
| 678 class_names = list(prob_df_aligned.columns) | |
| 679 class_to_int = {class_name: i for i, class_name in enumerate(class_names)} | |
| 680 | |
| 681 print(f" Class mapping: {class_to_int}") | |
| 682 | |
| 683 # Convert true labels to integers | |
| 684 y_true_np = y_true.map(class_to_int).to_numpy() | |
| 685 y_probs_np = prob_df_aligned.to_numpy() | |
| 686 | |
| 687 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") | |
| 688 print(f" Unique true labels (integers): {set(y_true_np)}") | |
| 689 print(f" Class labels (columns): {class_names}") | |
| 690 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") | |
| 691 | |
| 692 # Check for any unmapped labels (will be NaN) | |
| 693 if pd.isna(y_true_np).any(): | |
| 694 print(" Error: Some true labels could not be mapped to class columns") | |
| 695 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) | |
| 696 print(f" Unmapped labels: {unmapped_labels}") | |
| 697 print(f" Available classes: {class_names}") | |
| 698 continue | |
| 699 | |
| 700 try: | |
| 701 print(f" Generating precision-recall curve for '{target_value}'...") | |
| 702 fig = plot_pr_curves(y_true_np, y_probs_np) | |
| 703 | |
| 704 # Create output filename with target value | |
| 705 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
| 706 if len(target_values) > 1: | |
| 707 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 708 else: | |
| 709 output_filename = f"{output_name_base}.{args.format}" | |
| 710 | |
| 711 output_path = output_dir / output_filename | |
| 712 print(f" Saving precision-recall curve to: {output_path.absolute()}") | |
| 713 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 714 | |
| 715 except Exception as e: | |
| 716 print(f" Error generating precision-recall curve for '{target_value}': {str(e)}") | |
| 717 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") | |
| 718 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") | |
| 719 continue | |
| 720 | |
| 721 print("Precision-recall curves generated successfully!") | |
| 722 | |
| 723 | |
| 724 def generate_roc_curves(labels, args, output_dir, output_name_base): | |
| 725 """Generate ROC curves""" | |
| 726 print("Generating ROC curves...") | |
| 727 | |
| 728 # Parse target values from comma-separated string | |
| 729 if args.target_value: | |
| 730 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 731 else: | |
| 732 # If no target values specified, use all unique variables | |
| 733 target_values = labels['variable'].unique().tolist() | |
| 734 | |
| 735 print(f"Processing target values: {target_values}") | |
| 736 | |
| 737 for target_value in target_values: | |
| 738 print(f"\nProcessing target value: '{target_value}'") | |
| 739 | |
| 740 # Filter labels for the current target value | |
| 741 target_labels = labels[labels['variable'] == target_value] | |
| 742 | |
| 743 # Check if this is a regression problem (no class probabilities) | |
| 744 prob_columns = target_labels['class_label'].unique() | |
| 745 non_na_probs = target_labels['probability'].notna().sum() | |
| 746 | |
| 747 print(f" Class labels found: {list(prob_columns)}") | |
| 748 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | |
| 749 | |
| 750 # If most probabilities are NaN, this is likely a regression problem | |
| 751 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | |
| 752 print(" Detected regression problem - ROC curves not applicable") | |
| 753 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | |
| 754 continue | |
| 755 | |
| 756 # Debug: Check data quality | |
| 757 total_rows = len(target_labels) | |
| 758 missing_labels = target_labels['known_label'].isna().sum() | |
| 759 missing_probs = target_labels['probability'].isna().sum() | |
| 760 unique_samples = target_labels['sample_id'].nunique() | |
| 761 unique_classes = target_labels['class_label'].nunique() | |
| 762 | |
| 763 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
| 764 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
| 765 | |
| 766 if missing_labels > 0: | |
| 767 print(f" Warning: Found {missing_labels} missing known_label values") | |
| 768 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
| 769 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
| 770 | |
| 771 # Remove rows with missing known_label | |
| 772 target_labels = target_labels.dropna(subset=['known_label']) | |
| 773 if target_labels.empty: | |
| 774 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
| 775 continue | |
| 776 | |
| 777 # 1. Pivot to wide format | |
| 778 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') | |
| 779 | |
| 780 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") | |
| 781 print(f" Class columns: {list(prob_df.columns)}") | |
| 782 | |
| 783 # Check for NaN values in probability data | |
| 784 nan_counts = prob_df.isna().sum() | |
| 785 if nan_counts.any(): | |
| 786 print(f" NaN counts per class: {dict(nan_counts)}") | |
| 787 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") | |
| 788 | |
| 789 # Drop only rows where ALL probabilities are NaN | |
| 790 all_nan_rows = prob_df.isna().all(axis=1) | |
| 791 if all_nan_rows.any(): | |
| 792 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") | |
| 793 prob_df = prob_df[~all_nan_rows] | |
| 794 | |
| 795 remaining_nans = prob_df.isna().sum().sum() | |
| 796 if remaining_nans > 0: | |
| 797 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") | |
| 798 prob_df = prob_df.fillna(0) | |
| 799 | |
| 800 if prob_df.empty: | |
| 801 print(f" Error: No valid probability data remaining for '{target_value}' - skipping") | |
| 802 continue | |
| 803 | |
| 804 # 2. Get true labels | |
| 805 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') | |
| 806 | |
| 807 # 3. Align indices - only keep samples that exist in both datasets | |
| 808 common_indices = prob_df.index.intersection(true_labels_df.index) | |
| 809 if len(common_indices) == 0: | |
| 810 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") | |
| 811 continue | |
| 812 | |
| 813 print(f" Found {len(common_indices)} samples with both probability and true label data") | |
| 814 | |
| 815 # Filter both datasets to common indices | |
| 816 prob_df_aligned = prob_df.loc[common_indices] | |
| 817 y_true = true_labels_df.loc[common_indices]['known_label'] | |
| 818 | |
| 819 # 4. Final check for NaN values | |
| 820 if y_true.isna().any(): | |
| 821 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") | |
| 822 continue | |
| 823 | |
| 824 if prob_df_aligned.isna().any().any(): | |
| 825 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") | |
| 826 continue | |
| 827 | |
| 828 # 5. Convert categorical labels to integer labels | |
| 829 # Create a mapping from class names to integers | |
| 830 class_names = list(prob_df_aligned.columns) | |
| 831 class_to_int = {class_name: i for i, class_name in enumerate(class_names)} | |
| 832 | |
| 833 print(f" Class mapping: {class_to_int}") | |
| 834 | |
| 835 # Convert true labels to integers | |
| 836 y_true_np = y_true.map(class_to_int).to_numpy() | |
| 837 y_probs_np = prob_df_aligned.to_numpy() | |
| 838 | |
| 839 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") | |
| 840 print(f" Unique true labels (integers): {set(y_true_np)}") | |
| 841 print(f" Class labels (columns): {class_names}") | |
| 842 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") | |
| 843 | |
| 844 # Check for any unmapped labels (will be NaN) | |
| 845 if pd.isna(y_true_np).any(): | |
| 846 print(" Error: Some true labels could not be mapped to class columns") | |
| 847 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) | |
| 848 print(f" Unmapped labels: {unmapped_labels}") | |
| 849 print(f" Available classes: {class_names}") | |
| 850 continue | |
| 851 | |
| 852 try: | |
| 853 print(f" Generating ROC curve for '{target_value}'...") | |
| 854 fig = plot_roc_curves(y_true_np, y_probs_np) | |
| 855 | |
| 856 # Create output filename with target value | |
| 857 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
| 858 if len(target_values) > 1: | |
| 859 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 860 else: | |
| 861 output_filename = f"{output_name_base}.{args.format}" | |
| 862 | |
| 863 output_path = output_dir / output_filename | |
| 864 print(f" Saving ROC curve to: {output_path.absolute()}") | |
| 865 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 866 | |
| 867 except Exception as e: | |
| 868 print(f" Error generating ROC curve for '{target_value}': {str(e)}") | |
| 869 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") | |
| 870 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") | |
| 871 continue | |
| 872 | |
| 873 print("ROC curves generated successfully!") | |
| 874 | |
| 875 | |
| 876 def generate_box_plots(labels, args, output_dir, output_name_base): | |
| 877 """Generate box plots for model predictions""" | |
| 878 | |
| 879 print("Generating box plots...") | |
| 880 | |
| 881 # Parse target values from comma-separated string | |
| 882 if args.target_value: | |
| 883 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 884 else: | |
| 885 # If no target values specified, use all unique variables | |
| 886 target_values = labels['variable'].unique().tolist() | |
| 887 | |
| 888 print(f"Processing target values: {target_values}") | |
| 889 | |
| 890 for target_value in target_values: | |
| 891 print(f"\nProcessing target value: '{target_value}'") | |
| 892 | |
| 893 # Filter labels for the current target value | |
| 894 target_labels = labels[labels['variable'] == target_value] | |
| 895 | |
| 896 if target_labels.empty: | |
| 897 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
| 898 continue | |
| 899 | |
| 900 # Check if this is a classification problem (has probabilities) | |
| 901 prob_columns = target_labels['class_label'].unique() | |
| 902 non_na_probs = target_labels['probability'].notna().sum() | |
| 903 | |
| 904 print(f" Class labels found: {list(prob_columns)}") | |
| 905 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | |
| 906 | |
| 907 # If most probabilities are NaN, this is likely a regression problem | |
| 908 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | |
| 909 print(" Detected regression problem - precision-recall curves not applicable") | |
| 910 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | |
| 911 continue | |
| 912 | |
| 913 # Debug: Check data quality | |
| 914 total_rows = len(target_labels) | |
| 915 missing_labels = target_labels['known_label'].isna().sum() | |
| 916 missing_probs = target_labels['probability'].isna().sum() | |
| 917 unique_samples = target_labels['sample_id'].nunique() | |
| 918 unique_classes = target_labels['class_label'].nunique() | |
| 919 | |
| 920 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
| 921 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
| 922 | |
| 923 if missing_labels > 0: | |
| 924 print(f" Warning: Found {missing_labels} missing known_label values") | |
| 925 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
| 926 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
| 927 | |
| 928 # Remove rows with missing known_label | |
| 929 target_labels = target_labels.dropna(subset=['known_label']) | |
| 930 if target_labels.empty: | |
| 931 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
| 932 continue | |
| 933 | |
| 934 # Remove rows with missing data | |
| 935 clean_data = target_labels.dropna(subset=['known_label', 'probability']) | |
| 936 | |
| 937 if clean_data.empty: | |
| 938 print(" No valid data after cleaning - skipping") | |
| 939 continue | |
| 940 | |
| 941 # Get unique classes | |
| 942 classes = clean_data['class_label'].unique() | |
| 943 | |
| 944 for class_label in classes: | |
| 945 print(f" Generating box plot for class: {class_label}") | |
| 946 | |
| 947 # Filter for current class | |
| 948 class_data = clean_data[clean_data['class_label'] == class_label] | |
| 949 | |
| 950 try: | |
| 951 # Create the box plot | |
| 952 fig = plot_boxplot( | |
| 953 categorical_x=class_data['known_label'], | |
| 954 numerical_y=class_data['probability'], | |
| 955 title_x='True Label', | |
| 956 title_y=f'Predicted Probability ({class_label})', | |
| 957 ) | |
| 958 | |
| 959 # Save the plot | |
| 960 safe_class_name = str(class_label).replace('/', '_').replace('\\', '_').replace(' ', '_').replace(':', '_') | |
| 961 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
| 962 output_filename = f"{output_name_base}_{safe_target_name}_{safe_class_name}.{args.format}" | |
| 963 output_path = output_dir / output_filename | |
| 964 | |
| 965 print(f" Saving box plot to: {output_path.absolute()}") | |
| 966 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 967 plt.close(fig) | |
| 968 | |
| 969 except Exception as e: | |
| 970 print(f" Error generating box plot for class '{class_label}': {str(e)}") | |
| 971 continue | |
| 972 | |
| 973 | |
| 974 def main(): | |
| 975 """Main function to parse arguments and generate plots""" | |
| 976 parser = argparse.ArgumentParser(description="Generate plots using flexynesis") | |
| 977 | |
| 978 # Required arguments | |
| 979 parser.add_argument("--labels", type=str, required=False, | |
| 980 help="Path to labels file generated by flexynesis") | |
| 981 | |
| 982 # Plot type | |
| 983 parser.add_argument("--plot_type", type=str, required=True, | |
| 984 choices=['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'], | |
| 985 help="Type of plot to generate: 'dimred' for dimensionality reduction, 'kaplan_meier' for survival analysis, 'cox' for Cox proportional hazards analysis, 'scatter' for scatter plots, 'concordance_heatmap' for label concordance heatmaps, 'pr_curve' for precision-recall curves, 'roc_curve' for ROC curves, or 'box_plot' for box plots.") | |
| 986 | |
| 987 # Arguments for dimensionality reduction | |
| 988 parser.add_argument("--embeddings", type=str, | |
| 989 help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.") | |
| 990 parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'], | |
| 991 help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.") | |
| 992 parser.add_argument("--target_variables", type=str, required=False, | |
| 993 help="Comma-separated list of target variables to plot.") | |
| 994 | |
| 995 # Arguments for Kaplan-Meier | |
| 996 parser.add_argument("--survival_data", type=str, | |
| 997 help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.") | |
| 998 parser.add_argument("--surv_time_var", type=str, required=False, | |
| 999 help="Column name for survival time") | |
| 1000 parser.add_argument("--surv_event_var", type=str, required=False, | |
| 1001 help="Column name for survival event") | |
| 1002 parser.add_argument("--event_value", type=str, required=False, | |
| 1003 help="Value in event column that represents an event (e.g., 'DECEASED')") | |
| 1004 | |
| 1005 # Arguments for Cox analysis | |
| 1006 parser.add_argument("--model", type=str, | |
| 1007 help="Path to trained flexynesis model (pickle file). Required for cox plots.") | |
| 1008 parser.add_argument("--clinical_train", type=str, | |
| 1009 help="Path to training dataset (pickle file). Required for cox plots.") | |
| 1010 parser.add_argument("--clinical_test", type=str, | |
| 1011 help="Path to test dataset (pickle file). Required for cox plots.") | |
| 1012 parser.add_argument("--omics_train", type=str, default=None, | |
| 1013 help="Path to training omics dataset. Optional for cox plots.") | |
| 1014 parser.add_argument("--omics_test", type=str, default=None, | |
| 1015 help="Path to test omics dataset. Optional for cox plots.") | |
| 1016 parser.add_argument("--clinical_variables", type=str, | |
| 1017 help="Comma-separated list of clinical variables to include in Cox model (e.g., 'AGE,SEX,HISTOLOGICAL_DIAGNOSIS,STUDY')") | |
| 1018 parser.add_argument("--top_features", type=int, default=20, | |
| 1019 help="Number of top important features to include in Cox model. Default is 5") | |
| 1020 parser.add_argument("--crossval", action='store_true', | |
| 1021 help="If True, performs K-fold cross-validation and returns average C-index. Default is False") | |
| 1022 parser.add_argument("--n_splits", type=int, default=5, | |
| 1023 help="Number of folds for cross-validation. Default is 5") | |
| 1024 parser.add_argument("--random_state", type=int, default=42, | |
| 1025 help="Random seed for reproducibility. Default is 42") | |
| 1026 | |
| 1027 # Arguments for scatter plot, heatmap, PR curves, ROC curves, and box plots | |
| 1028 parser.add_argument("--target_value", type=str, default=None, | |
| 1029 help="Target value for scatter plot.") | |
| 1030 | |
| 1031 # Common arguments | |
| 1032 parser.add_argument("--output_dir", type=str, default='output', | |
| 1033 help="Output directory. Default is 'output'") | |
| 1034 parser.add_argument("--output_name", type=str, default=None, | |
| 1035 help="Output filename base") | |
| 1036 parser.add_argument("--format", type=str, default='jpg', choices=['png', 'pdf', 'svg', 'jpg'], | |
| 1037 help="Output format for the plot. Default is 'jpg'") | |
| 1038 parser.add_argument("--dpi", type=int, default=300, | |
| 1039 help="DPI for the output image. Default is 300") | |
| 1040 | |
| 1041 args = parser.parse_args() | |
| 1042 | |
| 1043 try: | |
| 1044 # validate plot type | |
| 1045 if not args.plot_type: | |
| 1046 raise ValueError("Please specify a plot type using --plot_type") | |
| 1047 if args.plot_type not in ['dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot']: | |
| 1048 raise ValueError(f"Invalid plot type: {args.plot_type}. Must be one of: 'dimred', 'kaplan_meier', 'cox', 'scatter', 'concordance_heatmap', 'pr_curve', 'roc_curve', 'box_plot'") | |
| 1049 | |
| 1050 # Validate plot type requirements | |
| 1051 if args.plot_type in ['dimred']: | |
| 1052 if not args.embeddings: | |
| 1053 raise ValueError("--embeddings is required when plot_type is 'dimred'") | |
| 1054 if not os.path.isfile(args.embeddings): | |
| 1055 raise FileNotFoundError(f"embeddings file not found: {args.embeddings}") | |
| 1056 if not args.labels: | |
| 1057 raise ValueError("--labels is required for dimensionality reduction plots") | |
| 1058 if not args.method: | |
| 1059 raise ValueError("--method is required for dimensionality reduction plots") | |
| 1060 if not args.target_variables: | |
| 1061 raise ValueError("--target_variables is required for dimensionality reduction plots") | |
| 1062 | |
| 1063 if args.plot_type in ['kaplan_meier']: | |
| 1064 if not args.survival_data: | |
| 1065 raise ValueError("--survival_data is required when plot_type is 'kaplan_meier'") | |
| 1066 if not os.path.isfile(args.survival_data): | |
| 1067 raise FileNotFoundError(f"Survival data file not found: {args.survival_data}") | |
| 1068 if not args.labels: | |
| 1069 raise ValueError("--labels is required for dimensionality reduction plots") | |
| 1070 if not args.method: | |
| 1071 raise ValueError("--method is required for dimensionality reduction plots") | |
| 1072 if not args.surv_time_var: | |
| 1073 raise ValueError("--surv_time_var is required for Kaplan-Meier plots") | |
| 1074 if not args.surv_event_var: | |
| 1075 raise ValueError("--surv_event_var is required for Kaplan-Meier plots") | |
| 1076 if not args.event_value: | |
| 1077 raise ValueError("--event_value is required for Kaplan-Meier plots") | |
| 1078 | |
| 1079 if args.plot_type in ['cox']: | |
| 1080 if not args.model: | |
| 1081 raise ValueError("--model is required when plot_type is 'cox'") | |
| 1082 if not os.path.isfile(args.model): | |
| 1083 raise FileNotFoundError(f"Model file not found: {args.model}") | |
| 1084 if not args.clinical_train: | |
| 1085 raise ValueError("--clinical_train is required when plot_type is 'cox'") | |
| 1086 if not os.path.isfile(args.clinical_train): | |
| 1087 raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}") | |
| 1088 if not args.clinical_test: | |
| 1089 raise ValueError("--clinical_test is required when plot_type is 'cox'") | |
| 1090 if not os.path.isfile(args.clinical_test): | |
| 1091 raise FileNotFoundError(f"Test dataset file not found: {args.clinical_test}") | |
| 1092 if not args.omics_train: | |
| 1093 raise ValueError("--omics_train is required when plot_type is 'cox'") | |
| 1094 if not os.path.isfile(args.omics_train): | |
| 1095 raise FileNotFoundError(f"Training omics dataset file not found: {args.omics_train}") | |
| 1096 if not args.omics_test: | |
| 1097 raise ValueError("--omics_test is required when plot_type is 'cox'") | |
| 1098 if not os.path.isfile(args.omics_test): | |
| 1099 raise FileNotFoundError(f"Test omics dataset file not found: {args.omics_test}") | |
| 1100 if not args.surv_time_var: | |
| 1101 raise ValueError("--surv_time_var is required for Cox plots") | |
| 1102 if not args.surv_event_var: | |
| 1103 raise ValueError("--surv_event_var is required for Cox plots") | |
| 1104 if not args.clinical_variables: | |
| 1105 raise ValueError("--clinical_variables is required for Cox plots") | |
| 1106 if not isinstance(args.top_features, int) or args.top_features <= 0: | |
| 1107 raise ValueError("--top_features must be a positive integer") | |
| 1108 if not args.event_value: | |
| 1109 raise ValueError("--event_value is required for Kaplan-Meier plots") | |
| 1110 if not args.crossval: | |
| 1111 args.crossval = False | |
| 1112 if not isinstance(args.n_splits, int) or args.n_splits <= 0: | |
| 1113 raise ValueError("--n_splits must be a positive integer") | |
| 1114 if not isinstance(args.random_state, int): | |
| 1115 raise ValueError("--random_state must be an integer") | |
| 1116 | |
| 1117 if args.plot_type in ['scatter']: | |
| 1118 if not args.labels: | |
| 1119 raise ValueError("--labels is required for scatter plots") | |
| 1120 if not args.target_value: | |
| 1121 print("--target_value is not specified, using all unique variables from labels") | |
| 1122 if not os.path.isfile(args.labels): | |
| 1123 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
| 1124 | |
| 1125 if args.plot_type in ['concordance_heatmap']: | |
| 1126 if not args.labels: | |
| 1127 raise ValueError("--labels is required for concordance heatmap") | |
| 1128 if not args.target_value: | |
| 1129 print("--target_value is not specified, using all unique variables from labels") | |
| 1130 if not os.path.isfile(args.labels): | |
| 1131 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
| 1132 | |
| 1133 if args.plot_type in ['pr_curve']: | |
| 1134 if not args.labels: | |
| 1135 raise ValueError("--labels is required for precision-recall curves") | |
| 1136 if not args.target_value: | |
| 1137 print("--target_value is not specified, using all unique variables from labels") | |
| 1138 if not os.path.isfile(args.labels): | |
| 1139 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
| 1140 | |
| 1141 if args.plot_type in ['roc_curve']: | |
| 1142 if not args.labels: | |
| 1143 raise ValueError("--labels is required for ROC curves") | |
| 1144 if not args.target_value: | |
| 1145 print("--target_value is not specified, using all unique variables from labels") | |
| 1146 if not os.path.isfile(args.labels): | |
| 1147 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
| 1148 | |
| 1149 if args.plot_type in ['box_plot']: | |
| 1150 if not args.labels: | |
| 1151 raise ValueError("--labels is required for box plots") | |
| 1152 if not args.target_value: | |
| 1153 print("--target_value is not specified, using all unique variables from labels") | |
| 1154 if not os.path.isfile(args.labels): | |
| 1155 raise FileNotFoundError(f"Labels file not found: {args.labels}") | |
| 1156 | |
| 1157 # Validate other arguments | |
| 1158 if args.method not in ['pca', 'umap']: | |
| 1159 raise ValueError("Method must be 'pca' or 'umap'") | |
| 1160 | |
| 1161 # Create output directory | |
| 1162 output_dir = Path(args.output_dir) | |
| 1163 output_dir.mkdir(parents=True, exist_ok=True) | |
| 1164 print(f"Output directory: {output_dir.absolute()}") | |
| 1165 | |
| 1166 # Generate output filename base | |
| 1167 if args.output_name: | |
| 1168 output_name_base = args.output_name | |
| 1169 else: | |
| 1170 if args.plot_type == 'dimred': | |
| 1171 embeddings_name = Path(args.embeddings).stem | |
| 1172 output_name_base = f"{embeddings_name}_{args.method}" | |
| 1173 elif args.plot_type == 'kaplan_meier': | |
| 1174 survival_name = Path(args.survival_data).stem | |
| 1175 output_name_base = f"{survival_name}_km" | |
| 1176 elif args.plot_type == 'cox': | |
| 1177 model_name = Path(args.model).stem | |
| 1178 output_name_base = f"{model_name}_cox" | |
| 1179 elif args.plot_type == 'scatter': | |
| 1180 labels_name = Path(args.labels).stem | |
| 1181 output_name_base = f"{labels_name}_scatter" | |
| 1182 elif args.plot_type == 'concordance_heatmap': | |
| 1183 labels_name = Path(args.labels).stem | |
| 1184 output_name_base = f"{labels_name}_concordance" | |
| 1185 elif args.plot_type == 'pr_curve': | |
| 1186 labels_name = Path(args.labels).stem | |
| 1187 output_name_base = f"{labels_name}_pr_curves" | |
| 1188 elif args.plot_type == 'roc_curve': | |
| 1189 labels_name = Path(args.labels).stem | |
| 1190 output_name_base = f"{labels_name}_roc_curves" | |
| 1191 elif args.plot_type == 'box_plot': | |
| 1192 labels_name = Path(args.labels).stem | |
| 1193 output_name_base = f"{labels_name}_box_plot" | |
| 1194 | |
| 1195 # Generate plots based on type | |
| 1196 if args.plot_type in ['dimred']: | |
| 1197 # Load labels | |
| 1198 print(f"Loading labels from: {args.labels}") | |
| 1199 label_data = load_labels(args.labels) | |
| 1200 # Load embeddings data | |
| 1201 print(f"Loading embeddings from: {args.embeddings}") | |
| 1202 embeddings, sample_names = load_embeddings(args.embeddings) | |
| 1203 print(f"embeddings shape: {embeddings.shape}") | |
| 1204 | |
| 1205 # Match samples to embeddings | |
| 1206 matched_labels = match_samples_to_embeddings(sample_names, label_data) | |
| 1207 print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction") | |
| 1208 | |
| 1209 generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base) | |
| 1210 | |
| 1211 elif args.plot_type in ['kaplan_meier']: | |
| 1212 # Load labels | |
| 1213 print(f"Loading labels from: {args.labels}") | |
| 1214 label_data = load_labels(args.labels) | |
| 1215 # Load survival data | |
| 1216 print(f"Loading survival data from: {args.survival_data}") | |
| 1217 survival_data = load_survival_data(args.survival_data) | |
| 1218 print(f"Survival data shape: {survival_data.shape}") | |
| 1219 | |
| 1220 generate_km_plots(survival_data, label_data, args, output_dir, output_name_base) | |
| 1221 | |
| 1222 elif args.plot_type in ['cox']: | |
| 1223 # Load model and datasets | |
| 1224 print(f"Loading model from: {args.model}") | |
| 1225 model = load_model(args.model) | |
| 1226 print(f"Loading training dataset from: {args.clinical_train}") | |
| 1227 clinical_train = load_omics(args.clinical_train) | |
| 1228 print(f"Loading test dataset from: {args.clinical_test}") | |
| 1229 clinical_test = load_omics(args.clinical_test) | |
| 1230 print(f"Loading training omics dataset from: {args.omics_train}") | |
| 1231 omics_train = load_omics(args.omics_train) | |
| 1232 print(f"Loading test omics dataset from: {args.omics_test}") | |
| 1233 omics_test = load_omics(args.omics_test) | |
| 1234 | |
| 1235 generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) | |
| 1236 | |
| 1237 elif args.plot_type in ['scatter']: | |
| 1238 # Load labels | |
| 1239 print(f"Loading labels from: {args.labels}") | |
| 1240 label_data = load_labels(args.labels) | |
| 1241 | |
| 1242 generate_plot_scatter(label_data, args, output_dir, output_name_base) | |
| 1243 | |
| 1244 elif args.plot_type in ['concordance_heatmap']: | |
| 1245 # Load labels | |
| 1246 print(f"Loading labels from: {args.labels}") | |
| 1247 label_data = load_labels(args.labels) | |
| 1248 | |
| 1249 generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base) | |
| 1250 | |
| 1251 elif args.plot_type in ['pr_curve']: | |
| 1252 # Load labels | |
| 1253 print(f"Loading labels from: {args.labels}") | |
| 1254 label_data = load_labels(args.labels) | |
| 1255 | |
| 1256 generate_pr_curves(label_data, args, output_dir, output_name_base) | |
| 1257 | |
| 1258 elif args.plot_type in ['roc_curve']: | |
| 1259 # Load labels | |
| 1260 print(f"Loading labels from: {args.labels}") | |
| 1261 label_data = load_labels(args.labels) | |
| 1262 | |
| 1263 generate_roc_curves(label_data, args, output_dir, output_name_base) | |
| 1264 | |
| 1265 elif args.plot_type in ['box_plot']: | |
| 1266 # Load labels | |
| 1267 print(f"Loading labels from: {args.labels}") | |
| 1268 label_data = load_labels(args.labels) | |
| 1269 | |
| 1270 generate_box_plots(label_data, args, output_dir, output_name_base) | |
| 1271 | |
| 1272 print("All plots generated successfully!") | |
| 1273 | |
| 1274 except (FileNotFoundError, ValueError, pd.errors.ParserError) as e: | |
| 1275 print(f"Error: {e}") | |
| 1276 return 1 | |
| 1277 | |
| 1278 return 0 | |
| 1279 | |
| 1280 | |
| 1281 if __name__ == "__main__": | |
| 1282 exit(main()) | 
