Mercurial > repos > bgruening > sklearn_to_categorical
diff model_prediction.py @ 0:59e8b4328c82 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 22:40:10 +0000 |
parents | |
children | f93f0cdbaf18 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/model_prediction.py Tue Apr 13 22:40:10 2021 +0000 @@ -0,0 +1,232 @@ +import argparse +import json +import warnings + +import numpy as np +import pandas as pd +from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr +from scipy.io import mmread +from sklearn.pipeline import Pipeline + +N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) + + +def main( + inputs, + infile_estimator, + outfile_predict, + infile_weights=None, + infile1=None, + fasta_path=None, + ref_seq=None, + vcf_path=None, +): + """ + Parameter + --------- + inputs : str + File path to galaxy tool parameter + + infile_estimator : strgit + File path to trained estimator input + + outfile_predict : str + File path to save the prediction results, tabular + + infile_weights : str + File path to weights input + + infile1 : str + File path to dataset containing features + + fasta_path : str + File path to dataset containing fasta file + + ref_seq : str + File path to dataset containing the reference genome sequence. + + vcf_path : str + File path to dataset containing variants info. + """ + warnings.filterwarnings("ignore") + + with open(inputs, "r") as param_handler: + params = json.load(param_handler) + + # load model + with open(infile_estimator, "rb") as est_handler: + estimator = load_model(est_handler) + + main_est = estimator + if isinstance(estimator, Pipeline): + main_est = estimator.steps[-1][-1] + if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): + if not infile_weights or infile_weights == "None": + raise ValueError( + "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!" + ) + main_est.load_weights(infile_weights) + + # handle data input + input_type = params["input_options"]["selected_input"] + # tabular input + if input_type == "tabular": + header = "infer" if params["input_options"]["header1"] else None + column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: + c = params["input_options"]["column_selector_options_1"]["col1"] + else: + c = None + + df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) + + X = read_columns(df, c=c, c_option=column_option).astype(float) + + if params["method"] == "predict": + preds = estimator.predict(X) + else: + preds = estimator.predict_proba(X) + + # sparse input + elif input_type == "sparse": + X = mmread(open(infile1, "r")) + if params["method"] == "predict": + preds = estimator.predict(X) + else: + preds = estimator.predict_proba(X) + + # fasta input + elif input_type == "seq_fasta": + if not hasattr(estimator, "data_batch_generator"): + raise ValueError( + "To do prediction on sequences in fasta input, " + "the estimator must be a `KerasGBatchClassifier`" + "equipped with data_batch_generator!" + ) + pyfaidx = get_module("pyfaidx") + sequences = pyfaidx.Fasta(fasta_path) + n_seqs = len(sequences.keys()) + X = np.arange(n_seqs)[:, np.newaxis] + seq_length = estimator.data_batch_generator.seq_length + batch_size = getattr(estimator, "batch_size", 32) + steps = (n_seqs + batch_size - 1) // batch_size + + seq_type = params["input_options"]["seq_type"] + klass = try_get_attr("galaxy_ml.preprocessors", seq_type) + + pred_data_generator = klass(fasta_path, seq_length=seq_length) + + if params["method"] == "predict": + preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps) + else: + preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps) + + # vcf input + elif input_type == "variant_effect": + klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") + + options = params["input_options"] + options.pop("selected_input") + if options["blacklist_regions"] == "none": + options["blacklist_regions"] = None + + pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options) + + pred_data_generator.set_processing_attrs() + + variants = pred_data_generator.variants + + # predict 1600 sample at once then write to file + gen_flow = pred_data_generator.flow(batch_size=1600) + + file_writer = open(outfile_predict, "w") + header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"]) + file_writer.write(header_row) + header_done = False + + steps_done = 0 + + # TODO: multiple threading + try: + while steps_done < len(gen_flow): + index_array = next(gen_flow.index_generator) + batch_X = gen_flow._get_batches_of_transformed_samples(index_array) + + if params["method"] == "predict": + batch_preds = estimator.predict( + batch_X, + # The presence of `pred_data_generator` below is to + # override model carrying data_generator if there + # is any. + data_generator=pred_data_generator, + ) + else: + batch_preds = estimator.predict_proba( + batch_X, + # The presence of `pred_data_generator` below is to + # override model carrying data_generator if there + # is any. + data_generator=pred_data_generator, + ) + + if batch_preds.ndim == 1: + batch_preds = batch_preds[:, np.newaxis] + + batch_meta = variants[index_array] + batch_out = np.column_stack([batch_meta, batch_preds]) + + if not header_done: + heads = np.arange(batch_preds.shape[-1]).astype(str) + heads_str = "\t".join(heads) + file_writer.write("\t%s\n" % heads_str) + header_done = True + + for row in batch_out: + row_str = "\t".join(row) + file_writer.write("%s\n" % row_str) + + steps_done += 1 + + finally: + file_writer.close() + # TODO: make api `pred_data_generator.close()` + pred_data_generator.close() + return 0 + # end input + + # output + if len(preds.shape) == 1: + rval = pd.DataFrame(preds, columns=["Predicted"]) + else: + rval = pd.DataFrame(preds) + + rval.to_csv(outfile_predict, sep="\t", header=True, index=False) + + +if __name__ == "__main__": + aparser = argparse.ArgumentParser() + aparser.add_argument("-i", "--inputs", dest="inputs", required=True) + aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") + aparser.add_argument("-w", "--infile_weights", dest="infile_weights") + aparser.add_argument("-X", "--infile1", dest="infile1") + aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") + aparser.add_argument("-f", "--fasta_path", dest="fasta_path") + aparser.add_argument("-r", "--ref_seq", dest="ref_seq") + aparser.add_argument("-v", "--vcf_path", dest="vcf_path") + args = aparser.parse_args() + + main( + args.inputs, + args.infile_estimator, + args.outfile_predict, + infile_weights=args.infile_weights, + infile1=args.infile1, + fasta_path=args.fasta_path, + ref_seq=args.ref_seq, + vcf_path=args.vcf_path, + )