diff utils.py @ 5:4f7e6612906b draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author bgruening
date Fri, 06 May 2022 09:05:18 +0000
parents afec8c595124
children e94dc7945639
line wrap: on
line diff
--- a/utils.py	Tue Jul 07 03:25:49 2020 -0400
+++ b/utils.py	Fri May 06 09:05:18 2022 +0000
@@ -1,10 +1,11 @@
-import numpy as np
 import json
-import h5py
 import random
+
+import h5py
+import numpy as np
+import tensorflow as tf
 from numpy.random import choice
-
-from keras import backend as K
+from tensorflow.keras import backend
 
 
 def read_file(file_path):
@@ -29,10 +30,10 @@
     """
     Create an h5 file with the trained weights and associated dicts
     """
-    hf_file = h5py.File(dump_file, 'w')
+    hf_file = h5py.File(dump_file, "w")
     for key in model_values:
         value = model_values[key]
-        if key == 'model_weights':
+        if key == "model_weights":
             for idx, item in enumerate(value):
                 w_key = "weight_" + str(idx)
                 if w_key in hf_file:
@@ -54,14 +55,19 @@
     """
     weight_values = list(class_weights.values())
     weight_values.extend(weight_values)
+
     def weighted_binary_crossentropy(y_true, y_pred):
         # add another dimension to compute dot product
-        expanded_weights = K.expand_dims(weight_values, axis=-1)
-        return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights)
+        expanded_weights = tf.expand_dims(weight_values, axis=-1)
+        bce = backend.binary_crossentropy(y_true, y_pred)
+        return backend.dot(bce, expanded_weights)
+
     return weighted_binary_crossentropy
 
 
-def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary):
+def balanced_sample_generator(
+    train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary
+):
     while True:
         dimension = train_data.shape[1]
         n_classes = train_labels.shape[1]
@@ -80,7 +86,18 @@
         yield generator_batch_data, generator_batch_labels
 
 
-def compute_precision(model, x, y, reverse_data_dictionary, usage_scores, actual_classes_pos, topk, standard_conn, last_tool_id, lowest_tool_ids):
+def compute_precision(
+    model,
+    x,
+    y,
+    reverse_data_dictionary,
+    usage_scores,
+    actual_classes_pos,
+    topk,
+    standard_conn,
+    last_tool_id,
+    lowest_tool_ids,
+):
     """
     Compute absolute and compatible precision
     """
@@ -137,7 +154,9 @@
                 else:
                     lowest_pub_prec = np.nan
                 if standard_topk_prediction_pos in usage_scores:
-                    usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0))
+                    usage_wt_score.append(
+                        np.log(usage_scores[standard_topk_prediction_pos] + 1.0)
+                    )
         else:
             # count precision only when there is actually true published tools
             # else set to np.nan. Set to 0 only when there is wrong prediction
@@ -148,7 +167,9 @@
         pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)]
         if pred_t_name in actual_next_tool_names:
             if normal_topk_prediction_pos in usage_scores:
-                usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0))
+                usage_wt_score.append(
+                    np.log(usage_scores[normal_topk_prediction_pos] + 1.0)
+                )
             top_precision = 1.0
             if last_tool_id in lowest_tool_ids:
                 lowest_norm_prec = 1.0
@@ -166,7 +187,16 @@
     return lowest_ids
 
 
-def verify_model(model, x, y, reverse_data_dictionary, usage_scores, standard_conn, lowest_tool_ids, topk_list=[1, 2, 3]):
+def verify_model(
+    model,
+    x,
+    y,
+    reverse_data_dictionary,
+    usage_scores,
+    standard_conn,
+    lowest_tool_ids,
+    topk_list=[1, 2, 3],
+):
     """
     Verify the model on test data
     """
@@ -187,7 +217,24 @@
         test_sample = x[i, :]
         last_tool_id = str(int(test_sample[-1]))
         for index, abs_topk in enumerate(topk_list):
-            usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids)
+            (
+                usg_wt_score,
+                absolute_precision,
+                pub_prec,
+                lowest_p_prec,
+                lowest_n_prec,
+            ) = compute_precision(
+                model,
+                test_sample,
+                y,
+                reverse_data_dictionary,
+                usage_scores,
+                actual_classes_pos,
+                abs_topk,
+                standard_conn,
+                last_tool_id,
+                lowest_tool_ids,
+            )
             precision[i][index] = absolute_precision
             usage_weights[i][index] = usg_wt_score
             epo_pub_prec[i][index] = pub_prec
@@ -202,22 +249,36 @@
     mean_pub_prec = np.nanmean(epo_pub_prec, axis=0)
     mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0)
     mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0)
-    return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, lowest_counter
+    return (
+        mean_usage,
+        mean_precision,
+        mean_pub_prec,
+        mean_lowest_pub_prec,
+        mean_lowest_norm_prec,
+        lowest_counter,
+    )
 
 
-def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections):
+def save_model(
+    results,
+    data_dictionary,
+    compatible_next_tools,
+    trained_model_path,
+    class_weights,
+    standard_connections,
+):
     # save files
     trained_model = results["model"]
     best_model_parameters = results["best_parameters"]
     model_config = trained_model.to_json()
     model_weights = trained_model.get_weights()
     model_values = {
-        'data_dictionary': data_dictionary,
-        'model_config': model_config,
-        'best_parameters': best_model_parameters,
-        'model_weights': model_weights,
+        "data_dictionary": data_dictionary,
+        "model_config": model_config,
+        "best_parameters": best_model_parameters,
+        "model_weights": model_weights,
         "compatible_tools": compatible_next_tools,
         "class_weights": class_weights,
-        "standard_connections": standard_connections
+        "standard_connections": standard_connections,
     }
     set_trained_model(trained_model_path, model_values)