Mercurial > repos > bgruening > keras_train_and_eval
view model_prediction.py @ 6:ead8f1822587 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author | bgruening |
---|---|
date | Sat, 01 May 2021 00:50:57 +0000 |
parents | 785dd890e27d |
children | b3093f953091 |
line wrap: on
line source
import argparse import json import warnings import numpy as np import pandas as pd from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr from scipy.io import mmread from sklearn.pipeline import Pipeline N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) def main( inputs, infile_estimator, outfile_predict, infile_weights=None, infile1=None, fasta_path=None, ref_seq=None, vcf_path=None, ): """ Parameter --------- inputs : str File path to galaxy tool parameter infile_estimator : strgit File path to trained estimator input outfile_predict : str File path to save the prediction results, tabular infile_weights : str File path to weights input infile1 : str File path to dataset containing features fasta_path : str File path to dataset containing fasta file ref_seq : str File path to dataset containing the reference genome sequence. vcf_path : str File path to dataset containing variants info. """ warnings.filterwarnings("ignore") with open(inputs, "r") as param_handler: params = json.load(param_handler) # load model with open(infile_estimator, "rb") as est_handler: estimator = load_model(est_handler) main_est = estimator if isinstance(estimator, Pipeline): main_est = estimator.steps[-1][-1] if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): if not infile_weights or infile_weights == "None": raise ValueError( "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!" ) main_est.load_weights(infile_weights) # handle data input input_type = params["input_options"]["selected_input"] # tabular input if input_type == "tabular": header = "infer" if params["input_options"]["header1"] else None column_option = params["input_options"]["column_selector_options_1"][ "selected_column_selector_option" ] if column_option in [ "by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name", ]: c = params["input_options"]["column_selector_options_1"]["col1"] else: c = None df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) X = read_columns(df, c=c, c_option=column_option).astype(float) if params["method"] == "predict": preds = estimator.predict(X) else: preds = estimator.predict_proba(X) # sparse input elif input_type == "sparse": X = mmread(open(infile1, "r")) if params["method"] == "predict": preds = estimator.predict(X) else: preds = estimator.predict_proba(X) # fasta input elif input_type == "seq_fasta": if not hasattr(estimator, "data_batch_generator"): raise ValueError( "To do prediction on sequences in fasta input, " "the estimator must be a `KerasGBatchClassifier`" "equipped with data_batch_generator!" ) pyfaidx = get_module("pyfaidx") sequences = pyfaidx.Fasta(fasta_path) n_seqs = len(sequences.keys()) X = np.arange(n_seqs)[:, np.newaxis] seq_length = estimator.data_batch_generator.seq_length batch_size = getattr(estimator, "batch_size", 32) steps = (n_seqs + batch_size - 1) // batch_size seq_type = params["input_options"]["seq_type"] klass = try_get_attr("galaxy_ml.preprocessors", seq_type) pred_data_generator = klass(fasta_path, seq_length=seq_length) if params["method"] == "predict": preds = estimator.predict( X, data_generator=pred_data_generator, steps=steps ) else: preds = estimator.predict_proba( X, data_generator=pred_data_generator, steps=steps ) # vcf input elif input_type == "variant_effect": klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") options = params["input_options"] options.pop("selected_input") if options["blacklist_regions"] == "none": options["blacklist_regions"] = None pred_data_generator = klass( ref_genome_path=ref_seq, vcf_path=vcf_path, **options ) pred_data_generator.set_processing_attrs() variants = pred_data_generator.variants # predict 1600 sample at once then write to file gen_flow = pred_data_generator.flow(batch_size=1600) file_writer = open(outfile_predict, "w") header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"]) file_writer.write(header_row) header_done = False steps_done = 0 # TODO: multiple threading try: while steps_done < len(gen_flow): index_array = next(gen_flow.index_generator) batch_X = gen_flow._get_batches_of_transformed_samples(index_array) if params["method"] == "predict": batch_preds = estimator.predict( batch_X, # The presence of `pred_data_generator` below is to # override model carrying data_generator if there # is any. data_generator=pred_data_generator, ) else: batch_preds = estimator.predict_proba( batch_X, # The presence of `pred_data_generator` below is to # override model carrying data_generator if there # is any. data_generator=pred_data_generator, ) if batch_preds.ndim == 1: batch_preds = batch_preds[:, np.newaxis] batch_meta = variants[index_array] batch_out = np.column_stack([batch_meta, batch_preds]) if not header_done: heads = np.arange(batch_preds.shape[-1]).astype(str) heads_str = "\t".join(heads) file_writer.write("\t%s\n" % heads_str) header_done = True for row in batch_out: row_str = "\t".join(row) file_writer.write("%s\n" % row_str) steps_done += 1 finally: file_writer.close() # TODO: make api `pred_data_generator.close()` pred_data_generator.close() return 0 # end input # output if len(preds.shape) == 1: rval = pd.DataFrame(preds, columns=["Predicted"]) else: rval = pd.DataFrame(preds) rval.to_csv(outfile_predict, sep="\t", header=True, index=False) if __name__ == "__main__": aparser = argparse.ArgumentParser() aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") aparser.add_argument("-w", "--infile_weights", dest="infile_weights") aparser.add_argument("-X", "--infile1", dest="infile1") aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") aparser.add_argument("-f", "--fasta_path", dest="fasta_path") aparser.add_argument("-r", "--ref_seq", dest="ref_seq") aparser.add_argument("-v", "--vcf_path", dest="vcf_path") args = aparser.parse_args() main( args.inputs, args.infile_estimator, args.outfile_predict, infile_weights=args.infile_weights, infile1=args.infile1, fasta_path=args.fasta_path, ref_seq=args.ref_seq, vcf_path=args.vcf_path, )