diff ml_visualization_ex.py @ 0:af2624d5ab32 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:24:32 +0000
parents
children 9349ed2749c6
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/ml_visualization_ex.py	Sat May 01 01:24:32 2021 +0000
@@ -0,0 +1,645 @@
+import argparse
+import json
+import os
+import warnings
+
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import plotly
+import plotly.graph_objs as go
+from galaxy_ml.utils import load_model, read_columns, SafeEval
+from keras.models import model_from_json
+from keras.utils import plot_model
+from sklearn.feature_selection.base import SelectorMixin
+from sklearn.metrics import (auc, average_precision_score, confusion_matrix,
+                             precision_recall_curve, roc_curve)
+from sklearn.pipeline import Pipeline
+
+safe_eval = SafeEval()
+
+# plotly default colors
+default_colors = [
+    "#1f77b4",  # muted blue
+    "#ff7f0e",  # safety orange
+    "#2ca02c",  # cooked asparagus green
+    "#d62728",  # brick red
+    "#9467bd",  # muted purple
+    "#8c564b",  # chestnut brown
+    "#e377c2",  # raspberry yogurt pink
+    "#7f7f7f",  # middle gray
+    "#bcbd22",  # curry yellow-green
+    "#17becf",  # blue-teal
+]
+
+
+def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
+    """output pr-curve in html using plotly
+
+    df1 : pandas.DataFrame
+        Containing y_true
+    df2 : pandas.DataFrame
+        Containing y_score
+    pos_label : None
+        The label of positive class
+    title : str
+        Plot title
+    """
+    data = []
+    for idx in range(df1.shape[1]):
+        y_true = df1.iloc[:, idx].values
+        y_score = df2.iloc[:, idx].values
+
+        precision, recall, _ = precision_recall_curve(
+            y_true, y_score, pos_label=pos_label
+        )
+        ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
+
+        trace = go.Scatter(
+            x=recall,
+            y=precision,
+            mode="lines",
+            marker=dict(color=default_colors[idx % len(default_colors)]),
+            name="%s (area = %.3f)" % (idx, ap),
+        )
+        data.append(trace)
+
+    layout = go.Layout(
+        xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1),
+        yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1),
+        title=dict(
+            text=title or "Precision-Recall Curve",
+            x=0.5,
+            y=0.92,
+            xanchor="center",
+            yanchor="top",
+        ),
+        font=dict(family="sans-serif", size=11),
+        # control backgroud colors
+        plot_bgcolor="rgba(255,255,255,0)",
+    )
+    """
+    legend=dict(
+        x=0.95,
+        y=0,
+        traceorder="normal",
+        font=dict(
+            family="sans-serif",
+            size=9,
+            color="black"
+        ),
+        bgcolor="LightSteelBlue",
+        bordercolor="Black",
+        borderwidth=2
+    ),"""
+
+    fig = go.Figure(data=data, layout=layout)
+
+    plotly.offline.plot(fig, filename="output.html", auto_open=False)
+    # to be discovered by `from_work_dir`
+    os.rename("output.html", "output")
+
+
+def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
+    """visualize pr-curve using matplotlib and output svg image"""
+    backend = matplotlib.get_backend()
+    if "inline" not in backend:
+        matplotlib.use("SVG")
+    plt.style.use("seaborn-colorblind")
+    plt.figure()
+
+    for idx in range(df1.shape[1]):
+        y_true = df1.iloc[:, idx].values
+        y_score = df2.iloc[:, idx].values
+
+        precision, recall, _ = precision_recall_curve(
+            y_true, y_score, pos_label=pos_label
+        )
+        ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)
+
+        plt.step(
+            recall,
+            precision,
+            "r-",
+            color="black",
+            alpha=0.3,
+            lw=1,
+            where="post",
+            label="%s (area = %.3f)" % (idx, ap),
+        )
+
+    plt.xlim([0.0, 1.0])
+    plt.ylim([0.0, 1.05])
+    plt.xlabel("Recall")
+    plt.ylabel("Precision")
+    title = title or "Precision-Recall Curve"
+    plt.title(title)
+    folder = os.getcwd()
+    plt.savefig(os.path.join(folder, "output.svg"), format="svg")
+    os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
+
+
+def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None):
+    """output roc-curve in html using plotly
+
+    df1 : pandas.DataFrame
+        Containing y_true
+    df2 : pandas.DataFrame
+        Containing y_score
+    pos_label : None
+        The label of positive class
+    drop_intermediate : bool
+        Whether to drop some suboptimal thresholds
+    title : str
+        Plot title
+    """
+    data = []
+    for idx in range(df1.shape[1]):
+        y_true = df1.iloc[:, idx].values
+        y_score = df2.iloc[:, idx].values
+
+        fpr, tpr, _ = roc_curve(
+            y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
+        )
+        roc_auc = auc(fpr, tpr)
+
+        trace = go.Scatter(
+            x=fpr,
+            y=tpr,
+            mode="lines",
+            marker=dict(color=default_colors[idx % len(default_colors)]),
+            name="%s (area = %.3f)" % (idx, roc_auc),
+        )
+        data.append(trace)
+
+    layout = go.Layout(
+        xaxis=dict(
+            title="False Positive Rate", linecolor="lightslategray", linewidth=1
+        ),
+        yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1),
+        title=dict(
+            text=title or "Receiver Operating Characteristic (ROC) Curve",
+            x=0.5,
+            y=0.92,
+            xanchor="center",
+            yanchor="top",
+        ),
+        font=dict(family="sans-serif", size=11),
+        # control backgroud colors
+        plot_bgcolor="rgba(255,255,255,0)",
+    )
+    """
+    # legend=dict(
+            # x=0.95,
+            # y=0,
+            # traceorder="normal",
+            # font=dict(
+            #    family="sans-serif",
+            #    size=9,
+            #    color="black"
+            # ),
+            # bgcolor="LightSteelBlue",
+            # bordercolor="Black",
+            # borderwidth=2
+        # ),
+    """
+
+    fig = go.Figure(data=data, layout=layout)
+
+    plotly.offline.plot(fig, filename="output.html", auto_open=False)
+    # to be discovered by `from_work_dir`
+    os.rename("output.html", "output")
+
+
+def visualize_roc_curve_matplotlib(
+    df1, df2, pos_label, drop_intermediate=True, title=None
+):
+    """visualize roc-curve using matplotlib and output svg image"""
+    backend = matplotlib.get_backend()
+    if "inline" not in backend:
+        matplotlib.use("SVG")
+    plt.style.use("seaborn-colorblind")
+    plt.figure()
+
+    for idx in range(df1.shape[1]):
+        y_true = df1.iloc[:, idx].values
+        y_score = df2.iloc[:, idx].values
+
+        fpr, tpr, _ = roc_curve(
+            y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
+        )
+        roc_auc = auc(fpr, tpr)
+
+        plt.step(
+            fpr,
+            tpr,
+            "r-",
+            color="black",
+            alpha=0.3,
+            lw=1,
+            where="post",
+            label="%s (area = %.3f)" % (idx, roc_auc),
+        )
+
+    plt.xlim([0.0, 1.0])
+    plt.ylim([0.0, 1.05])
+    plt.xlabel("False Positive Rate")
+    plt.ylabel("True Positive Rate")
+    title = title or "Receiver Operating Characteristic (ROC) Curve"
+    plt.title(title)
+    folder = os.getcwd()
+    plt.savefig(os.path.join(folder, "output.svg"), format="svg")
+    os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
+
+
+def get_dataframe(file_path, plot_selection, header_name, column_name):
+    header = "infer" if plot_selection[header_name] else None
+    column_option = plot_selection[column_name]["selected_column_selector_option"]
+    if column_option in [
+        "by_index_number",
+        "all_but_by_index_number",
+        "by_header_name",
+        "all_but_by_header_name",
+    ]:
+        col = plot_selection[column_name]["col1"]
+    else:
+        col = None
+    _, input_df = read_columns(
+        file_path,
+        c=col,
+        c_option=column_option,
+        return_df=True,
+        sep="\t",
+        header=header,
+        parse_dates=True,
+    )
+    return input_df
+
+
+def main(
+    inputs,
+    infile_estimator=None,
+    infile1=None,
+    infile2=None,
+    outfile_result=None,
+    outfile_object=None,
+    groups=None,
+    ref_seq=None,
+    intervals=None,
+    targets=None,
+    fasta_path=None,
+    model_config=None,
+    true_labels=None,
+    predicted_labels=None,
+    plot_color=None,
+    title=None,
+):
+    """
+    Parameter
+    ---------
+    inputs : str
+        File path to galaxy tool parameter
+
+    infile_estimator : str, default is None
+        File path to estimator
+
+    infile1 : str, default is None
+        File path to dataset containing features or true labels.
+
+    infile2 : str, default is None
+        File path to dataset containing target values or predicted
+        probabilities.
+
+    outfile_result : str, default is None
+        File path to save the results, either cv_results or test result
+
+    outfile_object : str, default is None
+        File path to save searchCV object
+
+    groups : str, default is None
+        File path to dataset containing groups labels
+
+    ref_seq : str, default is None
+        File path to dataset containing genome sequence file
+
+    intervals : str, default is None
+        File path to dataset containing interval file
+
+    targets : str, default is None
+        File path to dataset compressed target bed file
+
+    fasta_path : str, default is None
+        File path to dataset containing fasta file
+
+    model_config : str, default is None
+        File path to dataset containing JSON config for neural networks
+
+    true_labels : str, default is None
+        File path to dataset containing true labels
+
+    predicted_labels : str, default is None
+        File path to dataset containing true predicted labels
+
+    plot_color : str, default is None
+        Color of the confusion matrix heatmap
+
+    title : str, default is None
+        Title of the confusion matrix heatmap
+    """
+    warnings.simplefilter("ignore")
+
+    with open(inputs, "r") as param_handler:
+        params = json.load(param_handler)
+
+    title = params["plotting_selection"]["title"].strip()
+    plot_type = params["plotting_selection"]["plot_type"]
+    plot_format = params["plotting_selection"]["plot_format"]
+
+    if plot_type == "feature_importances":
+        with open(infile_estimator, "rb") as estimator_handler:
+            estimator = load_model(estimator_handler)
+
+        column_option = params["plotting_selection"]["column_selector_options"][
+            "selected_column_selector_option"
+        ]
+        if column_option in [
+            "by_index_number",
+            "all_but_by_index_number",
+            "by_header_name",
+            "all_but_by_header_name",
+        ]:
+            c = params["plotting_selection"]["column_selector_options"]["col1"]
+        else:
+            c = None
+
+        _, input_df = read_columns(
+            infile1,
+            c=c,
+            c_option=column_option,
+            return_df=True,
+            sep="\t",
+            header="infer",
+            parse_dates=True,
+        )
+
+        feature_names = input_df.columns.values
+
+        if isinstance(estimator, Pipeline):
+            for st in estimator.steps[:-1]:
+                if isinstance(st[-1], SelectorMixin):
+                    mask = st[-1].get_support()
+                    feature_names = feature_names[mask]
+            estimator = estimator.steps[-1][-1]
+
+        if hasattr(estimator, "coef_"):
+            coefs = estimator.coef_
+        else:
+            coefs = getattr(estimator, "feature_importances_", None)
+        if coefs is None:
+            raise RuntimeError(
+                "The classifier does not expose "
+                '"coef_" or "feature_importances_" '
+                "attributes"
+            )
+
+        threshold = params["plotting_selection"]["threshold"]
+        if threshold is not None:
+            mask = (coefs > threshold) | (coefs < -threshold)
+            coefs = coefs[mask]
+            feature_names = feature_names[mask]
+
+        # sort
+        indices = np.argsort(coefs)[::-1]
+
+        trace = go.Bar(x=feature_names[indices], y=coefs[indices])
+        layout = go.Layout(title=title or "Feature Importances")
+        fig = go.Figure(data=[trace], layout=layout)
+
+        plotly.offline.plot(fig, filename="output.html", auto_open=False)
+        # to be discovered by `from_work_dir`
+        os.rename("output.html", "output")
+
+        return 0
+
+    elif plot_type in ("pr_curve", "roc_curve"):
+        df1 = pd.read_csv(infile1, sep="\t", header="infer")
+        df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32)
+
+        minimum = params["plotting_selection"]["report_minimum_n_positives"]
+        # filter out columns whose n_positives is beblow the threhold
+        if minimum:
+            mask = df1.sum(axis=0) >= minimum
+            df1 = df1.loc[:, mask]
+            df2 = df2.loc[:, mask]
+
+        pos_label = params["plotting_selection"]["pos_label"].strip() or None
+
+        if plot_type == "pr_curve":
+            if plot_format == "plotly_html":
+                visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
+            else:
+                visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
+        else:  # 'roc_curve'
+            drop_intermediate = params["plotting_selection"]["drop_intermediate"]
+            if plot_format == "plotly_html":
+                visualize_roc_curve_plotly(
+                    df1,
+                    df2,
+                    pos_label,
+                    drop_intermediate=drop_intermediate,
+                    title=title,
+                )
+            else:
+                visualize_roc_curve_matplotlib(
+                    df1,
+                    df2,
+                    pos_label,
+                    drop_intermediate=drop_intermediate,
+                    title=title,
+                )
+
+        return 0
+
+    elif plot_type == "rfecv_gridscores":
+        input_df = pd.read_csv(infile1, sep="\t", header="infer")
+        scores = input_df.iloc[:, 0]
+        steps = params["plotting_selection"]["steps"].strip()
+        steps = safe_eval(steps)
+
+        data = go.Scatter(
+            x=list(range(len(scores))),
+            y=scores,
+            text=[str(_) for _ in steps] if steps else None,
+            mode="lines",
+        )
+        layout = go.Layout(
+            xaxis=dict(title="Number of features selected"),
+            yaxis=dict(title="Cross validation score"),
+            title=dict(
+                text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"
+            ),
+            font=dict(family="sans-serif", size=11),
+            # control backgroud colors
+            plot_bgcolor="rgba(255,255,255,0)",
+        )
+        """
+        # legend=dict(
+                # x=0.95,
+                # y=0,
+                # traceorder="normal",
+                # font=dict(
+                #    family="sans-serif",
+                #    size=9,
+                #    color="black"
+                # ),
+                # bgcolor="LightSteelBlue",
+                # bordercolor="Black",
+                # borderwidth=2
+            # ),
+        """
+
+        fig = go.Figure(data=[data], layout=layout)
+        plotly.offline.plot(fig, filename="output.html", auto_open=False)
+        # to be discovered by `from_work_dir`
+        os.rename("output.html", "output")
+
+        return 0
+
+    elif plot_type == "learning_curve":
+        input_df = pd.read_csv(infile1, sep="\t", header="infer")
+        plot_std_err = params["plotting_selection"]["plot_std_err"]
+        data1 = go.Scatter(
+            x=input_df["train_sizes_abs"],
+            y=input_df["mean_train_scores"],
+            error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None,
+            mode="lines",
+            name="Train Scores",
+        )
+        data2 = go.Scatter(
+            x=input_df["train_sizes_abs"],
+            y=input_df["mean_test_scores"],
+            error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None,
+            mode="lines",
+            name="Test Scores",
+        )
+        layout = dict(
+            xaxis=dict(title="No. of samples"),
+            yaxis=dict(title="Performance Score"),
+            # modify these configurations to customize image
+            title=dict(
+                text=title or "Learning Curve",
+                x=0.5,
+                y=0.92,
+                xanchor="center",
+                yanchor="top",
+            ),
+            font=dict(family="sans-serif", size=11),
+            # control backgroud colors
+            plot_bgcolor="rgba(255,255,255,0)",
+        )
+        """
+        # legend=dict(
+                # x=0.95,
+                # y=0,
+                # traceorder="normal",
+                # font=dict(
+                #    family="sans-serif",
+                #    size=9,
+                #    color="black"
+                # ),
+                # bgcolor="LightSteelBlue",
+                # bordercolor="Black",
+                # borderwidth=2
+            # ),
+        """
+
+        fig = go.Figure(data=[data1, data2], layout=layout)
+        plotly.offline.plot(fig, filename="output.html", auto_open=False)
+        # to be discovered by `from_work_dir`
+        os.rename("output.html", "output")
+
+        return 0
+
+    elif plot_type == "keras_plot_model":
+        with open(model_config, "r") as f:
+            model_str = f.read()
+        model = model_from_json(model_str)
+        plot_model(model, to_file="output.png")
+        os.rename("output.png", "output")
+
+        return 0
+
+    elif plot_type == "classification_confusion_matrix":
+        plot_selection = params["plotting_selection"]
+        input_true = get_dataframe(
+            true_labels, plot_selection, "header_true", "column_selector_options_true"
+        )
+        header_predicted = "infer" if plot_selection["header_predicted"] else None
+        input_predicted = pd.read_csv(
+            predicted_labels, sep="\t", parse_dates=True, header=header_predicted
+        )
+        true_classes = input_true.iloc[:, -1].copy()
+        predicted_classes = input_predicted.iloc[:, -1].copy()
+        axis_labels = list(set(true_classes))
+        c_matrix = confusion_matrix(true_classes, predicted_classes)
+        fig, ax = plt.subplots(figsize=(7, 7))
+        im = plt.imshow(c_matrix, cmap=plot_color)
+        for i in range(len(c_matrix)):
+            for j in range(len(c_matrix)):
+                ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
+        ax.set_ylabel("True class labels")
+        ax.set_xlabel("Predicted class labels")
+        ax.set_title(title)
+        ax.set_xticks(axis_labels)
+        ax.set_yticks(axis_labels)
+        fig.colorbar(im, ax=ax)
+        fig.tight_layout()
+        plt.savefig("output.png", dpi=125)
+        os.rename("output.png", "output")
+
+        return 0
+
+    # save pdf file to disk
+    # fig.write_image("image.pdf", format='pdf')
+    # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
+
+
+if __name__ == "__main__":
+    aparser = argparse.ArgumentParser()
+    aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
+    aparser.add_argument("-e", "--estimator", dest="infile_estimator")
+    aparser.add_argument("-X", "--infile1", dest="infile1")
+    aparser.add_argument("-y", "--infile2", dest="infile2")
+    aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
+    aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
+    aparser.add_argument("-g", "--groups", dest="groups")
+    aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
+    aparser.add_argument("-b", "--intervals", dest="intervals")
+    aparser.add_argument("-t", "--targets", dest="targets")
+    aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
+    aparser.add_argument("-c", "--model_config", dest="model_config")
+    aparser.add_argument("-tl", "--true_labels", dest="true_labels")
+    aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
+    aparser.add_argument("-pc", "--plot_color", dest="plot_color")
+    aparser.add_argument("-pt", "--title", dest="title")
+    args = aparser.parse_args()
+
+    main(
+        args.inputs,
+        args.infile_estimator,
+        args.infile1,
+        args.infile2,
+        args.outfile_result,
+        outfile_object=args.outfile_object,
+        groups=args.groups,
+        ref_seq=args.ref_seq,
+        intervals=args.intervals,
+        targets=args.targets,
+        fasta_path=args.fasta_path,
+        model_config=args.model_config,
+        true_labels=args.true_labels,
+        predicted_labels=args.predicted_labels,
+        plot_color=args.plot_color,
+        title=args.title,
+    )