diff result_heatmap/result_heatmap.py @ 0:e94586e24004 draft default tip

planemo upload for repository https://github.com/jaidevjoshi83/MicroBiomML commit 5ef78d4decc95ac107c468499328e7f086289ff9-dirty
author jay
date Tue, 17 Feb 2026 10:52:24 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/result_heatmap/result_heatmap.py	Tue Feb 17 10:52:24 2026 +0000
@@ -0,0 +1,249 @@
+import pandas as pd #pandas==2.1.4
+import plotly.graph_objects as go #plotly==5.20.0
+import os
+import argparse
+
+
+def Analysis(values, thr=0.05):
+    # print(values)
+    better = []
+    comparable = []
+    thr = 0.05 
+
+    last_value = values[4]
+    
+    for v in values[0:4]:
+        # print(v)
+        better.append(round(last_value -v, 2) > thr)
+        comparable.append(abs(round(last_value -v, 2)) <= thr)
+
+    if all(better):
+        return (True, 'better_all' )
+    elif True in better:
+        return (True, 'better_one' )
+    elif all( comparable):
+        return (True, 'Comp_with_all' )
+    elif True in comparable:
+        return (True, 'Comp_with_one' )
+    
+
+color_scale=[
+    [0, 'green'],    # Value -1 will be red
+    [0.5, 'red'], # Value 0 will be yellow
+    [1, 'yellow']    # Value 1 will be blue
+]
+
+# Define the color scale constant
+COLOR_SCALE = {
+    'Comp_with_all': 'blue',
+    'better_all': 'violet',
+    'Comp_with_one': 'black',
+    'better_one': 'red'
+}
+def ResultSummary(file, threshold, column_list=None):
+    print(file)
+    new_DF = pd.read_csv(file, sep='\t')
+    new_DF.set_index('name', inplace=True)
+
+    DF = new_DF.T
+    DF.columns = new_DF.index
+    DF.index = new_DF.columns
+    
+    # If no column_list provided, use all columns
+    if column_list is None:
+        df = DF
+    else:
+        df = DF.iloc[column_list]
+
+    column_anno_per = {}
+    comparable = {}
+
+    for n in df.columns.to_list():
+        comparable[n] = Analysis(df[n].values, threshold)
+    return comparable
+
+def Plot(input_file, width=2460, height=800, color_labels='Greens', font_size=22, tick_font=26, tick_angle=-80, threshold=0.05, column_list=None, outfile='out.html'):
+    
+    # Parse column_list if it's a string (from command line)
+    # Convert from 1-indexed (XML) to 0-indexed (Python)
+    if isinstance(column_list, str) and column_list:
+        column_list = [int(i) - 2 for i in column_list.split(',')]
+    
+    figure_size = (width, height)
+
+    print(column_list)
+    
+    result_1 = ResultSummary(input_file, threshold, column_list)
+
+    true_columns = []
+    true_column_comp = []
+
+    for i, k in enumerate(result_1.keys()):
+        if result_1[k]:
+            true_column_comp.append((i, result_1[k], k))
+
+    plotting_columns = {
+        'Comp_with_all': [],
+        'better_all': [],
+        'Comp_with_one': [],
+        'better_one': [],
+        'None': [],
+    }
+
+    colors = COLOR_SCALE
+    arranged_columns = []
+    counter = 0
+
+    for c in colors.keys():
+        for i, a in enumerate(true_column_comp):
+            if c == a[1][1]:
+                counter += 1
+                plotting_columns[c].append((a[2], counter - 1))
+                arranged_columns.append(a[2])
+
+    # Read and prepare data for plotting - use the same processing as ResultSummary
+    new_DF = pd.read_csv(input_file, sep='\t')
+    new_DF.set_index('name', inplace=True)
+    
+    # Transpose to get classifiers as rows and metrics as columns
+    DF = new_DF.T
+    DF.columns = new_DF.index
+    DF.index = new_DF.columns
+    
+    column_list
+
+    # Apply column_list filter if provided
+    if column_list is None:
+        df = DF
+    else:
+        df = DF.iloc[column_list]
+
+    print(df)
+    
+    # Filter to only keep the arranged_columns (columns that pass the analysis)
+    if arranged_columns:
+        df = df[arranged_columns]
+    
+    df.index.name = 'name'
+
+    # print(height, width)
+
+    heatmap = go.Heatmap(
+        z=df.values,
+        x=df.columns,
+        zmin=0,
+        zmax=1,
+        y=df.index,
+        # colorbar=dict(title='Value'),
+        text=df.values,  # Display values in each cell
+        texttemplate="%{text}",  # Format for text
+        colorscale=color_labels, 
+        textfont=dict(size=font_size, color='white')
+    )
+
+    shapes = []
+
+    for i in range(5, len(df), 5):
+        shapes.append(
+            go.layout.Shape(
+                type='line',
+                x0=-0.5,
+                x1=len(df.columns) - 0.5,
+                y0=i - 0.5,
+                y1=i - 0.5,
+                line=dict(color='white', width=1),
+            )
+        )
+
+    ind = 0
+    for t in plotting_columns.keys():
+        if t != 'None' and len(plotting_columns[t]) > 0:
+            col_idx = plotting_columns[t][0][1]
+            row_idx = 4
+            shape1 = go.layout.Shape(
+                type='rect',
+                x0=col_idx - 0.48,
+                x1=plotting_columns[t][-1][1] + 0.48,
+                y0=row_idx - 4.5,
+                y1=row_idx + 0.5,
+                line=dict(color=colors[t], width=2.5),  # Use color from the color scale constant
+                fillcolor='rgba(255, 255, 255, 0)',  # Transparent fill
+            )
+            shapes.append(shape1)
+
+    fig = go.Figure(data=[heatmap])
+
+    print(input_file.split('/')[len(input_file.split('/'))-1].split('.')[0])
+
+    # Create legend annotations for border colors at the top
+    legend_annotations = []
+    legend_labels = {
+        'better_all': 'Better than all (≥threshold)',
+        'better_one': 'Better than some',
+        'Comp_with_all': 'Comparable with all',
+        'Comp_with_one': 'Comparable with some'
+    }
+    
+    x_position = 0.0
+    for color_key, label_text in legend_labels.items():
+        legend_annotations.append(
+            dict(
+                x=x_position,
+                y=1.12,
+                xref='paper',
+                yref='paper',
+                text=f'<b style="color:{colors[color_key]};font-size:14px;">■</b> {label_text}',
+                showarrow=False,
+                xanchor='left',
+                yanchor='bottom',
+                font=dict(size=11)
+            )
+        )
+        x_position += 0.25
+
+    fig.update_layout(
+        width=figure_size[0],
+        height=figure_size[1],
+        shapes=shapes,
+        title='',
+        xaxis=dict(title='Study', tickfont=dict(size=24),  tickangle=tick_angle),
+        yaxis=dict(title='Classifier', tickfont=dict(size=24) ),
+        yaxis_autorange='reversed',
+        # colorscale=[[1, 'blue'], [-1, 'red']],
+        autosize=False,
+        annotations=legend_annotations,
+        margin=dict(t=200)  # Add top margin for legend
+    )
+
+    # Save the figure as HTML
+    fig.write_html(outfile)
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Plot heatmap from TSV data with classification results.")
+    parser.add_argument("--input_file", type=str, default="test_data_age_category.tsv", help="Path to input TSV file (default: test_data_age_category.tsv)")
+    parser = argparse.ArgumentParser(description="Plot heatmap from TSV data with classification results.")
+    parser.add_argument("--input_file", type=str, default="test_data_age_category.tsv", help="Path to input TSV file (default: test_data_age_category.tsv)")
+    parser.add_argument("--column_list", type=str, default=None, help="Comma-separated column indices to plot (default: None - plots all data)")
+    parser.add_argument("--width", type=int, default=2460, help="Figure width in pixels (default: 2460)")
+    parser.add_argument("--height", type=int, default=800, help="Figure height in pixels (default: 800)")
+    parser.add_argument("--color_labels", type=str, default="Greens", help="Color scheme for heatmap (default: Greens)")
+    parser.add_argument("--font_size", type=int, default=22, help="Font size for cell text (default: 22)")
+    parser.add_argument("--tick_font", type=int, default=26, help="Font size for tick labels (default: 26)")
+    parser.add_argument("--tick_angle", type=int, default=-80, help="Angle of x-axis tick labels in degrees (default: -80)")
+    parser.add_argument("--threshold", type=float, default=0.05, help="Threshold for comparison analysis (default: 0.05)")
+    parser.add_argument("--output", type=str, default="out.html", help="Output file path (default: out.html)")
+
+    args = parser.parse_args()
+
+    Plot(
+        input_file=args.input_file,
+        width=args.width,
+        height=args.height,
+        color_labels=args.color_labels,
+        font_size=int(args.font_size),
+        tick_font=int(args.tick_font),
+        tick_angle=int(args.tick_angle),
+        threshold=float(args.threshold),
+        column_list=args.column_list,
+        outfile=args.output
+    )
\ No newline at end of file