diff model_prediction.py @ 27:6edcaa8dbb9f draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author bgruening
date Fri, 13 Sep 2019 12:27:50 -0400
parents 37e193b3fdd7
children 83938131dd46
line wrap: on
line diff
--- a/model_prediction.py	Fri Aug 09 07:10:13 2019 -0400
+++ b/model_prediction.py	Fri Sep 13 12:27:50 2019 -0400
@@ -2,11 +2,13 @@
 import json
 import numpy as np
 import pandas as pd
+import tabix
 import warnings
 
 from scipy.io import mmread
 from sklearn.pipeline import Pipeline
 
+from galaxy_ml.externals.selene_sdk.sequences import Genome
 from galaxy_ml.utils import (load_model, read_columns,
                              get_module, try_get_attr)
 
@@ -138,53 +140,108 @@
 
         pred_data_generator.fit()
 
-        preds = estimator.model_.predict_generator(
-            pred_data_generator.flow(batch_size=32),
-            workers=N_JOBS,
-            use_multiprocessing=True)
+        variants = pred_data_generator.variants
+        # TODO : remove the following block after galaxy-ml v0.7.13
+        blacklist_tabix = getattr(pred_data_generator.reference_genome_,
+                                  '_blacklist_tabix', None)
+        clean_variants = []
+        if blacklist_tabix:
+            start_radius = pred_data_generator.start_radius_
+            end_radius = pred_data_generator.end_radius_
+
+            for chrom, pos, name, ref, alt, strand in variants:
+                center = pos + len(ref) // 2
+                start = center - start_radius
+                end = center + end_radius
 
-        if preds.min() < 0. or preds.max() > 1.:
-            warnings.warn('Network returning invalid probability values. '
-                          'The last layer might not normalize predictions '
-                          'into probabilities '
-                          '(like softmax or sigmoid would).')
+                if isinstance(pred_data_generator.reference_genome_, Genome):
+                    if "chr" not in chrom:
+                        chrom = "chr" + chrom
+                    if "MT" in chrom:
+                        chrom = chrom[:-1]
+                try:
+                    rows = blacklist_tabix.query(chrom, start, end)
+                    found = 0
+                    for row in rows:
+                        found = 1
+                        break
+                    if found:
+                        continue
+                except tabix.TabixError:
+                    pass
 
-        if params['method'] == 'predict_proba' and preds.shape[1] == 1:
-            # first column is probability of class 0 and second is of class 1
-            preds = np.hstack([1 - preds, preds])
+                clean_variants.append((chrom, pos, name, ref, alt, strand))
+        else:
+            clean_variants = variants
+
+        setattr(pred_data_generator, 'variants', clean_variants)
+
+        variants = np.array(clean_variants)
+        # predict 1600 sample at once then write to file
+        gen_flow = pred_data_generator.flow(batch_size=1600)
+
+        file_writer = open(outfile_predict, 'w')
+        header_row = '\t'.join(['chrom', 'pos', 'name', 'ref',
+                                'alt', 'strand'])
+        file_writer.write(header_row)
+        header_done = False
 
-        elif params['method'] == 'predict':
-            if preds.shape[-1] > 1:
-                # if the last activation is `softmax`, the sum of all
-                # probibilities will 1, the classification is considered as
-                # multi-class problem, otherwise, we take it as multi-label.
-                act = getattr(estimator.model_.layers[-1], 'activation', None)
-                if act and act.__name__ == 'softmax':
-                    classes = preds.argmax(axis=-1)
+        steps_done = 0
+
+        # TODO: multiple threading
+        try:
+            while steps_done < len(gen_flow):
+                index_array = next(gen_flow.index_generator)
+                batch_X = gen_flow._get_batches_of_transformed_samples(
+                    index_array)
+
+                if params['method'] == 'predict':
+                    batch_preds = estimator.predict(
+                        batch_X,
+                        # The presence of `pred_data_generator` below is to
+                        # override model carrying data_generator if there
+                        # is any.
+                        data_generator=pred_data_generator)
                 else:
-                    preds = (preds > 0.5).astype('int32')
-            else:
-                classes = (preds > 0.5).astype('int32')
+                    batch_preds = estimator.predict_proba(
+                        batch_X,
+                        # The presence of `pred_data_generator` below is to
+                        # override model carrying data_generator if there
+                        # is any.
+                        data_generator=pred_data_generator)
+
+                if batch_preds.ndim == 1:
+                    batch_preds = batch_preds[:, np.newaxis]
+
+                batch_meta = variants[index_array]
+                batch_out = np.column_stack([batch_meta, batch_preds])
 
-            preds = estimator.classes_[classes]
+                if not header_done:
+                    heads = np.arange(batch_preds.shape[-1]).astype(str)
+                    heads_str = '\t'.join(heads)
+                    file_writer.write("\t%s\n" % heads_str)
+                    header_done = True
+
+                for row in batch_out:
+                    row_str = '\t'.join(row)
+                    file_writer.write("%s\n" % row_str)
+
+                steps_done += 1
+
+        finally:
+            file_writer.close()
+            # TODO: make api `pred_data_generator.close()`
+            pred_data_generator.close()
+        return 0
     # end input
 
     # output
-    if input_type == 'variant_effect':   # TODO: save in batchs
-        rval = pd.DataFrame(preds)
-        meta = pd.DataFrame(
-            pred_data_generator.variants,
-            columns=['chrom', 'pos', 'name', 'ref', 'alt', 'strand'])
-
-        rval = pd.concat([meta, rval], axis=1)
-
-    elif len(preds.shape) == 1:
+    if len(preds.shape) == 1:
         rval = pd.DataFrame(preds, columns=['Predicted'])
     else:
         rval = pd.DataFrame(preds)
 
-    rval.to_csv(outfile_predict, sep='\t',
-                header=True, index=False)
+    rval.to_csv(outfile_predict, sep='\t', header=True, index=False)
 
 
 if __name__ == '__main__':