view ml_visualization_ex.py @ 38:74adae8d7b0f draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 80417bf0158a9b596e485dd66408f738f405145a
author bgruening
date Mon, 02 Oct 2023 10:30:40 +0000
parents 5276bdb49315
children
line wrap: on
line source

import argparse
import json
import os
import warnings

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.graph_objs as go
from galaxy_ml.model_persist import load_model_from_h5
from galaxy_ml.utils import read_columns, SafeEval
from sklearn.feature_selection._base import SelectorMixin
from sklearn.metrics import (
    auc,
    average_precision_score,
    confusion_matrix,
    precision_recall_curve,
    roc_curve,
)
from sklearn.pipeline import Pipeline
from tensorflow.keras.models import model_from_json
from tensorflow.keras.utils import plot_model

safe_eval = SafeEval()

# plotly default colors
default_colors = [
    "#1f77b4",  # muted blue
    "#ff7f0e",  # safety orange
    "#2ca02c",  # cooked asparagus green
    "#d62728",  # brick red
    "#9467bd",  # muted purple
    "#8c564b",  # chestnut brown
    "#e377c2",  # raspberry yogurt pink
    "#7f7f7f",  # middle gray
    "#bcbd22",  # curry yellow-green
    "#17becf",  # blue-teal
]


def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
    """output pr-curve in html using plotly

    df1 : pandas.DataFrame
        Containing y_true
    df2 : pandas.DataFrame
        Containing y_score
    pos_label : None
        The label of positive class
    title : str
        Plot title
    """
    data = []
    for idx in range(df1.shape[1]):
        y_true = df1.iloc[:, idx].values
        y_score = df2.iloc[:, idx].values

        precision, recall, _ = precision_recall_curve(
            y_true, y_score, pos_label=pos_label
        )
        ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)

        trace = go.Scatter(
            x=recall,
            y=precision,
            mode="lines",
            marker=dict(color=default_colors[idx % len(default_colors)]),
            name="%s (area = %.3f)" % (idx, ap),
        )
        data.append(trace)

    layout = go.Layout(
        xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1),
        yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1),
        title=dict(
            text=title or "Precision-Recall Curve",
            x=0.5,
            y=0.92,
            xanchor="center",
            yanchor="top",
        ),
        font=dict(family="sans-serif", size=11),
        # control backgroud colors
        plot_bgcolor="rgba(255,255,255,0)",
    )
    """
    legend=dict(
        x=0.95,
        y=0,
        traceorder="normal",
        font=dict(
            family="sans-serif",
            size=9,
            color="black"
        ),
        bgcolor="LightSteelBlue",
        bordercolor="Black",
        borderwidth=2
    ),"""

    fig = go.Figure(data=data, layout=layout)

    plotly.offline.plot(fig, filename="output.html", auto_open=False)
    # to be discovered by `from_work_dir`
    os.rename("output.html", "output")


def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
    """visualize pr-curve using matplotlib and output svg image"""
    backend = matplotlib.get_backend()
    if "inline" not in backend:
        matplotlib.use("SVG")
    plt.style.use("seaborn-colorblind")
    plt.figure()

    for idx in range(df1.shape[1]):
        y_true = df1.iloc[:, idx].values
        y_score = df2.iloc[:, idx].values

        precision, recall, _ = precision_recall_curve(
            y_true, y_score, pos_label=pos_label
        )
        ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1)

        plt.step(
            recall,
            precision,
            "r-",
            color="black",
            alpha=0.3,
            lw=1,
            where="post",
            label="%s (area = %.3f)" % (idx, ap),
        )

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    title = title or "Precision-Recall Curve"
    plt.title(title)
    folder = os.getcwd()
    plt.savefig(os.path.join(folder, "output.svg"), format="svg")
    os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))


def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None):
    """output roc-curve in html using plotly

    df1 : pandas.DataFrame
        Containing y_true
    df2 : pandas.DataFrame
        Containing y_score
    pos_label : None
        The label of positive class
    drop_intermediate : bool
        Whether to drop some suboptimal thresholds
    title : str
        Plot title
    """
    data = []
    for idx in range(df1.shape[1]):
        y_true = df1.iloc[:, idx].values
        y_score = df2.iloc[:, idx].values

        fpr, tpr, _ = roc_curve(
            y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
        )
        roc_auc = auc(fpr, tpr)

        trace = go.Scatter(
            x=fpr,
            y=tpr,
            mode="lines",
            marker=dict(color=default_colors[idx % len(default_colors)]),
            name="%s (area = %.3f)" % (idx, roc_auc),
        )
        data.append(trace)

    layout = go.Layout(
        xaxis=dict(
            title="False Positive Rate", linecolor="lightslategray", linewidth=1
        ),
        yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1),
        title=dict(
            text=title or "Receiver Operating Characteristic (ROC) Curve",
            x=0.5,
            y=0.92,
            xanchor="center",
            yanchor="top",
        ),
        font=dict(family="sans-serif", size=11),
        # control backgroud colors
        plot_bgcolor="rgba(255,255,255,0)",
    )
    """
    # legend=dict(
            # x=0.95,
            # y=0,
            # traceorder="normal",
            # font=dict(
            #    family="sans-serif",
            #    size=9,
            #    color="black"
            # ),
            # bgcolor="LightSteelBlue",
            # bordercolor="Black",
            # borderwidth=2
        # ),
    """

    fig = go.Figure(data=data, layout=layout)

    plotly.offline.plot(fig, filename="output.html", auto_open=False)
    # to be discovered by `from_work_dir`
    os.rename("output.html", "output")


def visualize_roc_curve_matplotlib(
    df1, df2, pos_label, drop_intermediate=True, title=None
):
    """visualize roc-curve using matplotlib and output svg image"""
    backend = matplotlib.get_backend()
    if "inline" not in backend:
        matplotlib.use("SVG")
    plt.style.use("seaborn-colorblind")
    plt.figure()

    for idx in range(df1.shape[1]):
        y_true = df1.iloc[:, idx].values
        y_score = df2.iloc[:, idx].values

        fpr, tpr, _ = roc_curve(
            y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate
        )
        roc_auc = auc(fpr, tpr)

        plt.step(
            fpr,
            tpr,
            "r-",
            color="black",
            alpha=0.3,
            lw=1,
            where="post",
            label="%s (area = %.3f)" % (idx, roc_auc),
        )

    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    title = title or "Receiver Operating Characteristic (ROC) Curve"
    plt.title(title)
    folder = os.getcwd()
    plt.savefig(os.path.join(folder, "output.svg"), format="svg")
    os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))


def get_dataframe(file_path, plot_selection, header_name, column_name):
    header = "infer" if plot_selection[header_name] else None
    column_option = plot_selection[column_name]["selected_column_selector_option"]
    if column_option in [
        "by_index_number",
        "all_but_by_index_number",
        "by_header_name",
        "all_but_by_header_name",
    ]:
        col = plot_selection[column_name]["col1"]
    else:
        col = None
    _, input_df = read_columns(
        file_path,
        c=col,
        c_option=column_option,
        return_df=True,
        sep="\t",
        header=header,
        parse_dates=True,
    )
    return input_df


def main(
    inputs,
    infile_estimator=None,
    infile1=None,
    infile2=None,
    outfile_result=None,
    outfile_object=None,
    groups=None,
    ref_seq=None,
    intervals=None,
    targets=None,
    fasta_path=None,
    model_config=None,
    true_labels=None,
    predicted_labels=None,
    plot_color=None,
    title=None,
):
    """
    Parameter
    ---------
    inputs : str
        File path to galaxy tool parameter

    infile_estimator : str, default is None
        File path to estimator

    infile1 : str, default is None
        File path to dataset containing features or true labels.

    infile2 : str, default is None
        File path to dataset containing target values or predicted
        probabilities.

    outfile_result : str, default is None
        File path to save the results, either cv_results or test result

    outfile_object : str, default is None
        File path to save searchCV object

    groups : str, default is None
        File path to dataset containing groups labels

    ref_seq : str, default is None
        File path to dataset containing genome sequence file

    intervals : str, default is None
        File path to dataset containing interval file

    targets : str, default is None
        File path to dataset compressed target bed file

    fasta_path : str, default is None
        File path to dataset containing fasta file

    model_config : str, default is None
        File path to dataset containing JSON config for neural networks

    true_labels : str, default is None
        File path to dataset containing true labels

    predicted_labels : str, default is None
        File path to dataset containing true predicted labels

    plot_color : str, default is None
        Color of the confusion matrix heatmap

    title : str, default is None
        Title of the confusion matrix heatmap
    """
    warnings.simplefilter("ignore")

    with open(inputs, "r") as param_handler:
        params = json.load(param_handler)

    title = params["plotting_selection"]["title"].strip()
    plot_type = params["plotting_selection"]["plot_type"]
    plot_format = params["plotting_selection"]["plot_format"]

    if plot_type == "feature_importances":
        estimator = load_model_from_h5(infile_estimator)

        column_option = params["plotting_selection"]["column_selector_options"][
            "selected_column_selector_option"
        ]
        if column_option in [
            "by_index_number",
            "all_but_by_index_number",
            "by_header_name",
            "all_but_by_header_name",
        ]:
            c = params["plotting_selection"]["column_selector_options"]["col1"]
        else:
            c = None

        _, input_df = read_columns(
            infile1,
            c=c,
            c_option=column_option,
            return_df=True,
            sep="\t",
            header="infer",
            parse_dates=True,
        )

        feature_names = input_df.columns.values

        if isinstance(estimator, Pipeline):
            for st in estimator.steps[:-1]:
                if isinstance(st[-1], SelectorMixin):
                    mask = st[-1].get_support()
                    feature_names = feature_names[mask]
            estimator = estimator.steps[-1][-1]

        if hasattr(estimator, "coef_"):
            coefs = estimator.coef_
        else:
            coefs = getattr(estimator, "feature_importances_", None)
        if coefs is None:
            raise RuntimeError(
                "The classifier does not expose "
                '"coef_" or "feature_importances_" '
                "attributes"
            )

        threshold = params["plotting_selection"]["threshold"]
        if threshold is not None:
            mask = (coefs > threshold) | (coefs < -threshold)
            coefs = coefs[mask]
            feature_names = feature_names[mask]

        # sort
        indices = np.argsort(coefs)[::-1]

        trace = go.Bar(x=feature_names[indices], y=coefs[indices])
        layout = go.Layout(title=title or "Feature Importances")
        fig = go.Figure(data=[trace], layout=layout)

        plotly.offline.plot(fig, filename="output.html", auto_open=False)
        # to be discovered by `from_work_dir`
        os.rename("output.html", "output")

        return 0

    elif plot_type in ("pr_curve", "roc_curve"):
        df1 = pd.read_csv(infile1, sep="\t", header="infer")
        df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32)

        minimum = params["plotting_selection"]["report_minimum_n_positives"]
        # filter out columns whose n_positives is beblow the threhold
        if minimum:
            mask = df1.sum(axis=0) >= minimum
            df1 = df1.loc[:, mask]
            df2 = df2.loc[:, mask]

        pos_label = params["plotting_selection"]["pos_label"].strip() or None

        if plot_type == "pr_curve":
            if plot_format == "plotly_html":
                visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
            else:
                visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
        else:  # 'roc_curve'
            drop_intermediate = params["plotting_selection"]["drop_intermediate"]
            if plot_format == "plotly_html":
                visualize_roc_curve_plotly(
                    df1,
                    df2,
                    pos_label,
                    drop_intermediate=drop_intermediate,
                    title=title,
                )
            else:
                visualize_roc_curve_matplotlib(
                    df1,
                    df2,
                    pos_label,
                    drop_intermediate=drop_intermediate,
                    title=title,
                )

        return 0

    elif plot_type == "rfecv_gridscores":
        input_df = pd.read_csv(infile1, sep="\t", header="infer")
        scores = input_df.iloc[:, 0]
        steps = params["plotting_selection"]["steps"].strip()
        steps = safe_eval(steps)

        data = go.Scatter(
            x=list(range(len(scores))),
            y=scores,
            text=[str(_) for _ in steps] if steps else None,
            mode="lines",
        )
        layout = go.Layout(
            xaxis=dict(title="Number of features selected"),
            yaxis=dict(title="Cross validation score"),
            title=dict(
                text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"
            ),
            font=dict(family="sans-serif", size=11),
            # control backgroud colors
            plot_bgcolor="rgba(255,255,255,0)",
        )
        """
        # legend=dict(
                # x=0.95,
                # y=0,
                # traceorder="normal",
                # font=dict(
                #    family="sans-serif",
                #    size=9,
                #    color="black"
                # ),
                # bgcolor="LightSteelBlue",
                # bordercolor="Black",
                # borderwidth=2
            # ),
        """

        fig = go.Figure(data=[data], layout=layout)
        plotly.offline.plot(fig, filename="output.html", auto_open=False)
        # to be discovered by `from_work_dir`
        os.rename("output.html", "output")

        return 0

    elif plot_type == "learning_curve":
        input_df = pd.read_csv(infile1, sep="\t", header="infer")
        plot_std_err = params["plotting_selection"]["plot_std_err"]
        data1 = go.Scatter(
            x=input_df["train_sizes_abs"],
            y=input_df["mean_train_scores"],
            error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None,
            mode="lines",
            name="Train Scores",
        )
        data2 = go.Scatter(
            x=input_df["train_sizes_abs"],
            y=input_df["mean_test_scores"],
            error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None,
            mode="lines",
            name="Test Scores",
        )
        layout = dict(
            xaxis=dict(title="No. of samples"),
            yaxis=dict(title="Performance Score"),
            # modify these configurations to customize image
            title=dict(
                text=title or "Learning Curve",
                x=0.5,
                y=0.92,
                xanchor="center",
                yanchor="top",
            ),
            font=dict(family="sans-serif", size=11),
            # control backgroud colors
            plot_bgcolor="rgba(255,255,255,0)",
        )
        """
        # legend=dict(
                # x=0.95,
                # y=0,
                # traceorder="normal",
                # font=dict(
                #    family="sans-serif",
                #    size=9,
                #    color="black"
                # ),
                # bgcolor="LightSteelBlue",
                # bordercolor="Black",
                # borderwidth=2
            # ),
        """

        fig = go.Figure(data=[data1, data2], layout=layout)
        plotly.offline.plot(fig, filename="output.html", auto_open=False)
        # to be discovered by `from_work_dir`
        os.rename("output.html", "output")

        return 0

    elif plot_type == "keras_plot_model":
        with open(model_config, "r") as f:
            model_str = f.read()
        model = model_from_json(model_str)
        plot_model(model, to_file="output.png")
        os.rename("output.png", "output")

        return 0

    elif plot_type == "classification_confusion_matrix":
        plot_selection = params["plotting_selection"]
        input_true = get_dataframe(
            true_labels, plot_selection, "header_true", "column_selector_options_true"
        )
        header_predicted = "infer" if plot_selection["header_predicted"] else None
        input_predicted = pd.read_csv(
            predicted_labels, sep="\t", parse_dates=True, header=header_predicted
        )
        true_classes = input_true.iloc[:, -1].copy()
        predicted_classes = input_predicted.iloc[:, -1].copy()
        axis_labels = list(set(true_classes))
        c_matrix = confusion_matrix(true_classes, predicted_classes)
        fig, ax = plt.subplots(figsize=(7, 7))
        im = plt.imshow(c_matrix, cmap=plot_color)
        for i in range(len(c_matrix)):
            for j in range(len(c_matrix)):
                ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
        ax.set_ylabel("True class labels")
        ax.set_xlabel("Predicted class labels")
        ax.set_title(title)
        ax.set_xticks(axis_labels)
        ax.set_yticks(axis_labels)
        fig.colorbar(im, ax=ax)
        fig.tight_layout()
        plt.savefig("output.png", dpi=125)
        os.rename("output.png", "output")

        return 0

    # save pdf file to disk
    # fig.write_image("image.pdf", format='pdf')
    # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)


if __name__ == "__main__":
    aparser = argparse.ArgumentParser()
    aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
    aparser.add_argument("-e", "--estimator", dest="infile_estimator")
    aparser.add_argument("-X", "--infile1", dest="infile1")
    aparser.add_argument("-y", "--infile2", dest="infile2")
    aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
    aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
    aparser.add_argument("-g", "--groups", dest="groups")
    aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
    aparser.add_argument("-b", "--intervals", dest="intervals")
    aparser.add_argument("-t", "--targets", dest="targets")
    aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
    aparser.add_argument("-c", "--model_config", dest="model_config")
    aparser.add_argument("-tl", "--true_labels", dest="true_labels")
    aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
    aparser.add_argument("-pc", "--plot_color", dest="plot_color")
    aparser.add_argument("-pt", "--title", dest="title")
    args = aparser.parse_args()

    main(
        args.inputs,
        args.infile_estimator,
        args.infile1,
        args.infile2,
        args.outfile_result,
        outfile_object=args.outfile_object,
        groups=args.groups,
        ref_seq=args.ref_seq,
        intervals=args.intervals,
        targets=args.targets,
        fasta_path=args.fasta_path,
        model_config=args.model_config,
        true_labels=args.true_labels,
        predicted_labels=args.predicted_labels,
        plot_color=args.plot_color,
        title=args.title,
    )