diff model_prediction.py @ 36:92e09b827300 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:43:23 +0000
parents 318484f56b6a
children 5af054432771
line wrap: on
line diff
--- a/model_prediction.py	Tue Apr 13 22:34:09 2021 +0000
+++ b/model_prediction.py	Sat May 01 01:43:23 2021 +0000
@@ -63,7 +63,8 @@
     if hasattr(main_est, "config") and hasattr(main_est, "load_weights"):
         if not infile_weights or infile_weights == "None":
             raise ValueError(
-                "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!"
+                "The selected model skeleton asks for weights, "
+                "but dataset for weights wan not selected!"
             )
         main_est.load_weights(infile_weights)
 
@@ -72,7 +73,9 @@
     # tabular input
     if input_type == "tabular":
         header = "infer" if params["input_options"]["header1"] else None
-        column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
+        column_option = params["input_options"]["column_selector_options_1"][
+            "selected_column_selector_option"
+        ]
         if column_option in [
             "by_index_number",
             "all_but_by_index_number",
@@ -122,9 +125,13 @@
         pred_data_generator = klass(fasta_path, seq_length=seq_length)
 
         if params["method"] == "predict":
-            preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps)
+            preds = estimator.predict(
+                X, data_generator=pred_data_generator, steps=steps
+            )
         else:
-            preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps)
+            preds = estimator.predict_proba(
+                X, data_generator=pred_data_generator, steps=steps
+            )
 
     # vcf input
     elif input_type == "variant_effect":
@@ -135,7 +142,9 @@
         if options["blacklist_regions"] == "none":
             options["blacklist_regions"] = None
 
-        pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
+        pred_data_generator = klass(
+            ref_genome_path=ref_seq, vcf_path=vcf_path, **options
+        )
 
         pred_data_generator.set_processing_attrs()