Mercurial > repos > bgruening > sklearn_sample_generator
diff model_prediction.py @ 35:1e99cfb71f40 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 17:52:15 +0000 |
parents | df579b31311d |
children | 999e07f0a9fa |
line wrap: on
line diff
--- a/model_prediction.py Thu Oct 01 20:27:36 2020 +0000 +++ b/model_prediction.py Tue Apr 13 17:52:15 2021 +0000 @@ -1,23 +1,29 @@ import argparse import json +import warnings + import numpy as np import pandas as pd -import warnings - from scipy.io import mmread from sklearn.pipeline import Pipeline -from galaxy_ml.utils import (load_model, read_columns, - get_module, try_get_attr) +from galaxy_ml.utils import (get_module, load_model, + read_columns, try_get_attr) + + +N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) -N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) - - -def main(inputs, infile_estimator, outfile_predict, - infile_weights=None, infile1=None, - fasta_path=None, ref_seq=None, - vcf_path=None): +def main( + inputs, + infile_estimator, + outfile_predict, + infile_weights=None, + infile1=None, + fasta_path=None, + ref_seq=None, + vcf_path=None, +): """ Parameter --------- @@ -45,96 +51,94 @@ vcf_path : str File path to dataset containing variants info. """ - warnings.filterwarnings('ignore') + warnings.filterwarnings("ignore") - with open(inputs, 'r') as param_handler: + with open(inputs, "r") as param_handler: params = json.load(param_handler) # load model - with open(infile_estimator, 'rb') as est_handler: + with open(infile_estimator, "rb") as est_handler: estimator = load_model(est_handler) main_est = estimator if isinstance(estimator, Pipeline): main_est = estimator.steps[-1][-1] - if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): - if not infile_weights or infile_weights == 'None': - raise ValueError("The selected model skeleton asks for weights, " - "but dataset for weights wan not selected!") + if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): + if not infile_weights or infile_weights == "None": + raise ValueError( + "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!" + ) main_est.load_weights(infile_weights) # handle data input - input_type = params['input_options']['selected_input'] + input_type = params["input_options"]["selected_input"] # tabular input - if input_type == 'tabular': - header = 'infer' if params['input_options']['header1'] else None - column_option = (params['input_options'] - ['column_selector_options_1'] - ['selected_column_selector_option']) - if column_option in ['by_index_number', 'all_but_by_index_number', - 'by_header_name', 'all_but_by_header_name']: - c = params['input_options']['column_selector_options_1']['col1'] + if input_type == "tabular": + header = "infer" if params["input_options"]["header1"] else None + column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: + c = params["input_options"]["column_selector_options_1"]["col1"] else: c = None - df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) + df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) X = read_columns(df, c=c, c_option=column_option).astype(float) - if params['method'] == 'predict': + if params["method"] == "predict": preds = estimator.predict(X) else: preds = estimator.predict_proba(X) # sparse input - elif input_type == 'sparse': - X = mmread(open(infile1, 'r')) - if params['method'] == 'predict': + elif input_type == "sparse": + X = mmread(open(infile1, "r")) + if params["method"] == "predict": preds = estimator.predict(X) else: preds = estimator.predict_proba(X) # fasta input - elif input_type == 'seq_fasta': - if not hasattr(estimator, 'data_batch_generator'): + elif input_type == "seq_fasta": + if not hasattr(estimator, "data_batch_generator"): raise ValueError( "To do prediction on sequences in fasta input, " "the estimator must be a `KerasGBatchClassifier`" - "equipped with data_batch_generator!") - pyfaidx = get_module('pyfaidx') + "equipped with data_batch_generator!" + ) + pyfaidx = get_module("pyfaidx") sequences = pyfaidx.Fasta(fasta_path) n_seqs = len(sequences.keys()) X = np.arange(n_seqs)[:, np.newaxis] seq_length = estimator.data_batch_generator.seq_length - batch_size = getattr(estimator, 'batch_size', 32) + batch_size = getattr(estimator, "batch_size", 32) steps = (n_seqs + batch_size - 1) // batch_size - seq_type = params['input_options']['seq_type'] - klass = try_get_attr( - 'galaxy_ml.preprocessors', seq_type) + seq_type = params["input_options"]["seq_type"] + klass = try_get_attr("galaxy_ml.preprocessors", seq_type) - pred_data_generator = klass( - fasta_path, seq_length=seq_length) + pred_data_generator = klass(fasta_path, seq_length=seq_length) - if params['method'] == 'predict': - preds = estimator.predict( - X, data_generator=pred_data_generator, steps=steps) + if params["method"] == "predict": + preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps) else: - preds = estimator.predict_proba( - X, data_generator=pred_data_generator, steps=steps) + preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps) # vcf input - elif input_type == 'variant_effect': - klass = try_get_attr('galaxy_ml.preprocessors', - 'GenomicVariantBatchGenerator') + elif input_type == "variant_effect": + klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") - options = params['input_options'] - options.pop('selected_input') - if options['blacklist_regions'] == 'none': - options['blacklist_regions'] = None + options = params["input_options"] + options.pop("selected_input") + if options["blacklist_regions"] == "none": + options["blacklist_regions"] = None - pred_data_generator = klass( - ref_genome_path=ref_seq, vcf_path=vcf_path, **options) + pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options) pred_data_generator.set_processing_attrs() @@ -143,9 +147,8 @@ # predict 1600 sample at once then write to file gen_flow = pred_data_generator.flow(batch_size=1600) - file_writer = open(outfile_predict, 'w') - header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', - 'alt', 'strand']) + file_writer = open(outfile_predict, "w") + header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"]) file_writer.write(header_row) header_done = False @@ -155,23 +158,24 @@ try: while steps_done < len(gen_flow): index_array = next(gen_flow.index_generator) - batch_X = gen_flow._get_batches_of_transformed_samples( - index_array) + batch_X = gen_flow._get_batches_of_transformed_samples(index_array) - if params['method'] == 'predict': + if params["method"] == "predict": batch_preds = estimator.predict( batch_X, # The presence of `pred_data_generator` below is to # override model carrying data_generator if there # is any. - data_generator=pred_data_generator) + data_generator=pred_data_generator, + ) else: batch_preds = estimator.predict_proba( batch_X, # The presence of `pred_data_generator` below is to # override model carrying data_generator if there # is any. - data_generator=pred_data_generator) + data_generator=pred_data_generator, + ) if batch_preds.ndim == 1: batch_preds = batch_preds[:, np.newaxis] @@ -181,12 +185,12 @@ if not header_done: heads = np.arange(batch_preds.shape[-1]).astype(str) - heads_str = '\t'.join(heads) + heads_str = "\t".join(heads) file_writer.write("\t%s\n" % heads_str) header_done = True for row in batch_out: - row_str = '\t'.join(row) + row_str = "\t".join(row) file_writer.write("%s\n" % row_str) steps_done += 1 @@ -200,14 +204,14 @@ # output if len(preds.shape) == 1: - rval = pd.DataFrame(preds, columns=['Predicted']) + rval = pd.DataFrame(preds, columns=["Predicted"]) else: rval = pd.DataFrame(preds) - rval.to_csv(outfile_predict, sep='\t', header=True, index=False) + rval.to_csv(outfile_predict, sep="\t", header=True, index=False) -if __name__ == '__main__': +if __name__ == "__main__": aparser = argparse.ArgumentParser() aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") @@ -219,7 +223,13 @@ aparser.add_argument("-v", "--vcf_path", dest="vcf_path") args = aparser.parse_args() - main(args.inputs, args.infile_estimator, args.outfile_predict, - infile_weights=args.infile_weights, infile1=args.infile1, - fasta_path=args.fasta_path, ref_seq=args.ref_seq, - vcf_path=args.vcf_path) + main( + args.inputs, + args.infile_estimator, + args.outfile_predict, + infile_weights=args.infile_weights, + infile1=args.infile1, + fasta_path=args.fasta_path, + ref_seq=args.ref_seq, + vcf_path=args.vcf_path, + )