Mercurial > repos > bgruening > flexynesis
comparison flexynesis_plot.py @ 8:9c91d13827ef draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flexynesis commit 6b520305ec30e6dc37eba92c67a5368cea0fc5ad
| author | bgruening |
|---|---|
| date | Wed, 23 Jul 2025 07:50:31 +0000 |
| parents | 9450286c42ab |
| children | e0a67265f3ed |
comparison
equal
deleted
inserted
replaced
| 7:9450286c42ab | 8:9c91d13827ef |
|---|---|
| 9 | 9 |
| 10 import matplotlib.pyplot as plt | 10 import matplotlib.pyplot as plt |
| 11 import numpy as np | 11 import numpy as np |
| 12 import pandas as pd | 12 import pandas as pd |
| 13 import seaborn as sns | 13 import seaborn as sns |
| 14 import torch | |
| 15 from flexynesis import ( | 14 from flexynesis import ( |
| 16 build_cox_model, | 15 build_cox_model, |
| 17 get_important_features, | |
| 18 plot_dim_reduced, | 16 plot_dim_reduced, |
| 19 plot_hazard_ratios, | 17 plot_hazard_ratios, |
| 20 plot_kaplan_meier_curves, | 18 plot_kaplan_meier_curves, |
| 21 plot_pr_curves, | 19 plot_pr_curves, |
| 22 plot_roc_curves, | 20 plot_roc_curves, |
| 53 if file_ext == '.csv': | 51 if file_ext == '.csv': |
| 54 df = pd.read_csv(labels_input) | 52 df = pd.read_csv(labels_input) |
| 55 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | 53 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: |
| 56 df = pd.read_csv(labels_input, sep='\t') | 54 df = pd.read_csv(labels_input, sep='\t') |
| 57 | 55 |
| 58 # Check if this is the specific format with sample_id, known_label, predicted_label | 56 print(f"available columns: {df.columns.tolist()}") |
| 59 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | 57 return df |
| 60 if all(col in df.columns for col in required_cols): | |
| 61 return df | |
| 62 else: | |
| 63 raise ValueError(f"Labels file {labels_input} does not contain required columns: {required_cols}") | |
| 64 | 58 |
| 65 except Exception as e: | 59 except Exception as e: |
| 66 raise ValueError(f"Error loading labels from {labels_input}: {e}") from e | 60 raise ValueError(f"Error loading labels from {labels_input}: {e}") from e |
| 67 | |
| 68 | |
| 69 def load_survival_data(survival_path): | |
| 70 """Load survival data from a file. First column should be sample_id""" | |
| 71 try: | |
| 72 # Determine file extension | |
| 73 file_ext = Path(survival_path).suffix.lower() | |
| 74 | |
| 75 if file_ext == '.csv': | |
| 76 df = pd.read_csv(survival_path, index_col=0) | |
| 77 elif file_ext in ['.tsv', '.txt', '.tab', '.tabular']: | |
| 78 df = pd.read_csv(survival_path, sep='\t', index_col=0) | |
| 79 else: | |
| 80 raise ValueError(f"Unsupported file extension: {file_ext}") | |
| 81 return df | |
| 82 | |
| 83 except Exception as e: | |
| 84 raise ValueError(f"Error loading survival data from {survival_path}: {e}") from e | |
| 85 | 61 |
| 86 | 62 |
| 87 def load_omics(omics_path): | 63 def load_omics(omics_path): |
| 88 """Load omics data from a file. First column should be features""" | 64 """Load omics data from a file. First column should be features""" |
| 89 try: | 65 try: |
| 100 | 76 |
| 101 except Exception as e: | 77 except Exception as e: |
| 102 raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e | 78 raise ValueError(f"Error loading omics data from {omics_path}: {e}") from e |
| 103 | 79 |
| 104 | 80 |
| 105 def load_model(model_path): | 81 def match_samples_to_embeddings(sample_names, labels): |
| 106 """Load flexynesis model from pickle file""" | |
| 107 try: | |
| 108 with open(model_path, 'rb') as f: | |
| 109 model = torch.load(f, weights_only=False) | |
| 110 return model | |
| 111 except Exception as e: | |
| 112 raise ValueError(f"Error loading model from {model_path}: {e}") from e | |
| 113 | |
| 114 | |
| 115 def match_samples_to_embeddings(sample_names, label_data): | |
| 116 """Filter label data to match sample names in the embeddings""" | 82 """Filter label data to match sample names in the embeddings""" |
| 117 df_matched = label_data[label_data['sample_id'].isin(sample_names)] | 83 # Create a DataFrame from sample_names to preserve order |
| 84 sample_df = pd.DataFrame({'sample_names': sample_names}) | |
| 85 | |
| 86 # left_join | |
| 87 first_column = labels.columns[0] | |
| 88 df_matched = sample_df.merge(labels, left_on='sample_names', right_on=first_column, how='left') | |
| 89 | |
| 90 # remove sample_names to keep the initial structure | |
| 91 df_matched = df_matched.drop('sample_names', axis=1) | |
| 118 return df_matched | 92 return df_matched |
| 119 | 93 |
| 120 | 94 |
| 121 def detect_color_type(labels_series): | 95 def detect_color_type(labels_series): |
| 122 """Auto-detect whether target variables should be treated as categorical or numerical""" | 96 """Auto-detect whether target variables should be treated as categorical or numerical""" |
| 212 | 186 |
| 213 | 187 |
| 214 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base): | 188 def generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base): |
| 215 """Generate dimensionality reduction plots""" | 189 """Generate dimensionality reduction plots""" |
| 216 | 190 |
| 217 # Parse target values from comma-separated string | 191 # Check if this is the specific format with sample_id, known_label, predicted_label |
| 218 if args.target_value: | 192 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] |
| 219 target_values = [val.strip() for val in args.target_value.split(',')] | 193 is_flexynesis_format = all(col in matched_labels.columns for col in required_cols) |
| 194 | |
| 195 if not args.color: | |
| 196 if is_flexynesis_format: | |
| 197 print("Detected flexynesis labels format") | |
| 198 print(f"Generating {args.method.upper()} plots for known and predicted labels...") | |
| 199 else: | |
| 200 print("Labels are not in flexynesis format (Custom labels), please specify a color variable with --color") | |
| 201 | |
| 202 # Parse target values from comma-separated string | |
| 203 if args.target_value: | |
| 204 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 205 else: | |
| 206 # If no target values specified, use all unique variables | |
| 207 target_values = matched_labels['variable'].unique().tolist() | |
| 208 | |
| 209 print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}") | |
| 210 | |
| 211 # Check variables | |
| 212 available_vars = matched_labels['variable'].unique() | |
| 213 missing_vars = [var for var in target_values if var not in available_vars] | |
| 214 | |
| 215 if missing_vars: | |
| 216 print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}") | |
| 217 print(f"Available variables: {', '.join(available_vars)}") | |
| 218 | |
| 219 # Filter to only process available variables | |
| 220 valid_vars = [var for var in target_values if var in available_vars] | |
| 221 | |
| 222 if not valid_vars: | |
| 223 raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}") | |
| 224 | |
| 225 # Generate plots for each valid target variable | |
| 226 for var in valid_vars: | |
| 227 print(f"\nPlotting variable: {var}") | |
| 228 | |
| 229 # Filter matched labels for current variable | |
| 230 var_labels = matched_labels[matched_labels['variable'] == var].copy() | |
| 231 var_labels = var_labels.drop_duplicates(subset='sample_id') | |
| 232 | |
| 233 if var_labels.empty: | |
| 234 print(f"Warning: No data found for variable '{var}', skipping...") | |
| 235 continue | |
| 236 | |
| 237 # Auto-detect color type | |
| 238 known_color_type = detect_color_type(var_labels['known_label']) | |
| 239 predicted_color_type = detect_color_type(var_labels['predicted_label']) | |
| 240 | |
| 241 print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") | |
| 242 | |
| 243 try: | |
| 244 # Plot 1: Known labels | |
| 245 print(f" Creating known labels plot for {var}...") | |
| 246 fig_known = plot_dim_reduced( | |
| 247 matrix=embeddings, | |
| 248 labels=var_labels['known_label'], | |
| 249 method=args.method, | |
| 250 color_type=known_color_type | |
| 251 ) | |
| 252 | |
| 253 output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" | |
| 254 print(f" Saving known labels plot to: {output_path_known.name}") | |
| 255 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | |
| 256 | |
| 257 # Plot 2: Predicted labels | |
| 258 print(f" Creating predicted labels plot for {var}...") | |
| 259 fig_predicted = plot_dim_reduced( | |
| 260 matrix=embeddings, | |
| 261 labels=var_labels['predicted_label'], | |
| 262 method=args.method, | |
| 263 color_type=predicted_color_type | |
| 264 ) | |
| 265 | |
| 266 output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}" | |
| 267 print(f" Saving predicted labels plot to: {output_path_predicted.name}") | |
| 268 fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight') | |
| 269 | |
| 270 print(f" ✓ Successfully created plots for variable '{var}'") | |
| 271 | |
| 272 except Exception as e: | |
| 273 print(f" ✗ Error creating plots for variable '{var}': {e}") | |
| 274 continue | |
| 275 | |
| 276 print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!") | |
| 277 | |
| 220 else: | 278 else: |
| 221 # If no target values specified, use all unique variables | 279 # check if the color variable exists in matched_labels |
| 222 target_values = matched_labels['variable'].unique().tolist() | 280 if args.color not in matched_labels.columns: |
| 223 | 281 raise ValueError(f"Color variable '{args.color}' not found in matched labels. Available columns: {matched_labels.columns.tolist()}") |
| 224 print(f"Generating {args.method.upper()} plots for {len(target_values)} target variable(s): {', '.join(target_values)}") | |
| 225 | |
| 226 # Check variables | |
| 227 available_vars = matched_labels['variable'].unique() | |
| 228 missing_vars = [var for var in target_values if var not in available_vars] | |
| 229 | |
| 230 if missing_vars: | |
| 231 print(f"Warning: The following target variables were not found in the data: {', '.join(missing_vars)}") | |
| 232 print(f"Available variables: {', '.join(available_vars)}") | |
| 233 | |
| 234 # Filter to only process available variables | |
| 235 valid_vars = [var for var in target_values if var in available_vars] | |
| 236 | |
| 237 if not valid_vars: | |
| 238 raise ValueError(f"None of the specified target variables were found in the data. Available: {', '.join(available_vars)}") | |
| 239 | |
| 240 # Generate plots for each valid target variable | |
| 241 for var in valid_vars: | |
| 242 print(f"\nPlotting variable: {var}") | |
| 243 | |
| 244 # Filter matched labels for current variable | |
| 245 var_labels = matched_labels[matched_labels['variable'] == var].copy() | |
| 246 var_labels = var_labels.drop_duplicates(subset='sample_id') | |
| 247 | |
| 248 if var_labels.empty: | |
| 249 print(f"Warning: No data found for variable '{var}', skipping...") | |
| 250 continue | |
| 251 | 282 |
| 252 # Auto-detect color type | 283 # Auto-detect color type |
| 253 known_color_type = detect_color_type(var_labels['known_label']) | 284 color_type = detect_color_type(matched_labels[args.color]) |
| 254 predicted_color_type = detect_color_type(var_labels['predicted_label']) | 285 |
| 255 | 286 print(f" Auto-detected color type: {color_type}") |
| 256 print(f" Auto-detected color types - Known: {known_color_type}, Predicted: {predicted_color_type}") | 287 |
| 257 | 288 # Plot: Specified color column |
| 258 try: | 289 print(f" Creating plot for {args.color}...") |
| 259 # Plot 1: Known labels | 290 fig = plot_dim_reduced( |
| 260 print(f" Creating known labels plot for {var}...") | 291 matrix=embeddings, |
| 261 fig_known = plot_dim_reduced( | 292 labels=matched_labels[args.color], |
| 262 matrix=embeddings, | 293 method=args.method, |
| 263 labels=var_labels['known_label'], | 294 color_type=color_type |
| 264 method=args.method, | 295 ) |
| 265 color_type=known_color_type | 296 |
| 266 ) | 297 output_path = output_dir / f"{output_name_base}_{args.color}.{args.format}" |
| 267 | 298 print(f" Saving plot to: {output_path.name}") |
| 268 output_path_known = output_dir / f"{output_name_base}_{var}_known.{args.format}" | 299 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') |
| 269 print(f" Saving known labels plot to: {output_path_known.name}") | 300 |
| 270 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | 301 print(f" ✓ Successfully created plot for variable '{args.color}'") |
| 271 | 302 |
| 272 # Plot 2: Predicted labels | 303 |
| 273 print(f" Creating predicted labels plot for {var}...") | 304 def generate_km_plots(survival_data, labels, args, output_dir, output_name_base): |
| 274 fig_predicted = plot_dim_reduced( | |
| 275 matrix=embeddings, | |
| 276 labels=var_labels['predicted_label'], | |
| 277 method=args.method, | |
| 278 color_type=predicted_color_type | |
| 279 ) | |
| 280 | |
| 281 output_path_predicted = output_dir / f"{output_name_base}_{var}_predicted.{args.format}" | |
| 282 print(f" Saving predicted labels plot to: {output_path_predicted.name}") | |
| 283 fig_predicted.save(output_path_predicted, dpi=args.dpi, bbox_inches='tight') | |
| 284 | |
| 285 print(f" ✓ Successfully created plots for variable '{var}'") | |
| 286 | |
| 287 except Exception as e: | |
| 288 print(f" ✗ Error creating plots for variable '{var}': {e}") | |
| 289 continue | |
| 290 | |
| 291 print(f"\nDimensionality reduction plots completed for {len(valid_vars)} variable(s)!") | |
| 292 | |
| 293 | |
| 294 def generate_km_plots(survival_data, label_data, args, output_dir, output_name_base): | |
| 295 """Generate Kaplan-Meier plots""" | 305 """Generate Kaplan-Meier plots""" |
| 306 | |
| 307 # Check if this is the specific format with sample_id, known_label, predicted_label | |
| 308 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
| 309 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
| 310 | |
| 311 if not is_flexynesis_format: | |
| 312 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") | |
| 313 | |
| 296 print("Generating Kaplan-Meier curves of risk subtypes...") | 314 print("Generating Kaplan-Meier curves of risk subtypes...") |
| 297 | 315 |
| 298 # Reset index and rename the index column to sample_id | |
| 299 survival_data = survival_data.reset_index() | |
| 300 if survival_data.columns[0] != 'sample_id': | 316 if survival_data.columns[0] != 'sample_id': |
| 301 survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'}) | 317 survival_data = survival_data.rename(columns={survival_data.columns[0]: 'sample_id'}) |
| 302 | 318 |
| 303 # Convert survival event column to binary (0/1) based on event_value | |
| 304 # Check if the event column exists | 319 # Check if the event column exists |
| 305 if args.surv_event_var not in survival_data.columns: | 320 if args.surv_event_var not in survival_data.columns: |
| 306 raise ValueError(f"Column '{args.surv_event_var}' not found in survival data") | 321 raise ValueError(f"Column '{args.surv_event_var}' not found in survival data") |
| 307 | 322 |
| 308 # Convert to string for comparison to handle mixed types | 323 labels = labels[(labels['variable'] == args.surv_event_var)] |
| 309 survival_data[args.surv_event_var] = survival_data[args.surv_event_var].astype(str) | |
| 310 event_value_str = str(args.event_value) | |
| 311 | |
| 312 # Create binary event column (1 if matches event_value, 0 otherwise) | |
| 313 survival_data[f'{args.surv_event_var}_binary'] = ( | |
| 314 survival_data[args.surv_event_var] == event_value_str | |
| 315 ).astype(int) | |
| 316 | |
| 317 # Filter for survival category and class_label == '1:DECEASED' | |
| 318 label_data['class_label'] = label_data['class_label'].astype(str) | |
| 319 | |
| 320 label_data = label_data[(label_data['variable'] == args.surv_event_var) & (label_data['class_label'] == event_value_str)] | |
| 321 | |
| 322 # check survival data | |
| 323 for col in [args.surv_time_var, args.surv_event_var]: | |
| 324 if col not in survival_data.columns: | |
| 325 raise ValueError(f"Column '{col}' not found in survival data") | |
| 326 | 324 |
| 327 # Merge survival data with labels | 325 # Merge survival data with labels |
| 328 df_deceased = pd.merge(survival_data, label_data, on='sample_id', how='inner') | 326 df_deceased = pd.merge(survival_data, labels, on='sample_id', how='inner') |
| 327 df_deceased = df_deceased.dropna(subset=[args.surv_time_var, args.surv_event_var]) | |
| 329 | 328 |
| 330 if df_deceased.empty: | 329 if df_deceased.empty: |
| 331 raise ValueError("No matching samples found after merging survival and label data.") | 330 raise ValueError("No matching samples found after merging survival and label data.") |
| 332 | 331 |
| 333 # Get risk scores | 332 # Get risk scores |
| 334 risk_scores = df_deceased['probability'].values | 333 risk_scores = df_deceased['predicted_label'].values |
| 335 | 334 |
| 336 # Compute groups (e.g., median split) | 335 # Compute groups (e.g., median split) |
| 337 quantiles = np.quantile(risk_scores, [0.5]) | 336 quantiles = np.quantile(risk_scores, [0.5]) |
| 338 groups = np.digitize(risk_scores, quantiles) | 337 groups = np.digitize(risk_scores, quantiles) |
| 339 group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups] | 338 group_labels = ['low_risk' if g == 0 else 'high_risk' for g in groups] |
| 340 | 339 |
| 341 fig_known = plot_kaplan_meier_curves( | 340 fig_known = plot_kaplan_meier_curves( |
| 342 durations=df_deceased[args.surv_time_var], | 341 durations=df_deceased[args.surv_time_var], |
| 343 events=df_deceased[f'{args.surv_event_var}_binary'], | 342 events=df_deceased[args.surv_event_var], |
| 344 categorical_variable=group_labels | 343 categorical_variable=group_labels |
| 345 ) | 344 ) |
| 346 | 345 |
| 347 output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}" | 346 output_path_known = output_dir / f"{output_name_base}_km_risk_subtypes.{args.format}" |
| 348 print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}") | 347 print(f"Saving Kaplan-Meier plot to: {output_path_known.absolute()}") |
| 349 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') | 348 fig_known.save(output_path_known, dpi=args.dpi, bbox_inches='tight') |
| 350 | 349 |
| 351 print("Kaplan-Meier plot saved successfully!") | 350 print("Kaplan-Meier plot saved successfully!") |
| 352 | 351 |
| 353 | 352 |
| 354 def generate_cox_plots(model, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): | 353 def generate_cox_plots(important_features, clinical_train, clinical_test, omics_train, omics_test, args, output_dir, output_name_base): |
| 355 """Generate Cox proportional hazards plots""" | 354 """Generate Cox proportional hazards plots""" |
| 356 print("Generating Cox proportional hazards analysis...") | 355 print("Generating Cox proportional hazards analysis...") |
| 357 | 356 |
| 357 # Check if this is the specific format with target_variable, importance | |
| 358 required_cols = ['target_variable', 'layer', 'importance'] | |
| 359 is_flexynesis_format = all(col in important_features.columns for col in required_cols) | |
| 360 | |
| 361 if not is_flexynesis_format: | |
| 362 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid important_features file with the required columns, {required_cols}.") | |
| 363 | |
| 358 # Parse clinical variables | 364 # Parse clinical variables |
| 359 clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] | 365 clinical_vars = [] |
| 366 if args.clinical_variables: | |
| 367 clinical_vars = [var.strip() for var in args.clinical_variables.split(',')] | |
| 360 | 368 |
| 361 # Validate that survival variables are included | 369 # Validate that survival variables are included |
| 362 required_vars = [args.surv_time_var, args.surv_event_var] | 370 required_vars = [args.surv_time_var, args.surv_event_var] |
| 363 for var in required_vars: | 371 for var in required_vars: |
| 364 if var not in clinical_vars: | 372 if var not in clinical_vars: |
| 380 df_clin = pd.concat([df_clin_train, df_clin_test], axis=0) | 388 df_clin = pd.concat([df_clin_train, df_clin_test], axis=0) |
| 381 | 389 |
| 382 # Get top survival markers | 390 # Get top survival markers |
| 383 print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...") | 391 print(f"Extracting top {args.top_features} important features for {args.surv_event_var}...") |
| 384 try: | 392 try: |
| 385 imp = get_important_features(model, | 393 print(f"Loading {args.top_features} important features from: {args.important_features}") |
| 386 var=args.surv_event_var, | 394 imp_features = load_labels(args.important_features) |
| 387 top=args.top_features | 395 imp_features = imp_features[imp_features['target_variable'] == args.surv_event_var] |
| 388 )['name'].unique().tolist() | 396 if args.layer not in imp_features['layer'].unique(): |
| 397 print(f"Available class labels: {imp_features['layer'].unique()}") | |
| 398 raise ValueError(f"Class label '{args.layer}' not found in important features data: {args.important_features}") | |
| 399 imp_features = imp_features[imp_features['layer'] == args.layer] | |
| 400 if imp_features.empty: | |
| 401 raise ValueError(f"No important features found for target variable '{args.surv_event_var}' in {args.important_features}") | |
| 402 imp_features = imp_features.sort_values(by='importance', ascending=False) | |
| 403 | |
| 404 if len(imp_features) < args.top_features: | |
| 405 raise ValueError(f"Requested top {args.top_features} features, but only {len(imp_features)} available in {args.important_features}") | |
| 406 | |
| 407 imp = imp_features['name'].unique().tolist()[0:args.top_features] | |
| 408 | |
| 389 print(f"Top features: {', '.join(imp)}") | 409 print(f"Top features: {', '.join(imp)}") |
| 390 except Exception as e: | 410 except Exception as e: |
| 391 raise ValueError(f"Error getting important features: {e}") | 411 raise ValueError(f"Error getting important features: {e}") |
| 392 | 412 |
| 393 # Extract feature data from omics datasets | 413 # Extract feature data from omics datasets |
| 415 final_samples = len(df) | 435 final_samples = len(df) |
| 416 print(f"Removed {initial_samples - final_samples} samples without survival data") | 436 print(f"Removed {initial_samples - final_samples} samples without survival data") |
| 417 | 437 |
| 418 if df.empty: | 438 if df.empty: |
| 419 raise ValueError("No samples remain after filtering for survival data") | 439 raise ValueError("No samples remain after filtering for survival data") |
| 420 | |
| 421 # Convert survival event column to binary (0/1) based on event_value | |
| 422 # Convert to string for comparison to handle mixed types | |
| 423 df[args.surv_event_var] = df[args.surv_event_var].astype(str) | |
| 424 event_value_str = str(args.event_value) | |
| 425 | |
| 426 df[f'{args.surv_event_var}'] = ( | |
| 427 df[args.surv_event_var] == event_value_str | |
| 428 ).astype(int) | |
| 429 | 440 |
| 430 # Build Cox model | 441 # Build Cox model |
| 431 print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}") | 442 print(f"Building Cox model with time variable: {args.surv_time_var}, event variable: {args.surv_event_var}") |
| 432 try: | 443 try: |
| 433 coxm = build_cox_model(df, | 444 coxm = build_cox_model(df, |
| 457 | 468 |
| 458 def generate_plot_scatter(labels, args, output_dir, output_name_base): | 469 def generate_plot_scatter(labels, args, output_dir, output_name_base): |
| 459 """Generate scatter plot of known vs predicted labels""" | 470 """Generate scatter plot of known vs predicted labels""" |
| 460 print("Generating scatter plots of known vs predicted labels...") | 471 print("Generating scatter plots of known vs predicted labels...") |
| 461 | 472 |
| 473 # Check if this is the specific format with sample_id, known_label, predicted_label | |
| 474 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
| 475 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
| 476 | |
| 477 if is_flexynesis_format: | |
| 478 # Parse target values from comma-separated string | |
| 479 if args.target_value: | |
| 480 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 481 else: | |
| 482 # If no target values specified, use all unique variables | |
| 483 target_values = labels['variable'].unique().tolist() | |
| 484 | |
| 485 print(f"Processing target values: {target_values}") | |
| 486 | |
| 487 successful_plots = 0 | |
| 488 skipped_plots = 0 | |
| 489 | |
| 490 for target_value in target_values: | |
| 491 print(f"\nProcessing target value: '{target_value}'") | |
| 492 | |
| 493 # Filter labels for the current target value | |
| 494 target_labels = labels[labels['variable'] == target_value] | |
| 495 | |
| 496 if target_labels.empty: | |
| 497 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
| 498 skipped_plots += 1 | |
| 499 continue | |
| 500 | |
| 501 # Check if labels are numeric and convert | |
| 502 true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') | |
| 503 predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') | |
| 504 | |
| 505 if true_values.isna().all() or predicted_values.isna().all(): | |
| 506 print(f"No valid numeric values found for known or predicted labels in '{target_value}'") | |
| 507 skipped_plots += 1 | |
| 508 continue | |
| 509 | |
| 510 try: | |
| 511 print(f" Generating scatter plot for '{target_value}'...") | |
| 512 fig = plot_scatter(true_values, predicted_values) | |
| 513 | |
| 514 # Create output filename with target value | |
| 515 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
| 516 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 517 | |
| 518 output_path = output_dir / output_filename | |
| 519 print(f" Saving scatter plot to: {output_path.absolute()}") | |
| 520 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 521 | |
| 522 successful_plots += 1 | |
| 523 print(f" Scatter plot for '{target_value}' generated successfully!") | |
| 524 | |
| 525 except Exception as e: | |
| 526 print(f" Error generating plot for '{target_value}': {str(e)}") | |
| 527 skipped_plots += 1 | |
| 528 | |
| 529 # Summary | |
| 530 print(" Summary:") | |
| 531 print(f" Successfully generated: {successful_plots} plots") | |
| 532 print(f" Skipped: {skipped_plots} plots") | |
| 533 | |
| 534 if successful_plots == 0: | |
| 535 raise ValueError("No scatter plots could be generated. Check your data and target values.") | |
| 536 | |
| 537 print("Scatter plot generation completed!") | |
| 538 | |
| 539 if not is_flexynesis_format: | |
| 540 print("Labels are not in flexynesis format (Custom labels)") | |
| 541 | |
| 542 if not args.true_label or not args.predicted_label: | |
| 543 raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.") | |
| 544 | |
| 545 # Check if labels are numeric and convert | |
| 546 true_values = pd.to_numeric(labels[args.true_label], errors='coerce') | |
| 547 predicted_values = pd.to_numeric(labels[args.predicted_label], errors='coerce') | |
| 548 | |
| 549 if true_values.isna().all() or predicted_values.isna().all(): | |
| 550 print("No valid numeric values found for known or predicted labels") | |
| 551 | |
| 552 try: | |
| 553 print(" Generating scatter plot...") | |
| 554 fig = plot_scatter(true_values, predicted_values) | |
| 555 | |
| 556 # Create output filename with target value | |
| 557 output_filename = f"{output_name_base}.{args.format}" | |
| 558 | |
| 559 output_path = output_dir / output_filename | |
| 560 print(f" Saving scatter plot to: {output_path.absolute()}") | |
| 561 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 562 | |
| 563 except Exception as e: | |
| 564 print(f" Error generating plot: {str(e)}") | |
| 565 | |
| 566 print("Scatter plot generation completed!") | |
| 567 | |
| 568 | |
| 569 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base): | |
| 570 """Generate label concordance heatmap""" | |
| 571 print("Generating label concordance heatmaps...") | |
| 572 | |
| 573 # Check if this is the specific format with sample_id, known_label, predicted_label | |
| 574 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
| 575 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
| 576 | |
| 577 if is_flexynesis_format: | |
| 578 # Parse target values from comma-separated string | |
| 579 if args.target_value: | |
| 580 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 581 else: | |
| 582 # If no target values specified, use all unique variables | |
| 583 target_values = labels['variable'].unique().tolist() | |
| 584 | |
| 585 print(f"Processing target values: {target_values}") | |
| 586 | |
| 587 for target_value in target_values: | |
| 588 print(f"\nProcessing target value: '{target_value}'") | |
| 589 | |
| 590 # Filter labels for the current target value | |
| 591 target_labels = labels[labels['variable'] == target_value] | |
| 592 | |
| 593 if target_labels.empty: | |
| 594 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
| 595 continue | |
| 596 | |
| 597 true_values = target_labels['known_label'].tolist() | |
| 598 predicted_values = target_labels['predicted_label'].tolist() | |
| 599 | |
| 600 try: | |
| 601 print(f" Generating heatmap for '{target_value}'...") | |
| 602 fig = plot_label_concordance_heatmap(true_values, predicted_values) | |
| 603 plt.close(fig) | |
| 604 | |
| 605 # Create output filename with target value | |
| 606 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
| 607 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 608 | |
| 609 output_path = output_dir / output_filename | |
| 610 print(f" Saving heatmap to: {output_path.absolute()}") | |
| 611 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 612 | |
| 613 except Exception as e: | |
| 614 print(f" Error generating heatmap for '{target_value}': {str(e)}") | |
| 615 continue | |
| 616 | |
| 617 print("Label concordance heatmap generated successfully!") | |
| 618 | |
| 619 if not is_flexynesis_format: | |
| 620 print("Labels are not in flexynesis format (Custom labels)") | |
| 621 | |
| 622 if not args.true_label or not args.predicted_label: | |
| 623 raise ValueError("For custom labels, please specify --true_label and --predicted_label arguments.") | |
| 624 | |
| 625 true_values = labels[args.true_label].tolist() | |
| 626 predicted_values = labels[args.predicted_label].tolist() | |
| 627 | |
| 628 try: | |
| 629 print(" Generating heatmap for...") | |
| 630 fig = plot_label_concordance_heatmap(true_values, predicted_values) | |
| 631 plt.close(fig) | |
| 632 | |
| 633 # Create output filename with target value | |
| 634 output_filename = f"{output_name_base}.{args.format}" | |
| 635 | |
| 636 output_path = output_dir / output_filename | |
| 637 print(f" Saving heatmap to: {output_path.absolute()}") | |
| 638 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 639 | |
| 640 except Exception as e: | |
| 641 print(f" Error generating heatmap': {str(e)}") | |
| 642 | |
| 643 print("Label concordance heatmap generated successfully!") | |
| 644 | |
| 645 | |
| 646 def generate_pr_curves(labels, args, output_dir, output_name_base): | |
| 647 """Generate precision-recall curves""" | |
| 648 print("Generating precision-recall curves...") | |
| 649 | |
| 650 # Check if this is the specific format with sample_id, known_label, predicted_label | |
| 651 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
| 652 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
| 653 | |
| 654 if not is_flexynesis_format: | |
| 655 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") | |
| 656 | |
| 462 # Parse target values from comma-separated string | 657 # Parse target values from comma-separated string |
| 463 if args.target_value: | 658 if args.target_value: |
| 464 target_values = [val.strip() for val in args.target_value.split(',')] | 659 target_values = [val.strip() for val in args.target_value.split(',')] |
| 465 else: | 660 else: |
| 466 # If no target values specified, use all unique variables | 661 # If no target values specified, use all unique variables |
| 467 target_values = labels['variable'].unique().tolist() | 662 target_values = labels['variable'].unique().tolist() |
| 468 | 663 |
| 469 print(f"Processing target values: {target_values}") | 664 print(f"Processing target values: {target_values}") |
| 470 | 665 |
| 471 successful_plots = 0 | |
| 472 skipped_plots = 0 | |
| 473 | |
| 474 for target_value in target_values: | 666 for target_value in target_values: |
| 475 print(f"\nProcessing target value: '{target_value}'") | 667 print(f"\nProcessing target value: '{target_value}'") |
| 476 | 668 |
| 477 # Filter labels for the current target value | 669 # Filter labels for the current target value |
| 478 target_labels = labels[labels['variable'] == target_value] | 670 target_labels = labels[labels['variable'] == target_value] |
| 479 | 671 |
| 480 if target_labels.empty: | 672 # Check if this is a regression problem (no class probabilities) |
| 481 print(f" Warning: No data found for target value '{target_value}' - skipping") | 673 prob_columns = target_labels['class_label'].unique() |
| 482 skipped_plots += 1 | 674 non_na_probs = target_labels['probability'].notna().sum() |
| 483 continue | 675 |
| 484 | 676 print(f" Class labels found: {list(prob_columns)}") |
| 485 # Check if labels are numeric and convert | 677 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") |
| 486 true_values = pd.to_numeric(target_labels['known_label'], errors='coerce') | 678 |
| 487 predicted_values = pd.to_numeric(target_labels['predicted_label'], errors='coerce') | 679 # If most probabilities are NaN, this is likely a regression problem |
| 488 | 680 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities |
| 489 if true_values.isna().all() or predicted_values.isna().all(): | 681 print(" Detected regression problem - precision-recall curves not applicable") |
| 490 print(f"No valid numeric values found for known or predicted labels in '{target_value}'") | 682 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") |
| 491 skipped_plots += 1 | 683 continue |
| 684 | |
| 685 # Debug: Check data quality | |
| 686 total_rows = len(target_labels) | |
| 687 missing_labels = target_labels['known_label'].isna().sum() | |
| 688 missing_probs = target_labels['probability'].isna().sum() | |
| 689 unique_samples = target_labels['sample_id'].nunique() | |
| 690 unique_classes = target_labels['class_label'].nunique() | |
| 691 | |
| 692 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
| 693 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
| 694 | |
| 695 if missing_labels > 0: | |
| 696 print(f" Warning: Found {missing_labels} missing known_label values") | |
| 697 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
| 698 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
| 699 | |
| 700 # Remove rows with missing known_label | |
| 701 target_labels = target_labels.dropna(subset=['known_label']) | |
| 702 if target_labels.empty: | |
| 703 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
| 704 continue | |
| 705 | |
| 706 # 1. Pivot to wide format | |
| 707 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') | |
| 708 | |
| 709 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") | |
| 710 print(f" Class columns: {list(prob_df.columns)}") | |
| 711 | |
| 712 # Check for NaN values in probability data | |
| 713 nan_counts = prob_df.isna().sum() | |
| 714 if nan_counts.any(): | |
| 715 print(f" NaN counts per class: {dict(nan_counts)}") | |
| 716 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") | |
| 717 | |
| 718 # Drop only rows where ALL probabilities are NaN | |
| 719 all_nan_rows = prob_df.isna().all(axis=1) | |
| 720 if all_nan_rows.any(): | |
| 721 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") | |
| 722 prob_df = prob_df[~all_nan_rows] | |
| 723 | |
| 724 remaining_nans = prob_df.isna().sum().sum() | |
| 725 if remaining_nans > 0: | |
| 726 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") | |
| 727 prob_df = prob_df.fillna(0) | |
| 728 | |
| 729 if prob_df.empty: | |
| 730 print(f" Error: No valid probability data remaining for '{target_value}' - skipping") | |
| 731 continue | |
| 732 | |
| 733 # 2. Get true labels | |
| 734 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') | |
| 735 | |
| 736 # 3. Align indices - only keep samples that exist in both datasets | |
| 737 common_indices = prob_df.index.intersection(true_labels_df.index) | |
| 738 if len(common_indices) == 0: | |
| 739 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") | |
| 740 continue | |
| 741 | |
| 742 print(f" Found {len(common_indices)} samples with both probability and true label data") | |
| 743 | |
| 744 # Filter both datasets to common indices | |
| 745 prob_df_aligned = prob_df.loc[common_indices] | |
| 746 y_true = true_labels_df.loc[common_indices]['known_label'] | |
| 747 | |
| 748 # 4. Final check for NaN values | |
| 749 if y_true.isna().any(): | |
| 750 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") | |
| 751 continue | |
| 752 | |
| 753 if prob_df_aligned.isna().any().any(): | |
| 754 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") | |
| 755 continue | |
| 756 | |
| 757 # 5. Convert categorical labels to integer labels | |
| 758 # Create a mapping from class names to integers | |
| 759 class_names = list(prob_df_aligned.columns) | |
| 760 class_to_int = {class_name: i for i, class_name in enumerate(class_names)} | |
| 761 | |
| 762 print(f" Class mapping: {class_to_int}") | |
| 763 | |
| 764 # Convert true labels to integers | |
| 765 y_true_np = y_true.map(class_to_int).to_numpy() | |
| 766 y_probs_np = prob_df_aligned.to_numpy() | |
| 767 | |
| 768 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") | |
| 769 print(f" Unique true labels (integers): {set(y_true_np)}") | |
| 770 print(f" Class labels (columns): {class_names}") | |
| 771 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") | |
| 772 | |
| 773 # Check for any unmapped labels (will be NaN) | |
| 774 if pd.isna(y_true_np).any(): | |
| 775 print(" Error: Some true labels could not be mapped to class columns") | |
| 776 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) | |
| 777 print(f" Unmapped labels: {unmapped_labels}") | |
| 778 print(f" Available classes: {class_names}") | |
| 492 continue | 779 continue |
| 493 | 780 |
| 494 try: | 781 try: |
| 495 print(f" Generating scatter plot for '{target_value}'...") | 782 print(f" Generating precision-recall curve for '{target_value}'...") |
| 496 fig = plot_scatter(true_values, predicted_values) | 783 fig = plot_pr_curves(y_true_np, y_probs_np) |
| 497 | 784 |
| 498 # Create output filename with target value | 785 # Create output filename with target value |
| 499 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | 786 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') |
| 500 if len(target_values) > 1: | 787 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" |
| 501 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 502 else: | |
| 503 output_filename = f"{output_name_base}.{args.format}" | |
| 504 | 788 |
| 505 output_path = output_dir / output_filename | 789 output_path = output_dir / output_filename |
| 506 print(f" Saving scatter plot to: {output_path.absolute()}") | 790 print(f" Saving precision-recall curve to: {output_path.absolute()}") |
| 507 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | 791 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') |
| 508 | 792 |
| 509 successful_plots += 1 | |
| 510 print(f" Scatter plot for '{target_value}' generated successfully!") | |
| 511 | |
| 512 except Exception as e: | 793 except Exception as e: |
| 513 print(f" Error generating plot for '{target_value}': {str(e)}") | 794 print(f" Error generating precision-recall curve for '{target_value}': {str(e)}") |
| 514 skipped_plots += 1 | 795 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") |
| 515 | 796 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") |
| 516 # Summary | 797 continue |
| 517 print(" Summary:") | 798 |
| 518 print(f" Successfully generated: {successful_plots} plots") | 799 print("Precision-recall curves generated successfully!") |
| 519 print(f" Skipped: {skipped_plots} plots") | 800 |
| 520 | 801 |
| 521 if successful_plots == 0: | 802 def generate_roc_curves(labels, args, output_dir, output_name_base): |
| 522 raise ValueError("No scatter plots could be generated. Check your data and target values.") | 803 """Generate ROC curves""" |
| 523 | 804 print("Generating ROC curves...") |
| 524 print("Scatter plot generation completed!") | 805 |
| 525 | 806 # Check if this is the specific format with sample_id, known_label, predicted_label |
| 526 | 807 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] |
| 527 def generate_label_concordance_heatmap(labels, args, output_dir, output_name_base): | 808 is_flexynesis_format = all(col in labels.columns for col in required_cols) |
| 528 """Generate label concordance heatmap""" | 809 |
| 529 print("Generating label concordance heatmaps...") | 810 if not is_flexynesis_format: |
| 811 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") | |
| 530 | 812 |
| 531 # Parse target values from comma-separated string | 813 # Parse target values from comma-separated string |
| 532 if args.target_value: | 814 if args.target_value: |
| 533 target_values = [val.strip() for val in args.target_value.split(',')] | 815 target_values = [val.strip() for val in args.target_value.split(',')] |
| 534 else: | 816 else: |
| 541 print(f"\nProcessing target value: '{target_value}'") | 823 print(f"\nProcessing target value: '{target_value}'") |
| 542 | 824 |
| 543 # Filter labels for the current target value | 825 # Filter labels for the current target value |
| 544 target_labels = labels[labels['variable'] == target_value] | 826 target_labels = labels[labels['variable'] == target_value] |
| 545 | 827 |
| 546 if target_labels.empty: | |
| 547 print(f" Warning: No data found for target value '{target_value}' - skipping") | |
| 548 continue | |
| 549 | |
| 550 true_values = target_labels['known_label'].tolist() | |
| 551 predicted_values = target_labels['predicted_label'].tolist() | |
| 552 | |
| 553 try: | |
| 554 print(f" Generating heatmap for '{target_value}'...") | |
| 555 fig = plot_label_concordance_heatmap(true_values, predicted_values) | |
| 556 plt.close(fig) | |
| 557 | |
| 558 # Create output filename with target value | |
| 559 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
| 560 if len(target_values) > 1: | |
| 561 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 562 else: | |
| 563 output_filename = f"{output_name_base}.{args.format}" | |
| 564 | |
| 565 output_path = output_dir / output_filename | |
| 566 print(f" Saving heatmap to: {output_path.absolute()}") | |
| 567 fig.savefig(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 568 | |
| 569 except Exception as e: | |
| 570 print(f" Error generating heatmap for '{target_value}': {str(e)}") | |
| 571 continue | |
| 572 | |
| 573 print("Label concordance heatmap generated successfully!") | |
| 574 | |
| 575 | |
| 576 def generate_pr_curves(labels, args, output_dir, output_name_base): | |
| 577 """Generate precision-recall curves""" | |
| 578 print("Generating precision-recall curves...") | |
| 579 | |
| 580 # Parse target values from comma-separated string | |
| 581 if args.target_value: | |
| 582 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 583 else: | |
| 584 # If no target values specified, use all unique variables | |
| 585 target_values = labels['variable'].unique().tolist() | |
| 586 | |
| 587 print(f"Processing target values: {target_values}") | |
| 588 | |
| 589 for target_value in target_values: | |
| 590 print(f"\nProcessing target value: '{target_value}'") | |
| 591 | |
| 592 # Filter labels for the current target value | |
| 593 target_labels = labels[labels['variable'] == target_value] | |
| 594 | |
| 595 # Check if this is a regression problem (no class probabilities) | 828 # Check if this is a regression problem (no class probabilities) |
| 596 prob_columns = target_labels['class_label'].unique() | 829 prob_columns = target_labels['class_label'].unique() |
| 597 non_na_probs = target_labels['probability'].notna().sum() | 830 non_na_probs = target_labels['probability'].notna().sum() |
| 598 | 831 |
| 599 print(f" Class labels found: {list(prob_columns)}") | 832 print(f" Class labels found: {list(prob_columns)}") |
| 600 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | 833 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") |
| 601 | 834 |
| 602 # If most probabilities are NaN, this is likely a regression problem | 835 # If most probabilities are NaN, this is likely a regression problem |
| 603 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | 836 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities |
| 604 print(" Detected regression problem - precision-recall curves not applicable") | 837 print(" Detected regression problem - ROC curves not applicable") |
| 605 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | 838 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") |
| 606 continue | 839 continue |
| 607 | 840 |
| 608 # Debug: Check data quality | 841 # Debug: Check data quality |
| 609 total_rows = len(target_labels) | 842 total_rows = len(target_labels) |
| 700 print(f" Unmapped labels: {unmapped_labels}") | 933 print(f" Unmapped labels: {unmapped_labels}") |
| 701 print(f" Available classes: {class_names}") | 934 print(f" Available classes: {class_names}") |
| 702 continue | 935 continue |
| 703 | 936 |
| 704 try: | 937 try: |
| 705 print(f" Generating precision-recall curve for '{target_value}'...") | 938 print(f" Generating ROC curve for '{target_value}'...") |
| 706 fig = plot_pr_curves(y_true_np, y_probs_np) | 939 fig = plot_roc_curves(y_true_np, y_probs_np) |
| 707 | 940 |
| 708 # Create output filename with target value | 941 # Create output filename with target value |
| 709 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | 942 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') |
| 710 if len(target_values) > 1: | 943 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" |
| 711 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 712 else: | |
| 713 output_filename = f"{output_name_base}.{args.format}" | |
| 714 | |
| 715 output_path = output_dir / output_filename | |
| 716 print(f" Saving precision-recall curve to: {output_path.absolute()}") | |
| 717 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | |
| 718 | |
| 719 except Exception as e: | |
| 720 print(f" Error generating precision-recall curve for '{target_value}': {str(e)}") | |
| 721 print(f" Debug info - y_true type: {type(y_true_np)}, contains NaN: {pd.isna(y_true_np).any()}") | |
| 722 print(f" Debug info - y_probs type: {type(y_probs_np)}, contains NaN: {pd.isna(y_probs_np).any()}") | |
| 723 continue | |
| 724 | |
| 725 print("Precision-recall curves generated successfully!") | |
| 726 | |
| 727 | |
| 728 def generate_roc_curves(labels, args, output_dir, output_name_base): | |
| 729 """Generate ROC curves""" | |
| 730 print("Generating ROC curves...") | |
| 731 | |
| 732 # Parse target values from comma-separated string | |
| 733 if args.target_value: | |
| 734 target_values = [val.strip() for val in args.target_value.split(',')] | |
| 735 else: | |
| 736 # If no target values specified, use all unique variables | |
| 737 target_values = labels['variable'].unique().tolist() | |
| 738 | |
| 739 print(f"Processing target values: {target_values}") | |
| 740 | |
| 741 for target_value in target_values: | |
| 742 print(f"\nProcessing target value: '{target_value}'") | |
| 743 | |
| 744 # Filter labels for the current target value | |
| 745 target_labels = labels[labels['variable'] == target_value] | |
| 746 | |
| 747 # Check if this is a regression problem (no class probabilities) | |
| 748 prob_columns = target_labels['class_label'].unique() | |
| 749 non_na_probs = target_labels['probability'].notna().sum() | |
| 750 | |
| 751 print(f" Class labels found: {list(prob_columns)}") | |
| 752 print(f" Non-NaN probabilities: {non_na_probs}/{len(target_labels)}") | |
| 753 | |
| 754 # If most probabilities are NaN, this is likely a regression problem | |
| 755 if non_na_probs < len(target_labels) * 0.1: # Less than 10% valid probabilities | |
| 756 print(" Detected regression problem - ROC curves not applicable") | |
| 757 print(f" Skipping '{target_value}' (use regression evaluation metrics instead)") | |
| 758 continue | |
| 759 | |
| 760 # Debug: Check data quality | |
| 761 total_rows = len(target_labels) | |
| 762 missing_labels = target_labels['known_label'].isna().sum() | |
| 763 missing_probs = target_labels['probability'].isna().sum() | |
| 764 unique_samples = target_labels['sample_id'].nunique() | |
| 765 unique_classes = target_labels['class_label'].nunique() | |
| 766 | |
| 767 print(f" Data summary: {total_rows} total rows, {unique_samples} unique samples, {unique_classes} unique classes") | |
| 768 print(f" Missing data: {missing_labels} missing known_label, {missing_probs} missing probability") | |
| 769 | |
| 770 if missing_labels > 0: | |
| 771 print(f" Warning: Found {missing_labels} missing known_label values") | |
| 772 missing_samples = target_labels[target_labels['known_label'].isna()]['sample_id'].unique()[:5] | |
| 773 print(f" Sample IDs with missing known_label: {list(missing_samples)}") | |
| 774 | |
| 775 # Remove rows with missing known_label | |
| 776 target_labels = target_labels.dropna(subset=['known_label']) | |
| 777 if target_labels.empty: | |
| 778 print(f" Error: No valid known_label data remaining for '{target_value}' - skipping") | |
| 779 continue | |
| 780 | |
| 781 # 1. Pivot to wide format | |
| 782 prob_df = target_labels.pivot(index='sample_id', columns='class_label', values='probability') | |
| 783 | |
| 784 print(f" After pivot: {prob_df.shape[0]} samples x {prob_df.shape[1]} classes") | |
| 785 print(f" Class columns: {list(prob_df.columns)}") | |
| 786 | |
| 787 # Check for NaN values in probability data | |
| 788 nan_counts = prob_df.isna().sum() | |
| 789 if nan_counts.any(): | |
| 790 print(f" NaN counts per class: {dict(nan_counts)}") | |
| 791 print(f" Samples with any NaN: {prob_df.isna().any(axis=1).sum()}/{len(prob_df)}") | |
| 792 | |
| 793 # Drop only rows where ALL probabilities are NaN | |
| 794 all_nan_rows = prob_df.isna().all(axis=1) | |
| 795 if all_nan_rows.any(): | |
| 796 print(f" Dropping {all_nan_rows.sum()} samples with all NaN probabilities") | |
| 797 prob_df = prob_df[~all_nan_rows] | |
| 798 | |
| 799 remaining_nans = prob_df.isna().sum().sum() | |
| 800 if remaining_nans > 0: | |
| 801 print(f" Warning: {remaining_nans} individual NaN values remain - filling with 0") | |
| 802 prob_df = prob_df.fillna(0) | |
| 803 | |
| 804 if prob_df.empty: | |
| 805 print(f" Error: No valid probability data remaining for '{target_value}' - skipping") | |
| 806 continue | |
| 807 | |
| 808 # 2. Get true labels | |
| 809 true_labels_df = target_labels.drop_duplicates('sample_id')[['sample_id', 'known_label']].set_index('sample_id') | |
| 810 | |
| 811 # 3. Align indices - only keep samples that exist in both datasets | |
| 812 common_indices = prob_df.index.intersection(true_labels_df.index) | |
| 813 if len(common_indices) == 0: | |
| 814 print(f" Error: No common sample_ids between probability and true label data for '{target_value}' - skipping") | |
| 815 continue | |
| 816 | |
| 817 print(f" Found {len(common_indices)} samples with both probability and true label data") | |
| 818 | |
| 819 # Filter both datasets to common indices | |
| 820 prob_df_aligned = prob_df.loc[common_indices] | |
| 821 y_true = true_labels_df.loc[common_indices]['known_label'] | |
| 822 | |
| 823 # 4. Final check for NaN values | |
| 824 if y_true.isna().any(): | |
| 825 print(f" Error: True labels still contain NaN after alignment for '{target_value}' - skipping") | |
| 826 continue | |
| 827 | |
| 828 if prob_df_aligned.isna().any().any(): | |
| 829 print(f" Error: Probability data still contains NaN after alignment for '{target_value}' - skipping") | |
| 830 continue | |
| 831 | |
| 832 # 5. Convert categorical labels to integer labels | |
| 833 # Create a mapping from class names to integers | |
| 834 class_names = list(prob_df_aligned.columns) | |
| 835 class_to_int = {class_name: i for i, class_name in enumerate(class_names)} | |
| 836 | |
| 837 print(f" Class mapping: {class_to_int}") | |
| 838 | |
| 839 # Convert true labels to integers | |
| 840 y_true_np = y_true.map(class_to_int).to_numpy() | |
| 841 y_probs_np = prob_df_aligned.to_numpy() | |
| 842 | |
| 843 print(f" Data shape: y_true={y_true_np.shape}, y_probs={y_probs_np.shape}") | |
| 844 print(f" Unique true labels (integers): {set(y_true_np)}") | |
| 845 print(f" Class labels (columns): {class_names}") | |
| 846 print(f" Label distribution: {dict(zip(*np.unique(y_true_np, return_counts=True)))}") | |
| 847 | |
| 848 # Check for any unmapped labels (will be NaN) | |
| 849 if pd.isna(y_true_np).any(): | |
| 850 print(" Error: Some true labels could not be mapped to class columns") | |
| 851 unmapped_labels = set(y_true[y_true.map(class_to_int).isna()]) | |
| 852 print(f" Unmapped labels: {unmapped_labels}") | |
| 853 print(f" Available classes: {class_names}") | |
| 854 continue | |
| 855 | |
| 856 try: | |
| 857 print(f" Generating ROC curve for '{target_value}'...") | |
| 858 fig = plot_roc_curves(y_true_np, y_probs_np) | |
| 859 | |
| 860 # Create output filename with target value | |
| 861 safe_target_name = target_value.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
| 862 if len(target_values) > 1: | |
| 863 output_filename = f"{output_name_base}_{safe_target_name}.{args.format}" | |
| 864 else: | |
| 865 output_filename = f"{output_name_base}.{args.format}" | |
| 866 | 944 |
| 867 output_path = output_dir / output_filename | 945 output_path = output_dir / output_filename |
| 868 print(f" Saving ROC curve to: {output_path.absolute()}") | 946 print(f" Saving ROC curve to: {output_path.absolute()}") |
| 869 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') | 947 fig.save(output_path, dpi=args.dpi, bbox_inches='tight') |
| 870 | 948 |
| 877 print("ROC curves generated successfully!") | 955 print("ROC curves generated successfully!") |
| 878 | 956 |
| 879 | 957 |
| 880 def generate_box_plots(labels, args, output_dir, output_name_base): | 958 def generate_box_plots(labels, args, output_dir, output_name_base): |
| 881 """Generate box plots for model predictions""" | 959 """Generate box plots for model predictions""" |
| 960 | |
| 961 # Check if this is the specific format with sample_id, known_label, predicted_label | |
| 962 required_cols = ['sample_id', 'variable', 'class_label', 'probability', 'known_label', 'predicted_label'] | |
| 963 is_flexynesis_format = all(col in labels.columns for col in required_cols) | |
| 964 | |
| 965 if not is_flexynesis_format: | |
| 966 raise ValueError(f"Labels are not in flexynesis format (Custom labels). Please provide a valid label file with the required columns, {required_cols}.") | |
| 882 | 967 |
| 883 print("Generating box plots...") | 968 print("Generating box plots...") |
| 884 | 969 |
| 885 # Parse target values from comma-separated string | 970 # Parse target values from comma-separated string |
| 886 if args.target_value: | 971 if args.target_value: |
| 991 # Arguments for dimensionality reduction | 1076 # Arguments for dimensionality reduction |
| 992 parser.add_argument("--embeddings", type=str, | 1077 parser.add_argument("--embeddings", type=str, |
| 993 help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.") | 1078 help="Path to input data embeddings file (CSV or tabular format). Required for dimred plots.") |
| 994 parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'], | 1079 parser.add_argument("--method", type=str, default='pca', choices=['pca', 'umap'], |
| 995 help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.") | 1080 help="Transformation method ('pca' or 'umap'). Default is 'pca'. Used for dimred plots.") |
| 1081 parser.add_argument("--color", type=str, default=None, | |
| 1082 help="User-defined color for the plot.") | |
| 996 | 1083 |
| 997 # Arguments for Kaplan-Meier | 1084 # Arguments for Kaplan-Meier |
| 998 parser.add_argument("--survival_data", type=str, | 1085 parser.add_argument("--survival_data", type=str, |
| 999 help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.") | 1086 help="Path to survival data file with columns: duration and event. Required for kaplan_meier plots.") |
| 1000 parser.add_argument("--surv_time_var", type=str, required=False, | 1087 parser.add_argument("--surv_time_var", type=str, required=False, |
| 1001 help="Column name for survival time") | 1088 help="Column name for survival time") |
| 1002 parser.add_argument("--surv_event_var", type=str, required=False, | 1089 parser.add_argument("--surv_event_var", type=str, required=False, |
| 1003 help="Column name for survival event") | 1090 help="Column name for survival event") |
| 1004 parser.add_argument("--event_value", type=str, required=False, | |
| 1005 help="Value in event column that represents an event (e.g., 'DECEASED')") | |
| 1006 | 1091 |
| 1007 # Arguments for Cox analysis | 1092 # Arguments for Cox analysis |
| 1008 parser.add_argument("--model", type=str, | 1093 parser.add_argument("--important_features", type=str, |
| 1009 help="Path to trained flexynesis model (pickle file). Required for cox plots.") | 1094 help="Path to calculated feature importance file. Required for cox plots.") |
| 1010 parser.add_argument("--clinical_train", type=str, | 1095 parser.add_argument("--clinical_train", type=str, |
| 1011 help="Path to training dataset (pickle file). Required for cox plots.") | 1096 help="Path to training dataset (pickle file). Required for cox plots.") |
| 1012 parser.add_argument("--clinical_test", type=str, | 1097 parser.add_argument("--clinical_test", type=str, |
| 1013 help="Path to test dataset (pickle file). Required for cox plots.") | 1098 help="Path to test dataset (pickle file). Required for cox plots.") |
| 1014 parser.add_argument("--omics_train", type=str, default=None, | 1099 parser.add_argument("--omics_train", type=str, default=None, |
| 1023 help="If True, performs K-fold cross-validation and returns average C-index. Default is False") | 1108 help="If True, performs K-fold cross-validation and returns average C-index. Default is False") |
| 1024 parser.add_argument("--n_splits", type=int, default=5, | 1109 parser.add_argument("--n_splits", type=int, default=5, |
| 1025 help="Number of folds for cross-validation. Default is 5") | 1110 help="Number of folds for cross-validation. Default is 5") |
| 1026 parser.add_argument("--random_state", type=int, default=42, | 1111 parser.add_argument("--random_state", type=int, default=42, |
| 1027 help="Random seed for reproducibility. Default is 42") | 1112 help="Random seed for reproducibility. Default is 42") |
| 1113 parser.add_argument("--layer", type=str, default=None, | |
| 1114 help="Class label for filtering important features.") | |
| 1028 | 1115 |
| 1029 # Arguments for dimred, scatter plot, heatmap, PR curves, ROC curves, and box plots | 1116 # Arguments for dimred, scatter plot, heatmap, PR curves, ROC curves, and box plots |
| 1030 parser.add_argument("--target_value", type=str, default=None, | 1117 parser.add_argument("--target_value", type=str, default=None, |
| 1031 help="Target value for scatter plot.") | 1118 help="Target value for scatter plot.") |
| 1032 | 1119 |
| 1120 # Arguments for scatter plots and concordance heatmaps | |
| 1121 parser.add_argument("--true_label", type=str, default=None, | |
| 1122 help="Column name for true labels in scatter plots and concordance heatmaps.") | |
| 1123 parser.add_argument("--predicted_label", type=str, default=None, | |
| 1124 help="Column name for predicted labels in scatter plots and concordance heatmaps.") | |
| 1033 # Common arguments | 1125 # Common arguments |
| 1034 parser.add_argument("--output_dir", type=str, default='output', | 1126 parser.add_argument("--output_dir", type=str, default='output', |
| 1035 help="Output directory. Default is 'output'") | 1127 help="Output directory. Default is 'output'") |
| 1036 parser.add_argument("--output_name", type=str, default=None, | 1128 parser.add_argument("--output_name", type=str, default=None, |
| 1037 help="Output filename base") | 1129 help="Output filename base") |
| 1071 raise ValueError("--method is required for dimensionality reduction plots") | 1163 raise ValueError("--method is required for dimensionality reduction plots") |
| 1072 if not args.surv_time_var: | 1164 if not args.surv_time_var: |
| 1073 raise ValueError("--surv_time_var is required for Kaplan-Meier plots") | 1165 raise ValueError("--surv_time_var is required for Kaplan-Meier plots") |
| 1074 if not args.surv_event_var: | 1166 if not args.surv_event_var: |
| 1075 raise ValueError("--surv_event_var is required for Kaplan-Meier plots") | 1167 raise ValueError("--surv_event_var is required for Kaplan-Meier plots") |
| 1076 if not args.event_value: | |
| 1077 raise ValueError("--event_value is required for Kaplan-Meier plots") | |
| 1078 | 1168 |
| 1079 if args.plot_type in ['cox']: | 1169 if args.plot_type in ['cox']: |
| 1080 if not args.model: | 1170 if not args.important_features: |
| 1081 raise ValueError("--model is required when plot_type is 'cox'") | 1171 raise ValueError("--important_features is required when plot_type is 'cox'") |
| 1082 if not os.path.isfile(args.model): | 1172 if not os.path.isfile(args.important_features): |
| 1083 raise FileNotFoundError(f"Model file not found: {args.model}") | 1173 raise FileNotFoundError(f"Important features file not found: {args.important_features}") |
| 1084 if not args.clinical_train: | 1174 if not args.clinical_train: |
| 1085 raise ValueError("--clinical_train is required when plot_type is 'cox'") | 1175 raise ValueError("--clinical_train is required when plot_type is 'cox'") |
| 1086 if not os.path.isfile(args.clinical_train): | 1176 if not os.path.isfile(args.clinical_train): |
| 1087 raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}") | 1177 raise FileNotFoundError(f"Training dataset file not found: {args.clinical_train}") |
| 1088 if not args.clinical_test: | 1178 if not args.clinical_test: |
| 1100 if not args.surv_time_var: | 1190 if not args.surv_time_var: |
| 1101 raise ValueError("--surv_time_var is required for Cox plots") | 1191 raise ValueError("--surv_time_var is required for Cox plots") |
| 1102 if not args.surv_event_var: | 1192 if not args.surv_event_var: |
| 1103 raise ValueError("--surv_event_var is required for Cox plots") | 1193 raise ValueError("--surv_event_var is required for Cox plots") |
| 1104 if not args.clinical_variables: | 1194 if not args.clinical_variables: |
| 1105 raise ValueError("--clinical_variables is required for Cox plots") | 1195 print("--clinical_variables is not set for Cox plots") |
| 1106 if not isinstance(args.top_features, int) or args.top_features <= 0: | 1196 if not isinstance(args.top_features, int) or args.top_features <= 0: |
| 1107 raise ValueError("--top_features must be a positive integer") | 1197 raise ValueError("--top_features must be a positive integer") |
| 1108 if not args.event_value: | |
| 1109 raise ValueError("--event_value is required for Kaplan-Meier plots") | |
| 1110 if not args.crossval: | 1198 if not args.crossval: |
| 1111 args.crossval = False | 1199 args.crossval = False |
| 1112 if not isinstance(args.n_splits, int) or args.n_splits <= 0: | 1200 if not isinstance(args.n_splits, int) or args.n_splits <= 0: |
| 1113 raise ValueError("--n_splits must be a positive integer") | 1201 raise ValueError("--n_splits must be a positive integer") |
| 1114 if not isinstance(args.random_state, int): | 1202 if not isinstance(args.random_state, int): |
| 1115 raise ValueError("--random_state must be an integer") | 1203 raise ValueError("--random_state must be an integer") |
| 1204 if not args.layer: | |
| 1205 print("--layer is not specified, using all classes from labels") | |
| 1116 | 1206 |
| 1117 if args.plot_type in ['scatter']: | 1207 if args.plot_type in ['scatter']: |
| 1118 if not args.labels: | 1208 if not args.labels: |
| 1119 raise ValueError("--labels is required for scatter plots") | 1209 raise ValueError("--labels is required for scatter plots") |
| 1120 if not args.target_value: | 1210 if not args.target_value: |
| 1172 output_name_base = f"{embeddings_name}_{args.method}" | 1262 output_name_base = f"{embeddings_name}_{args.method}" |
| 1173 elif args.plot_type == 'kaplan_meier': | 1263 elif args.plot_type == 'kaplan_meier': |
| 1174 survival_name = Path(args.survival_data).stem | 1264 survival_name = Path(args.survival_data).stem |
| 1175 output_name_base = f"{survival_name}_km" | 1265 output_name_base = f"{survival_name}_km" |
| 1176 elif args.plot_type == 'cox': | 1266 elif args.plot_type == 'cox': |
| 1177 model_name = Path(args.model).stem | 1267 model_name = Path(args.important_features).stem |
| 1178 output_name_base = f"{model_name}_cox" | 1268 output_name_base = f"{model_name}_cox" |
| 1179 elif args.plot_type == 'scatter': | 1269 elif args.plot_type == 'scatter': |
| 1180 labels_name = Path(args.labels).stem | 1270 labels_name = Path(args.labels).stem |
| 1181 output_name_base = f"{labels_name}_scatter" | 1271 output_name_base = f"{labels_name}_scatter" |
| 1182 elif args.plot_type == 'concordance_heatmap': | 1272 elif args.plot_type == 'concordance_heatmap': |
| 1194 | 1284 |
| 1195 # Generate plots based on type | 1285 # Generate plots based on type |
| 1196 if args.plot_type in ['dimred']: | 1286 if args.plot_type in ['dimred']: |
| 1197 # Load labels | 1287 # Load labels |
| 1198 print(f"Loading labels from: {args.labels}") | 1288 print(f"Loading labels from: {args.labels}") |
| 1199 label_data = load_labels(args.labels) | 1289 labels = load_labels(args.labels) |
| 1200 # Load embeddings data | 1290 # Load embeddings data |
| 1201 print(f"Loading embeddings from: {args.embeddings}") | 1291 print(f"Loading embeddings from: {args.embeddings}") |
| 1202 embeddings, sample_names = load_embeddings(args.embeddings) | 1292 embeddings, sample_names = load_embeddings(args.embeddings) |
| 1203 print(f"embeddings shape: {embeddings.shape}") | 1293 print(f"embeddings shape: {embeddings.shape}") |
| 1204 | 1294 |
| 1205 # Match samples to embeddings | 1295 # Match samples to embeddings |
| 1206 matched_labels = match_samples_to_embeddings(sample_names, label_data) | 1296 matched_labels = match_samples_to_embeddings(sample_names, labels) |
| 1207 print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction") | 1297 print(f"Successfully matched {len(matched_labels)} samples for dimensionality reduction") |
| 1208 | 1298 print(f"Matched labels shape: {matched_labels.shape}") |
| 1299 print(f"Columns in matched labels: {matched_labels.columns.tolist()}") | |
| 1209 generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base) | 1300 generate_dimred_plots(embeddings, matched_labels, args, output_dir, output_name_base) |
| 1210 | 1301 |
| 1211 elif args.plot_type in ['kaplan_meier']: | 1302 elif args.plot_type in ['kaplan_meier']: |
| 1212 # Load labels | 1303 # Load labels |
| 1213 print(f"Loading labels from: {args.labels}") | 1304 print(f"Loading labels from: {args.labels}") |
| 1214 label_data = load_labels(args.labels) | 1305 labels = load_labels(args.labels) |
| 1215 # Load survival data | 1306 # Load survival data |
| 1216 print(f"Loading survival data from: {args.survival_data}") | 1307 print(f"Loading survival data from: {args.survival_data}") |
| 1217 survival_data = load_survival_data(args.survival_data) | 1308 survival_data = load_labels(args.survival_data) |
| 1218 print(f"Survival data shape: {survival_data.shape}") | 1309 print(f"Survival data shape: {survival_data.shape}") |
| 1219 | 1310 |
| 1220 generate_km_plots(survival_data, label_data, args, output_dir, output_name_base) | 1311 generate_km_plots(survival_data, labels, args, output_dir, output_name_base) |
| 1221 | 1312 |
| 1222 elif args.plot_type in ['cox']: | 1313 elif args.plot_type in ['cox']: |
| 1223 # Load model and datasets | 1314 # Load important_features and datasets |
| 1224 print(f"Loading model from: {args.model}") | 1315 print(f"Loading important features from: {args.important_features}") |
| 1225 model = load_model(args.model) | 1316 important_features = load_labels(args.important_features) |
| 1226 print(f"Loading training dataset from: {args.clinical_train}") | 1317 print(f"Loading training dataset from: {args.clinical_train}") |
| 1227 clinical_train = load_omics(args.clinical_train) | 1318 clinical_train = load_omics(args.clinical_train) |
| 1228 print(f"Loading test dataset from: {args.clinical_test}") | 1319 print(f"Loading test dataset from: {args.clinical_test}") |
| 1229 clinical_test = load_omics(args.clinical_test) | 1320 clinical_test = load_omics(args.clinical_test) |
| 1230 print(f"Loading training omics dataset from: {args.omics_train}") | 1321 print(f"Loading training omics dataset from: {args.omics_train}") |
| 1231 omics_train = load_omics(args.omics_train) | 1322 omics_train = load_omics(args.omics_train) |
| 1232 print(f"Loading test omics dataset from: {args.omics_test}") | 1323 print(f"Loading test omics dataset from: {args.omics_test}") |
| 1233 omics_test = load_omics(args.omics_test) | 1324 omics_test = load_omics(args.omics_test) |
| 1234 | 1325 |
| 1235 generate_cox_plots(model, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) | 1326 generate_cox_plots(important_features, clinical_train, clinical_test, omics_test, omics_train, args, output_dir, output_name_base) |
| 1236 | 1327 |
| 1237 elif args.plot_type in ['scatter']: | 1328 elif args.plot_type in ['scatter']: |
| 1238 # Load labels | 1329 # Load labels |
| 1239 print(f"Loading labels from: {args.labels}") | 1330 print(f"Loading labels from: {args.labels}") |
| 1240 label_data = load_labels(args.labels) | 1331 labels = load_labels(args.labels) |
| 1241 | 1332 |
| 1242 generate_plot_scatter(label_data, args, output_dir, output_name_base) | 1333 generate_plot_scatter(labels, args, output_dir, output_name_base) |
| 1243 | 1334 |
| 1244 elif args.plot_type in ['concordance_heatmap']: | 1335 elif args.plot_type in ['concordance_heatmap']: |
| 1245 # Load labels | 1336 # Load labels |
| 1246 print(f"Loading labels from: {args.labels}") | 1337 print(f"Loading labels from: {args.labels}") |
| 1247 label_data = load_labels(args.labels) | 1338 labels = load_labels(args.labels) |
| 1248 | 1339 |
| 1249 generate_label_concordance_heatmap(label_data, args, output_dir, output_name_base) | 1340 generate_label_concordance_heatmap(labels, args, output_dir, output_name_base) |
| 1250 | 1341 |
| 1251 elif args.plot_type in ['pr_curve']: | 1342 elif args.plot_type in ['pr_curve']: |
| 1252 # Load labels | 1343 # Load labels |
| 1253 print(f"Loading labels from: {args.labels}") | 1344 print(f"Loading labels from: {args.labels}") |
| 1254 label_data = load_labels(args.labels) | 1345 labels = load_labels(args.labels) |
| 1255 | 1346 |
| 1256 generate_pr_curves(label_data, args, output_dir, output_name_base) | 1347 generate_pr_curves(labels, args, output_dir, output_name_base) |
| 1257 | 1348 |
| 1258 elif args.plot_type in ['roc_curve']: | 1349 elif args.plot_type in ['roc_curve']: |
| 1259 # Load labels | 1350 # Load labels |
| 1260 print(f"Loading labels from: {args.labels}") | 1351 print(f"Loading labels from: {args.labels}") |
| 1261 label_data = load_labels(args.labels) | 1352 labels = load_labels(args.labels) |
| 1262 | 1353 |
| 1263 generate_roc_curves(label_data, args, output_dir, output_name_base) | 1354 generate_roc_curves(labels, args, output_dir, output_name_base) |
| 1264 | 1355 |
| 1265 elif args.plot_type in ['box_plot']: | 1356 elif args.plot_type in ['box_plot']: |
| 1266 # Load labels | 1357 # Load labels |
| 1267 print(f"Loading labels from: {args.labels}") | 1358 print(f"Loading labels from: {args.labels}") |
| 1268 label_data = load_labels(args.labels) | 1359 labels = load_labels(args.labels) |
| 1269 | 1360 |
| 1270 generate_box_plots(label_data, args, output_dir, output_name_base) | 1361 generate_box_plots(labels, args, output_dir, output_name_base) |
| 1271 | 1362 |
| 1272 print("All plots generated successfully!") | 1363 print("All plots generated successfully!") |
| 1273 | 1364 |
| 1274 except (FileNotFoundError, ValueError, pd.errors.ParserError) as e: | 1365 except (FileNotFoundError, ValueError, pd.errors.ParserError) as e: |
| 1275 print(f"Error: {e}") | 1366 print(f"Error: {e}") |
