Mercurial > repos > bgruening > sklearn_searchcv
diff ml_visualization_ex.py @ 17:1ae5dfd5ac17 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9e28f4466084464d38d3f8db2aff07974be4ba69"
author | bgruening |
---|---|
date | Wed, 11 Mar 2020 14:02:20 -0400 |
parents | c1ca24a1509d |
children | cb5635e30842 |
line wrap: on
line diff
--- a/ml_visualization_ex.py Wed Jan 22 08:05:12 2020 -0500 +++ b/ml_visualization_ex.py Wed Mar 11 14:02:20 2020 -0400 @@ -13,7 +13,7 @@ from keras.utils import plot_model from sklearn.feature_selection.base import SelectorMixin from sklearn.metrics import precision_recall_curve, average_precision_score -from sklearn.metrics import roc_curve, auc +from sklearn.metrics import roc_curve, auc, confusion_matrix from sklearn.pipeline import Pipeline from galaxy_ml.utils import load_model, read_columns, SafeEval @@ -266,12 +266,29 @@ os.path.join(folder, "output")) +def get_dataframe(file_path, plot_selection, header_name, column_name): + header = 'infer' if plot_selection[header_name] else None + column_option = plot_selection[column_name]["selected_column_selector_option"] + if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: + col = plot_selection[column_name]["col1"] + else: + col = None + _, input_df = read_columns(file_path, c=col, + c_option=column_option, + return_df=True, + sep='\t', header=header, + parse_dates=True) + return input_df + + def main(inputs, infile_estimator=None, infile1=None, infile2=None, outfile_result=None, outfile_object=None, groups=None, ref_seq=None, intervals=None, targets=None, fasta_path=None, - model_config=None): + model_config=None, true_labels=None, + predicted_labels=None, plot_color=None, + title=None): """ Parameter --------- @@ -311,6 +328,18 @@ model_config : str, default is None File path to dataset containing JSON config for neural networks + + true_labels : str, default is None + File path to dataset containing true labels + + predicted_labels : str, default is None + File path to dataset containing true predicted labels + + plot_color : str, default is None + Color of the confusion matrix heatmap + + title : str, default is None + Title of the confusion matrix heatmap """ warnings.simplefilter('ignore') @@ -543,6 +572,32 @@ return 0 + elif plot_type == 'classification_confusion_matrix': + plot_selection = params["plotting_selection"] + input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") + header_predicted = 'infer' if plot_selection["header_predicted"] else None + input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted) + true_classes = input_true.iloc[:, -1].copy() + predicted_classes = input_predicted.iloc[:, -1].copy() + axis_labels = list(set(true_classes)) + c_matrix = confusion_matrix(true_classes, predicted_classes) + fig, ax = plt.subplots(figsize=(7, 7)) + im = plt.imshow(c_matrix, cmap=plot_color) + for i in range(len(c_matrix)): + for j in range(len(c_matrix)): + ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") + ax.set_ylabel('True class labels') + ax.set_xlabel('Predicted class labels') + ax.set_title(title) + ax.set_xticks(axis_labels) + ax.set_yticks(axis_labels) + fig.colorbar(im, ax=ax) + fig.tight_layout() + plt.savefig("output.png", dpi=125) + os.rename('output.png', 'output') + + return 0 + # save pdf file to disk # fig.write_image("image.pdf", format='pdf') # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) @@ -562,10 +617,17 @@ aparser.add_argument("-t", "--targets", dest="targets") aparser.add_argument("-f", "--fasta_path", dest="fasta_path") aparser.add_argument("-c", "--model_config", dest="model_config") + aparser.add_argument("-tl", "--true_labels", dest="true_labels") + aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") + aparser.add_argument("-pc", "--plot_color", dest="plot_color") + aparser.add_argument("-pt", "--title", dest="title") args = aparser.parse_args() main(args.inputs, args.infile_estimator, args.infile1, args.infile2, args.outfile_result, outfile_object=args.outfile_object, groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, targets=args.targets, fasta_path=args.fasta_path, - model_config=args.model_config) + model_config=args.model_config, true_labels=args.true_labels, + predicted_labels=args.predicted_labels, + plot_color=args.plot_color, + title=args.title)