Repository 'create_tool_recommendation_model'
hg clone https://toolshed.g2.bx.psu.edu/repos/bgruening/create_tool_recommendation_model

Changeset 5:4f7e6612906b (2022-05-06)
Previous changeset 4:afec8c595124 (2020-07-07) Next changeset 6:e94dc7945639 (2022-10-16)
Commit message:
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
modified:
create_tool_recommendation_model.xml
extract_workflow_connections.py
main.py
optimise_hyperparameters.py
predict_tool_usage.py
prepare_data.py
test-data/test_tool_usage
test-data/test_workflows
utils.py
b
diff -r afec8c595124 -r 4f7e6612906b create_tool_recommendation_model.xml
--- a/create_tool_recommendation_model.xml Tue Jul 07 03:25:49 2020 -0400
+++ b/create_tool_recommendation_model.xml Fri May 06 09:05:18 2022 +0000
b
@@ -1,13 +1,13 @@
-<tool id="create_tool_recommendation_model" name="Create a model to recommend tools" version="0.0.3">
+<tool id="create_tool_recommendation_model" name="Create a model to recommend tools" version="0.0.4">
     <description>using deep learning</description>
     <requirements>
-        <requirement type="package" version="3.6">python</requirement>
-        <requirement type="package" version="1.13.1">tensorflow</requirement>
-        <requirement type="package" version="2.3.0">keras</requirement>
-        <requirement type="package" version="0.21.3">scikit-learn</requirement>
-        <requirement type="package" version="2.9.0">h5py</requirement>
+        <requirement type="package" version="3.9.7">python</requirement>
+        <requirement type="package" version="2.7.0">tensorflow</requirement>
+        <requirement type="package" version="2.7.0">keras</requirement>
+        <requirement type="package" version="1.0.2">scikit-learn</requirement>
+        <requirement type="package" version="3.6.0">h5py</requirement>
         <requirement type="package" version="1.0.4">csvkit</requirement>
-        <requirement type="package" version="0.1.2">hyperopt</requirement>
+        <requirement type="package" version="0.2.5">hyperopt</requirement>
     </requirements>
     <version_command>echo "@VERSION@"</version_command>
     <command detect_errors="aggressive">
b
diff -r afec8c595124 -r 4f7e6612906b extract_workflow_connections.py
--- a/extract_workflow_connections.py Tue Jul 07 03:25:49 2020 -0400
+++ b/extract_workflow_connections.py Fri May 06 09:05:18 2022 +0000
[
@@ -10,7 +10,6 @@
 
 
 class ExtractWorkflowConnections:
-
     def __init__(self):
         """ Init method. """
 
@@ -33,12 +32,12 @@
         workflow_paths = list()
         unique_paths = dict()
         standard_connections = dict()
-        with open(raw_file_path, 'rt') as workflow_connections_file:
-            workflow_connections = csv.reader(workflow_connections_file, delimiter='\t')
+        with open(raw_file_path, "rt") as workflow_connections_file:
+            workflow_connections = csv.reader(workflow_connections_file, delimiter="\t")
             for index, row in enumerate(workflow_connections):
                 wf_id = str(row[0])
-                in_tool = row[3]
-                out_tool = row[6]
+                in_tool = row[3].strip()
+                out_tool = row[6].strip()
                 if wf_id not in workflows:
                     workflows[wf_id] = list()
                 if out_tool and in_tool and out_tool != in_tool:
@@ -144,7 +143,9 @@
         if end in graph:
             for node in graph[end]:
                 if node not in path:
-                    new_tools_paths = self.find_tool_paths_workflow(graph, start, node, path)
+                    new_tools_paths = self.find_tool_paths_workflow(
+                        graph, start, node, path
+                    )
                     for tool_path in new_tools_paths:
                         path_list.append(tool_path)
         return path_list
b
diff -r afec8c595124 -r 4f7e6612906b main.py
--- a/main.py Tue Jul 07 03:25:49 2020 -0400
+++ b/main.py Fri May 06 09:05:18 2022 +0000
[
b'@@ -3,35 +3,36 @@\n using machine learning (recurrent neural network)\n """\n \n-import numpy as np\n import argparse\n import time\n \n-# machine learning library\n-import tensorflow as tf\n-from keras import backend as K\n+import extract_workflow_connections\n import keras.callbacks as callbacks\n-\n-import extract_workflow_connections\n+import numpy as np\n+import optimise_hyperparameters\n import prepare_data\n-import optimise_hyperparameters\n import utils\n \n \n class PredictTool:\n-\n     def __init__(self, num_cpus):\n         """ Init method. """\n-        # set the number of cpus\n-        cpu_config = tf.ConfigProto(\n-            device_count={"CPU": num_cpus},\n-            intra_op_parallelism_threads=num_cpus,\n-            inter_op_parallelism_threads=num_cpus,\n-            allow_soft_placement=True\n-        )\n-        K.set_session(tf.Session(config=cpu_config))\n \n-    def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, tool_freq, tool_tr_samples):\n+    def find_train_best_network(\n+        self,\n+        network_config,\n+        reverse_dictionary,\n+        train_data,\n+        train_labels,\n+        test_data,\n+        test_labels,\n+        n_epochs,\n+        class_weights,\n+        usage_pred,\n+        standard_connections,\n+        tool_freq,\n+        tool_tr_samples,\n+    ):\n         """\n         Define recurrent neural network and train sequential data\n         """\n@@ -40,11 +41,34 @@\n \n         print("Start hyperparameter optimisation...")\n         hyper_opt = optimise_hyperparameters.HyperparameterOptimisation()\n-        best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights)\n+        best_params, best_model = hyper_opt.train_model(\n+            network_config,\n+            reverse_dictionary,\n+            train_data,\n+            train_labels,\n+            test_data,\n+            test_labels,\n+            tool_tr_samples,\n+            class_weights,\n+        )\n \n         # define callbacks\n-        early_stopping = callbacks.EarlyStopping(monitor=\'loss\', mode=\'min\', verbose=1, min_delta=1e-1, restore_best_weights=True)\n-        predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids)\n+        early_stopping = callbacks.EarlyStopping(\n+            monitor="loss",\n+            mode="min",\n+            verbose=1,\n+            min_delta=1e-1,\n+            restore_best_weights=True,\n+        )\n+        predict_callback_test = PredictCallback(\n+            test_data,\n+            test_labels,\n+            reverse_dictionary,\n+            n_epochs,\n+            usage_pred,\n+            standard_connections,\n+            lowest_tool_ids,\n+        )\n \n         callbacks_list = [predict_callback_test, early_stopping]\n         batch_size = int(best_params["batch_size"])\n@@ -57,21 +81,29 @@\n                 train_labels,\n                 batch_size,\n                 tool_tr_samples,\n-                reverse_dictionary\n+                reverse_dictionary,\n             ),\n             steps_per_epoch=len(train_data) // batch_size,\n             epochs=n_epochs,\n             callbacks=callbacks_list,\n             validation_data=(test_data, test_labels),\n             verbose=2,\n-            shuffle=True\n+            shuffle=True,\n         )\n-        train_performance["validation_loss"] = np.array(trained_model.history["val_loss"])\n+        train_performance["validation_loss"] = np.array(\n+            trained_model.history["val_loss"]\n+        )\n         train_performance["precision"] = predict_callback_test.precision\n         train_performance["usage_weights"] = predict_callback_test.usage_weights\n-        train_performance["published_precision"] = predict_callback_test.published_precision\n-        train_performance["lowest_p'..b'_parser.add_argument(\n+        "-sd",\n+        "--spatial_dropout",\n+        required=True,\n+        help="1d dropout used for embedding layer",\n+    )\n+    arg_parser.add_argument(\n+        "-rd",\n+        "--recurrent_dropout",\n+        required=True,\n+        help="dropout for the recurrent layers",\n+    )\n+    arg_parser.add_argument(\n+        "-lr", "--learning_rate", required=True, help="learning rate"\n+    )\n \n     # get argument values\n     args = vars(arg_parser.parse_args())\n@@ -156,33 +277,74 @@\n     num_cpus = 16\n \n     config = {\n-        \'cutoff_date\': cutoff_date,\n-        \'maximum_path_length\': maximum_path_length,\n-        \'n_epochs\': n_epochs,\n-        \'optimize_n_epochs\': optimize_n_epochs,\n-        \'max_evals\': max_evals,\n-        \'test_share\': test_share,\n-        \'batch_size\': batch_size,\n-        \'units\': units,\n-        \'embedding_size\': embedding_size,\n-        \'dropout\': dropout,\n-        \'spatial_dropout\': spatial_dropout,\n-        \'recurrent_dropout\': recurrent_dropout,\n-        \'learning_rate\': learning_rate\n+        "cutoff_date": cutoff_date,\n+        "maximum_path_length": maximum_path_length,\n+        "n_epochs": n_epochs,\n+        "optimize_n_epochs": optimize_n_epochs,\n+        "max_evals": max_evals,\n+        "test_share": test_share,\n+        "batch_size": batch_size,\n+        "units": units,\n+        "embedding_size": embedding_size,\n+        "dropout": dropout,\n+        "spatial_dropout": spatial_dropout,\n+        "recurrent_dropout": recurrent_dropout,\n+        "learning_rate": learning_rate,\n     }\n \n     # Extract and process workflows\n     connections = extract_workflow_connections.ExtractWorkflowConnections()\n-    workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path)\n+    (\n+        workflow_paths,\n+        compatible_next_tools,\n+        standard_connections,\n+    ) = connections.read_tabular_file(workflows_path)\n     # Process the paths from workflows\n     print("Dividing data...")\n     data = prepare_data.PrepareData(maximum_path_length, test_share)\n-    train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, train_tool_freq, tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections)\n+    (\n+        train_data,\n+        train_labels,\n+        test_data,\n+        test_labels,\n+        data_dictionary,\n+        reverse_dictionary,\n+        class_weights,\n+        usage_pred,\n+        train_tool_freq,\n+        tool_tr_samples,\n+    ) = data.get_data_labels_matrices(\n+        workflow_paths,\n+        tool_usage_path,\n+        cutoff_date,\n+        compatible_next_tools,\n+        standard_connections,\n+    )\n     # find the best model and start training\n     predict_tool = PredictTool(num_cpus)\n     # start training with weighted classes\n     print("Training with weighted classes and samples ...")\n-    results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, train_tool_freq, tool_tr_samples)\n-    utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections)\n+    results_weighted = predict_tool.find_train_best_network(\n+        config,\n+        reverse_dictionary,\n+        train_data,\n+        train_labels,\n+        test_data,\n+        test_labels,\n+        n_epochs,\n+        class_weights,\n+        usage_pred,\n+        standard_connections,\n+        train_tool_freq,\n+        tool_tr_samples,\n+    )\n+    utils.save_model(\n+        results_weighted,\n+        data_dictionary,\n+        compatible_next_tools,\n+        trained_model_path,\n+        class_weights,\n+        standard_connections,\n+    )\n     end_time = time.time()\n     print("Program finished in %s seconds" % str(end_time - start_time))\n'
b
diff -r afec8c595124 -r 4f7e6612906b optimise_hyperparameters.py
--- a/optimise_hyperparameters.py Tue Jul 07 03:25:49 2020 -0400
+++ b/optimise_hyperparameters.py Fri May 06 09:05:18 2022 +0000
[
@@ -3,24 +3,29 @@
 """
 
 import numpy as np
-from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
-
-from keras.models import Sequential
-from keras.layers import Dense, GRU, Dropout
-from keras.layers.embeddings import Embedding
-from keras.layers.core import SpatialDropout1D
-from keras.optimizers import RMSprop
-from keras.callbacks import EarlyStopping
-
 import utils
+from hyperopt import fmin, hp, STATUS_OK, tpe, Trials
+from tensorflow.keras.callbacks import EarlyStopping
+from tensorflow.keras.layers import Dense, Dropout, Embedding, GRU, SpatialDropout1D
+from tensorflow.keras.models import Sequential
+from tensorflow.keras.optimizers import RMSprop
 
 
 class HyperparameterOptimisation:
-
     def __init__(self):
         """ Init method. """
 
-    def train_model(self, config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights):
+    def train_model(
+        self,
+        config,
+        reverse_dictionary,
+        train_data,
+        train_labels,
+        test_data,
+        test_labels,
+        tool_tr_samples,
+        class_weights,
+    ):
         """
         Train a model and report accuracy
         """
@@ -40,52 +45,101 @@
         # get dimensions
         dimensions = len(reverse_dictionary) + 1
         best_model_params = dict()
-        early_stopping = EarlyStopping(monitor='val_loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True)
+        early_stopping = EarlyStopping(
+            monitor="val_loss",
+            mode="min",
+            verbose=1,
+            min_delta=1e-1,
+            restore_best_weights=True,
+        )
 
         # specify the search space for finding the best combination of parameters using Bayesian optimisation
         params = {
-            "embedding_size": hp.quniform("embedding_size", l_embedding_size[0], l_embedding_size[1], 1),
+            "embedding_size": hp.quniform(
+                "embedding_size", l_embedding_size[0], l_embedding_size[1], 1
+            ),
             "units": hp.quniform("units", l_units[0], l_units[1], 1),
-            "batch_size": hp.quniform("batch_size", l_batch_size[0], l_batch_size[1], 1),
-            "learning_rate": hp.loguniform("learning_rate", np.log(l_learning_rate[0]), np.log(l_learning_rate[1])),
+            "batch_size": hp.quniform(
+                "batch_size", l_batch_size[0], l_batch_size[1], 1
+            ),
+            "learning_rate": hp.loguniform(
+                "learning_rate", np.log(l_learning_rate[0]), np.log(l_learning_rate[1])
+            ),
             "dropout": hp.uniform("dropout", l_dropout[0], l_dropout[1]),
-            "spatial_dropout": hp.uniform("spatial_dropout", l_spatial_dropout[0], l_spatial_dropout[1]),
-            "recurrent_dropout": hp.uniform("recurrent_dropout", l_recurrent_dropout[0], l_recurrent_dropout[1])
+            "spatial_dropout": hp.uniform(
+                "spatial_dropout", l_spatial_dropout[0], l_spatial_dropout[1]
+            ),
+            "recurrent_dropout": hp.uniform(
+                "recurrent_dropout", l_recurrent_dropout[0], l_recurrent_dropout[1]
+            ),
         }
 
         def create_model(params):
             model = Sequential()
-            model.add(Embedding(dimensions, int(params["embedding_size"]), mask_zero=True))
+            model.add(
+                Embedding(dimensions, int(params["embedding_size"]), mask_zero=True)
+            )
             model.add(SpatialDropout1D(params["spatial_dropout"]))
-            model.add(GRU(int(params["units"]), dropout=params["dropout"], recurrent_dropout=params["recurrent_dropout"], return_sequences=True, activation="elu"))
+            model.add(
+                GRU(
+                    int(params["units"]),
+                    dropout=params["dropout"],
+                    recurrent_dropout=params["recurrent_dropout"],
+                    return_sequences=True,
+                    activation="elu",
+                )
+            )
             model.add(Dropout(params["dropout"]))
-            model.add(GRU(int(params["units"]), dropout=params["dropout"], recurrent_dropout=params["recurrent_dropout"], return_sequences=False, activation="elu"))
+            model.add(
+                GRU(
+                    int(params["units"]),
+                    dropout=params["dropout"],
+                    recurrent_dropout=params["recurrent_dropout"],
+                    return_sequences=False,
+                    activation="elu",
+                )
+            )
             model.add(Dropout(params["dropout"]))
             model.add(Dense(2 * dimensions, activation="sigmoid"))
             optimizer_rms = RMSprop(lr=params["learning_rate"])
             batch_size = int(params["batch_size"])
-            model.compile(loss=utils.weighted_loss(class_weights), optimizer=optimizer_rms)
+            model.compile(
+                loss=utils.weighted_loss(class_weights), optimizer=optimizer_rms
+            )
             print(model.summary())
-            model_fit = model.fit_generator(
+            model_fit = model.fit(
                 utils.balanced_sample_generator(
                     train_data,
                     train_labels,
                     batch_size,
                     tool_tr_samples,
-                    reverse_dictionary
+                    reverse_dictionary,
                 ),
                 steps_per_epoch=len(train_data) // batch_size,
                 epochs=optimize_n_epochs,
                 callbacks=[early_stopping],
                 validation_data=(test_data, test_labels),
                 verbose=2,
-                shuffle=True
+                shuffle=True,
             )
-            return {'loss': model_fit.history["val_loss"][-1], 'status': STATUS_OK, 'model': model}
+            return {
+                "loss": model_fit.history["val_loss"][-1],
+                "status": STATUS_OK,
+                "model": model,
+            }
+
         # minimize the objective function using the set of parameters above
         trials = Trials()
-        learned_params = fmin(create_model, params, trials=trials, algo=tpe.suggest, max_evals=int(config["max_evals"]))
-        best_model = trials.results[np.argmin([r['loss'] for r in trials.results])]['model']
+        learned_params = fmin(
+            create_model,
+            params,
+            trials=trials,
+            algo=tpe.suggest,
+            max_evals=int(config["max_evals"]),
+        )
+        best_model = trials.results[np.argmin([r["loss"] for r in trials.results])][
+            "model"
+        ]
         # set the best params with respective values
         for item in learned_params:
             item_val = learned_params[item]
b
diff -r afec8c595124 -r 4f7e6612906b predict_tool_usage.py
--- a/predict_tool_usage.py Tue Jul 07 03:25:49 2020 -0400
+++ b/predict_tool_usage.py Fri May 06 09:05:18 2022 +0000
[
@@ -2,17 +2,16 @@
 Predict tool usage to weigh the predicted tools
 """
 
-import os
-import numpy as np
-import warnings
+import collections
 import csv
-import collections
+import os
+import warnings
 
-from sklearn.svm import SVR
+import numpy as np
+import utils
 from sklearn.model_selection import GridSearchCV
 from sklearn.pipeline import Pipeline
-
-import utils
+from sklearn.svm import SVR
 
 warnings.filterwarnings("ignore")
 
@@ -20,7 +19,6 @@
 
 
 class ToolPopularity:
-
     def __init__(self):
         """ Init method. """
 
@@ -31,10 +29,11 @@
         tool_usage_dict = dict()
         all_dates = list()
         all_tool_list = list(dictionary.keys())
-        with open(tool_usage_file, 'rt') as usage_file:
-            tool_usage = csv.reader(usage_file, delimiter='\t')
+        with open(tool_usage_file, "rt") as usage_file:
+            tool_usage = csv.reader(usage_file, delimiter="\t")
             for index, row in enumerate(tool_usage):
-                if (str(row[1]) > cutoff_date) is True:
+                row = [item.strip() for item in row]
+                if (str(row[1]).strip() > cutoff_date) is True:
                     tool_id = utils.format_tool_id(row[0])
                     if tool_id in all_tool_list:
                         all_dates.append(row[1])
@@ -67,18 +66,25 @@
         """
         epsilon = 0.0
         cv = 5
-        s_typ = 'neg_mean_absolute_error'
+        s_typ = "neg_mean_absolute_error"
         n_jobs = 4
         s_error = 1
-        iid = True
         tr_score = False
         try:
-            pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))])
+            pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))])
             param_grid = {
-                'regressor__kernel': ['rbf', 'poly', 'linear'],
-                'regressor__degree': [2, 3]
+                "regressor__kernel": ["rbf", "poly", "linear"],
+                "regressor__degree": [2, 3],
             }
-            search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score)
+            search = GridSearchCV(
+                pipe,
+                param_grid,
+                cv=cv,
+                scoring=s_typ,
+                n_jobs=n_jobs,
+                error_score=s_error,
+                return_train_score=tr_score,
+            )
             search.fit(x_reshaped, y_reshaped.ravel())
             model = search.best_estimator_
             # set the next time point to get prediction for
@@ -87,7 +93,8 @@
             if prediction < epsilon:
                 prediction = [epsilon]
             return prediction[0]
-        except Exception:
+        except Exception as e:
+            print(e)
             return epsilon
 
     def get_pupularity_prediction(self, tools_usage):
b
diff -r afec8c595124 -r 4f7e6612906b prepare_data.py
--- a/prepare_data.py Tue Jul 07 03:25:49 2020 -0400
+++ b/prepare_data.py Fri May 06 09:05:18 2022 +0000
[
@@ -4,18 +4,17 @@
 into the test and training sets
 """
 
+import collections
 import os
-import collections
-import numpy as np
 import random
 
+import numpy as np
 import predict_tool_usage
 
 main_path = os.getcwd()
 
 
 class PrepareData:
-
     def __init__(self, max_seq_length, test_data_share):
         """ Init method. """
         self.max_tool_sequence_len = max_seq_length
@@ -27,15 +26,20 @@
         """
         tokens = list()
         raw_paths = workflow_paths
-        raw_paths = [x.replace("\n", '') for x in raw_paths]
+        raw_paths = [x.replace("\n", "") for x in raw_paths]
         for item in raw_paths:
             split_items = item.split(",")
             for token in split_items:
-                if token is not "":
+                if token != "":
                     tokens.append(token)
         tokens = list(set(tokens))
         tokens = np.array(tokens)
-        tokens = np.reshape(tokens, [-1, ])
+        tokens = np.reshape(
+            tokens,
+            [
+                -1,
+            ],
+        )
         return tokens, raw_paths
 
     def create_new_dict(self, new_data_dict):
@@ -60,7 +64,10 @@
         dictionary = dict()
         for word, _ in count:
             dictionary[word] = len(dictionary) + 1
-        dictionary, reverse_dictionary = self.assemble_dictionary(dictionary, old_data_dictionary)
+            word = word.strip()
+        dictionary, reverse_dictionary = self.assemble_dictionary(
+            dictionary, old_data_dictionary
+        )
         return dictionary, reverse_dictionary
 
     def decompose_paths(self, paths, dictionary):
@@ -74,13 +81,17 @@
             if len_tools <= self.max_tool_sequence_len:
                 for window in range(1, len_tools):
                     sequence = tools[0: window + 1]
-                    tools_pos = [str(dictionary[str(tool_item)]) for tool_item in sequence]
+                    tools_pos = [
+                        str(dictionary[str(tool_item)]) for tool_item in sequence
+                    ]
                     if len(tools_pos) > 1:
                         sub_paths_pos.append(",".join(tools_pos))
         sub_paths_pos = list(set(sub_paths_pos))
         return sub_paths_pos
 
-    def prepare_paths_labels_dictionary(self, dictionary, reverse_dictionary, paths, compatible_next_tools):
+    def prepare_paths_labels_dictionary(
+        self, dictionary, reverse_dictionary, paths, compatible_next_tools
+    ):
         """
         Create a dictionary of sequences with their labels for training and test paths
         """
@@ -90,14 +101,18 @@
             if item and item not in "":
                 tools = item.split(",")
                 label = tools[-1]
-                train_tools = tools[:len(tools) - 1]
+                train_tools = tools[: len(tools) - 1]
                 last_but_one_name = reverse_dictionary[int(train_tools[-1])]
                 try:
-                    compatible_tools = compatible_next_tools[last_but_one_name].split(",")
+                    compatible_tools = compatible_next_tools[last_but_one_name].split(
+                        ","
+                    )
                 except Exception:
                     continue
                 if len(compatible_tools) > 0:
-                    compatible_tools_ids = [str(dictionary[x]) for x in compatible_tools]
+                    compatible_tools_ids = [
+                        str(dictionary[x]) for x in compatible_tools
+                    ]
                     compatible_tools_ids.append(label)
                     composite_labels = ",".join(compatible_tools_ids)
                 train_tools = ",".join(train_tools)
@@ -127,7 +142,9 @@
             train_counter += 1
         return data_mat, label_mat
 
-    def pad_paths(self, paths_dictionary, num_classes, standard_connections, reverse_dictionary):
+    def pad_paths(
+        self, paths_dictionary, num_classes, standard_connections, reverse_dictionary
+    ):
         """
         Add padding to the tools sequences and create multi-hot encoded labels
         """
@@ -231,12 +248,22 @@
                     l_tool_tr_samples[last_tool_id].append(index)
         return l_tool_tr_samples
 
-    def get_data_labels_matrices(self, workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections, old_data_dictionary={}):
+    def get_data_labels_matrices(
+        self,
+        workflow_paths,
+        tool_usage_path,
+        cutoff_date,
+        compatible_next_tools,
+        standard_connections,
+        old_data_dictionary={},
+    ):
         """
         Convert the training and test paths into corresponding numpy matrices
         """
         processed_data, raw_paths = self.process_workflow_paths(workflow_paths)
-        dictionary, rev_dict = self.create_data_dictionary(processed_data, old_data_dictionary)
+        dictionary, rev_dict = self.create_data_dictionary(
+            processed_data, old_data_dictionary
+        )
         num_classes = len(dictionary)
 
         print("Raw paths: %d" % len(raw_paths))
@@ -247,18 +274,26 @@
         random.shuffle(all_unique_paths)
 
         print("Creating dictionaries...")
-        multilabels_paths = self.prepare_paths_labels_dictionary(dictionary, rev_dict, all_unique_paths, compatible_next_tools)
+        multilabels_paths = self.prepare_paths_labels_dictionary(
+            dictionary, rev_dict, all_unique_paths, compatible_next_tools
+        )
 
         print("Complete data: %d" % len(multilabels_paths))
-        train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths)
+        train_paths_dict, test_paths_dict = self.split_test_train_data(
+            multilabels_paths
+        )
 
         print("Train data: %d" % len(train_paths_dict))
         print("Test data: %d" % len(test_paths_dict))
 
         print("Padding train and test data...")
         # pad training and test data with leading zeros
-        test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict)
-        train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict)
+        test_data, test_labels = self.pad_paths(
+            test_paths_dict, num_classes, standard_connections, rev_dict
+        )
+        train_data, train_labels = self.pad_paths(
+            train_paths_dict, num_classes, standard_connections, rev_dict
+        )
 
         print("Estimating sample frequency...")
         l_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict)
@@ -274,4 +309,15 @@
         # get class weights using the predicted usage for each tool
         class_weights = self.assign_class_weights(num_classes, t_pred_usage)
 
-        return train_data, train_labels, test_data, test_labels, dictionary, rev_dict, class_weights, t_pred_usage, l_tool_freq, l_tool_tr_samples
+        return (
+            train_data,
+            train_labels,
+            test_data,
+            test_labels,
+            dictionary,
+            rev_dict,
+            class_weights,
+            t_pred_usage,
+            l_tool_freq,
+            l_tool_tr_samples,
+        )
b
diff -r afec8c595124 -r 4f7e6612906b test-data/test_tool_usage
--- a/test-data/test_tool_usage Tue Jul 07 03:25:49 2020 -0400
+++ b/test-data/test_tool_usage Fri May 06 09:05:18 2022 +0000
b
b'@@ -1,500 +1,93 @@\n-toolshed.g2.bx.psu.edu/repos/bgruening/rdock_rbdock/rdock_rbdock/0.1.1\t2020-04-01\t34568\n-upload1\t2020-04-01\t18321\n-toolshed.g2.bx.psu.edu/repos/devteam/fastqc/fastqc/0.72+galaxy1\t2020-04-01\t2839\n-toolshed.g2.bx.psu.edu/repos/bgruening/xchem_transfs_scoring/xchem_transfs_scoring/0.2.0\t2020-04-01\t1220\n-toolshed.g2.bx.psu.edu/repos/devteam/bowtie2/bowtie2/2.3.4.3+galaxy0\t2020-04-01\t958\n-CONVERTER_gz_to_uncompressed\t2020-04-01\t919\n-Filter1\t2020-04-01\t800\n-Cut1\t2020-04-01\t787\n-__SET_METADATA__\t2020-04-01\t685\n-toolshed.g2.bx.psu.edu/repos/pjbriggs/trimmomatic/trimmomatic/0.36.5\t2020-04-01\t643\n-toolshed.g2.bx.psu.edu/repos/iuc/featurecounts/featurecounts/1.6.4+galaxy1\t2020-04-01\t627\n-toolshed.g2.bx.psu.edu/repos/iuc/rgrnastar/rna_star/2.7.2b\t2020-04-01\t605\n-toolshed.g2.bx.psu.edu/repos/bgruening/trim_galore/trim_galore/0.4.3.1\t2020-04-01\t538\n-toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_plot_heatmap/deeptools_plot_heatmap/3.3.2.0.1\t2020-04-01\t530\n-toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_compute_matrix/deeptools_compute_matrix/3.3.2.0.0\t2020-04-01\t475\n-Remove beginning1\t2020-04-01\t447\n-toolshed.g2.bx.psu.edu/repos/bgruening/text_processing/tp_sort_header_tool/1.1.1\t2020-04-01\t408\n-join1\t2020-04-01\t369\n-toolshed.g2.bx.psu.edu/repos/iuc/hisat2/hisat2/2.1.0+galaxy5\t2020-04-01\t357\n-Convert characters1\t2020-04-01\t351\n-toolshed.g2.bx.psu.edu/repos/devteam/ncbi_blast_plus/ncbi_blastp_wrapper/0.3.3\t2020-04-01\t336\n-toolshed.g2.bx.psu.edu/repos/lparsons/cutadapt/cutadapt/1.16.5\t2020-04-01\t316\n-toolshed.g2.bx.psu.edu/repos/iuc/bedtools/bedtools_intersectbed/2.29.0\t2020-04-01\t313\n-toolshed.g2.bx.psu.edu/repos/iuc/samtools_fastx/samtools_fastx/1.9+galaxy1\t2020-04-01\t311\n-toolshed.g2.bx.psu.edu/repos/iuc/fastp/fastp/0.19.5+galaxy1\t2020-04-01\t305\n-toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_bam_coverage/deeptools_bam_coverage/3.3.2.0.0\t2020-04-01\t302\n-cat1\t2020-04-01\t299\n-toolshed.g2.bx.psu.edu/repos/iuc/macs2/macs2_callpeak/2.1.1.20160309.6\t2020-04-01\t298\n-addValue\t2020-04-01\t294\n-toolshed.g2.bx.psu.edu/repos/devteam/column_maker/Add_a_column1/1.1.0\t2020-04-01\t276\n-toolshed.g2.bx.psu.edu/repos/devteam/bwa/bwa_mem/0.7.17.1\t2020-04-01\t275\n-toolshed.g2.bx.psu.edu/repos/devteam/fastq_to_fasta/cshl_fastq_to_fasta/1.0.2\t2020-04-01\t274\n-toolshed.g2.bx.psu.edu/repos/iuc/multiqc/multiqc/1.7.1\t2020-04-01\t274\n-ebi_sra_main\t2020-04-01\t272\n-Grouping1\t2020-04-01\t237\n-toolshed.g2.bx.psu.edu/repos/devteam/picard/picard_MarkDuplicates/2.18.2.2\t2020-04-01\t236\n-toolshed.g2.bx.psu.edu/repos/lparsons/htseq_count/htseq_count/0.9.1\t2020-04-01\t233\n-toolshed.g2.bx.psu.edu/repos/devteam/fastq_groomer/fastq_groomer/1.1.5\t2020-04-01\t220\n-toolshed.g2.bx.psu.edu/repos/galaxyp/filter_by_fasta_ids/filter_by_fasta_ids/2.1\t2020-04-01\t219\n-toolshed.g2.bx.psu.edu/repos/iuc/sra_tools/fastq_dump/2.10.4+galaxy1\t2020-04-01\t218\n-toolshed.g2.bx.psu.edu/repos/bgruening/sucos_max_score/sucos_max_score/0.2.1\t2020-04-01\t216\n-toolshed.g2.bx.psu.edu/repos/bgruening/text_processing/tp_cut_tool/1.1.0\t2020-04-01\t216\n-toolshed.g2.bx.psu.edu/repos/jjohnson/query_tabular/query_tabular/3.0.0\t2020-04-01\t209\n-toolshed.g2.bx.psu.edu/repos/crs4/prokka/prokka/1.14.5\t2020-04-01\t204\n-toolshed.g2.bx.psu.edu/repos/iuc/snippy/snippy/4.5.0\t2020-04-01\t187\n-toolshed.g2.bx.psu.edu/repos/devteam/fastq_paired_end_joiner/fastq_paired_end_joiner/2.0.1.1+galaxy0\t2020-04-01\t185\n-CONVERTER_interval_to_bed_0\t2020-04-01\t182\n-toolshed.g2.bx.psu.edu/repos/nilesh/rseqc/rseqc_read_distribution/2.6.4.1\t2020-04-01\t178\n-ucsc_table_direct1\t2020-04-01\t177\n-toolshed.g2.bx.psu.edu/repos/iuc/deseq2/deseq2/2.11.40.6+galaxy1\t2020-04-01\t176\n-toolshed.g2.bx.psu.edu/repos/devteam/fastq_paired_end_deinterlacer/fastq_paired_end_deinterlacer/1.1.5\t2020-04-01\t175\n-toolshed.g2.bx.psu.edu/repos/bgruening/text_processing/tp_easyjoin_tool/1.1.1\t2020-04-01\t172\n-toolshed.g2.bx.psu.edu/repos/iuc/sra_tools/fastq_dump/2.10.4\t2020-04-01\t169\n-toolshed.g2.bx.psu.edu/repos/devteam/fastqtofasta/fastq_to_fasta_python/1.1.'..b'22-02-01\t1036\n+ toolshed.g2.bx.psu.edu/repos/devteam/column_maker/Add_a_column1/1.3.0                                                                               \t2022-02-01\t1019\n+ toolshed.g2.bx.psu.edu/repos/iuc/multiqc/multiqc/1.11+galaxy0                                                                                       \t2022-02-01\t973\n+ toolshed.g2.bx.psu.edu/repos/bgruening/text_processing/tp_cat/0.1.1                                                                                 \t2022-02-01\t962\n+ toolshed.g2.bx.psu.edu/repos/iuc/sra_tools/fasterq_dump/2.11.0+galaxy1                                                                              \t2022-02-01\t926\n+ toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_compute_matrix/deeptools_compute_matrix/3.3.2.0.0                                                  \t2022-02-01\t915\n+ toolshed.g2.bx.psu.edu/repos/devteam/fastq_groomer/fastq_groomer/1.1.5                                                                              \t2022-02-01\t907\n+ toolshed.g2.bx.psu.edu/repos/iuc/obi_illumina_pairend/obi_illumina_pairend/1.2.13                                                                   \t2022-02-01\t897\n+ toolshed.g2.bx.psu.edu/repos/nilesh/rseqc/rseqc_infer_experiment/2.6.4.1                                                                            \t2022-02-01\t879\n+ toolshed.g2.bx.psu.edu/repos/devteam/bowtie2/bowtie2/2.4.5+galaxy0                                                                                  \t2022-02-01\t871\n+ toolshed.g2.bx.psu.edu/repos/devteam/bowtie2/bowtie2/2.4.2+galaxy0                                                                                  \t2022-02-01\t862\n+ cat1                                                                                                                                                \t2022-02-01\t795\n+ toolshed.g2.bx.psu.edu/repos/lparsons/cutadapt/cutadapt/3.5+galaxy1                                                                                 \t2022-02-01\t772\n+ toolshed.g2.bx.psu.edu/repos/peterjc/blast_rbh/blast_reciprocal_best_hits/0.1.11                                                                    \t2022-02-01\t766\n+ toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_plot_heatmap/deeptools_plot_heatmap/3.3.2.0.1                                                      \t2022-02-01\t758\n+ toolshed.g2.bx.psu.edu/repos/galaxyp/regex_find_replace/regexColumn1/1.0.1                                                                          \t2022-02-01\t705\n+ toolshed.g2.bx.psu.edu/repos/iuc/macs2/macs2_callpeak/2.1.1.20160309.6                                                                              \t2022-02-01\t652\n+ toolshed.g2.bx.psu.edu/repos/devteam/ncbi_blast_plus/ncbi_blastn_wrapper/2.10.1+galaxy0                                                             \t2022-02-01\t648\n+ toolshed.g2.bx.psu.edu/repos/peterjc/seq_filter_by_mapping/seq_filter_by_mapping/0.0.6                                                              \t2022-02-01\t614\n+ toolshed.g2.bx.psu.edu/repos/iuc/deseq2/deseq2/2.11.40.7+galaxy1                                                                                    \t2022-02-01\t606\n+ toolshed.g2.bx.psu.edu/repos/iuc/compose_text_param/compose_text_param/0.1.1                                                                        \t2022-02-01\t588\n+ join1                                                                                                                                               \t2022-02-01\t584\n+ toolshed.g2.bx.psu.edu/repos/devteam/fastqtofasta/fastq_to_fasta_python/1.1.5                                                                       \t2022-02-01\t580\n+ toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_plot_profile/deeptools_plot_profile/3.3.2.0.0                                                      \t2022-02-01\t555\n+ toolshed.g2.bx.psu.edu/repos/lparsons/htseq_count/htseq_count/0.9.1+galaxy1                                                                         \t2022-02-01\t544\n'
b
diff -r afec8c595124 -r 4f7e6612906b test-data/test_workflows
--- a/test-data/test_workflows Tue Jul 07 03:25:49 2020 -0400
+++ b/test-data/test_workflows Fri May 06 09:05:18 2022 +0000
b
b'@@ -1,1000 +1,499 @@\n-3\t2013-02-07 16:48:46.721866\t5\tGrep1\t1.0.1\t7\tRemove beginning1\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t6\tCut1\t1.0.1\t8\taddValue\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t7\tRemove beginning1\t1.0.0\t9\tCut1\t1.0.1\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t7\tRemove beginning1\t1.0.0\t6\tCut1\t1.0.1\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t8\taddValue\t1.0.0\t11\tPaste1\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t9\tCut1\t1.0.1\t11\tPaste1\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t11\tPaste1\t1.0.0\t10\taddValue\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t12\t\t\t5\tGrep1\t1.0.1\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t13\tcat1\t1.0.0\t22\tbarchart_gnuplot\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t14\tbedtools_intersectBed\t\t16\twc_gnu\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t15\tbedtools_intersectBed\t\t17\tsort1\t1.0.1\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t16\twc_gnu\t1.0.0\t18\taddValue\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t17\tsort1\t1.0.1\t19\tcshl_awk_tool\t\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t18\taddValue\t1.0.0\t13\tcat1\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t19\tcshl_awk_tool\t\t21\tcshl_uniq_tool\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t20\tCount1\t1.0.0\t13\tcat1\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t21\tcshl_uniq_tool\t1.0.0\t20\tCount1\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t23\t\t\t15\tbedtools_intersectBed\t\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t23\t\t\t14\tbedtools_intersectBed\t\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t24\t\t\t15\tbedtools_intersectBed\t\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t24\t\t\t14\tbedtools_intersectBed\t\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t25\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t26\tcat1\t1.0.0\t67\tbarchart_gnuplot\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t27\tCut1\t1.0.1\t59\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t28\tCut1\t1.0.1\t59\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t29\tCut1\t1.0.1\t66\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t30\tCut1\t1.0.1\t66\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t31\tCut1\t1.0.1\t36\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t32\tCut1\t1.0.1\t36\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t33\tCut1\t1.0.1\t60\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t34\tCut1\t1.0.1\t60\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t35\tPaste1\t1.0.0\t65\tAdd_a_column1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t36\tPaste1\t1.0.0\t64\tAdd_a_column1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t37\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t38\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t39\tFilter1\t1.1.0\t46\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t39\tFilter1\t1.1.0\t45\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t44\tFilter1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t43\tFilter1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t42\tFilter1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t41\tcshl_grep_tool\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t39\tFilter1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t41\tcshl_grep_tool\t1.0.0\t52\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t41\tcshl_grep_tool\t1.0.0\t51\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t42\tFilter1\t1.1.0\t50\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t42\tFilter1\t1.1.0\t49\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t43\tFilter1\t1.1.0\t56\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t43\tFilter1\t1.1.0\t55\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t44\tFilter1\t1.1.0\t54\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t44\tFilter1\t1.1.0\t53\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t45\tSummary_Statistics1\t1.1.0\t57\tCut1\t1.0.1\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t46\tSummary_Statistics1\t1.1.0\t58\tCut1\t1.0.1\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t47\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t48\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:4'..b'0                                        \t276\t mergeCols1                                                                                                                             \t 1.0.1                                        \t f         \t f       \t f\n+31\t2013-02-18\t276\t mergeCols1                                                                                                                             \t 1.0.1                                        \t277\t Cut1                                                                                                                                   \t 1.0.1                                        \t f         \t f       \t f\n+31\t2013-02-18\t274\t                                                                                                                                        \t                                              \t275\t addValue                                                                                                                               \t 1.0.0                                        \t f         \t f       \t f\n+31\t2013-02-18\t275\t addValue                                                                                                                               \t 1.0.0                                        \t276\t mergeCols1                                                                                                                             \t 1.0.1                                        \t f         \t f       \t f\n+31\t2013-02-18\t276\t mergeCols1                                                                                                                             \t 1.0.1                                        \t277\t Cut1                                                                                                                                   \t 1.0.1                                        \t f         \t f       \t f\n+31\t2013-02-18\t274\t                                                                                                                                        \t                                              \t275\t addValue                                                                                                                               \t 1.0.0                                        \t f         \t f       \t f\n+31\t2013-02-18\t275\t addValue                                                                                                                               \t 1.0.0                                        \t276\t mergeCols1                                                                                                                             \t 1.0.1                                        \t f         \t f       \t f\n+31\t2013-02-18\t276\t mergeCols1                                                                                                                             \t 1.0.1                                        \t277\t Cut1                                                                                                                                   \t 1.0.1                                        \t f         \t f       \t f\n+31\t2013-02-18\t274\t                                                                                                                                        \t                                              \t275\t addValue                                                                                                                               \t 1.0.0                                        \t f         \t f       \t f\n+31\t2013-02-18\t275\t addValue                                                                                                                               \t 1.0.0                                        \t276\t mergeCols1                                                                                                                             \t 1.0.1                                        \t f         \t f       \t f\n'
b
diff -r afec8c595124 -r 4f7e6612906b utils.py
--- a/utils.py Tue Jul 07 03:25:49 2020 -0400
+++ b/utils.py Fri May 06 09:05:18 2022 +0000
[
@@ -1,10 +1,11 @@
-import numpy as np
 import json
-import h5py
 import random
+
+import h5py
+import numpy as np
+import tensorflow as tf
 from numpy.random import choice
-
-from keras import backend as K
+from tensorflow.keras import backend
 
 
 def read_file(file_path):
@@ -29,10 +30,10 @@
     """
     Create an h5 file with the trained weights and associated dicts
     """
-    hf_file = h5py.File(dump_file, 'w')
+    hf_file = h5py.File(dump_file, "w")
     for key in model_values:
         value = model_values[key]
-        if key == 'model_weights':
+        if key == "model_weights":
             for idx, item in enumerate(value):
                 w_key = "weight_" + str(idx)
                 if w_key in hf_file:
@@ -54,14 +55,19 @@
     """
     weight_values = list(class_weights.values())
     weight_values.extend(weight_values)
+
     def weighted_binary_crossentropy(y_true, y_pred):
         # add another dimension to compute dot product
-        expanded_weights = K.expand_dims(weight_values, axis=-1)
-        return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights)
+        expanded_weights = tf.expand_dims(weight_values, axis=-1)
+        bce = backend.binary_crossentropy(y_true, y_pred)
+        return backend.dot(bce, expanded_weights)
+
     return weighted_binary_crossentropy
 
 
-def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary):
+def balanced_sample_generator(
+    train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary
+):
     while True:
         dimension = train_data.shape[1]
         n_classes = train_labels.shape[1]
@@ -80,7 +86,18 @@
         yield generator_batch_data, generator_batch_labels
 
 
-def compute_precision(model, x, y, reverse_data_dictionary, usage_scores, actual_classes_pos, topk, standard_conn, last_tool_id, lowest_tool_ids):
+def compute_precision(
+    model,
+    x,
+    y,
+    reverse_data_dictionary,
+    usage_scores,
+    actual_classes_pos,
+    topk,
+    standard_conn,
+    last_tool_id,
+    lowest_tool_ids,
+):
     """
     Compute absolute and compatible precision
     """
@@ -137,7 +154,9 @@
                 else:
                     lowest_pub_prec = np.nan
                 if standard_topk_prediction_pos in usage_scores:
-                    usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0))
+                    usage_wt_score.append(
+                        np.log(usage_scores[standard_topk_prediction_pos] + 1.0)
+                    )
         else:
             # count precision only when there is actually true published tools
             # else set to np.nan. Set to 0 only when there is wrong prediction
@@ -148,7 +167,9 @@
         pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)]
         if pred_t_name in actual_next_tool_names:
             if normal_topk_prediction_pos in usage_scores:
-                usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0))
+                usage_wt_score.append(
+                    np.log(usage_scores[normal_topk_prediction_pos] + 1.0)
+                )
             top_precision = 1.0
             if last_tool_id in lowest_tool_ids:
                 lowest_norm_prec = 1.0
@@ -166,7 +187,16 @@
     return lowest_ids
 
 
-def verify_model(model, x, y, reverse_data_dictionary, usage_scores, standard_conn, lowest_tool_ids, topk_list=[1, 2, 3]):
+def verify_model(
+    model,
+    x,
+    y,
+    reverse_data_dictionary,
+    usage_scores,
+    standard_conn,
+    lowest_tool_ids,
+    topk_list=[1, 2, 3],
+):
     """
     Verify the model on test data
     """
@@ -187,7 +217,24 @@
         test_sample = x[i, :]
         last_tool_id = str(int(test_sample[-1]))
         for index, abs_topk in enumerate(topk_list):
-            usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids)
+            (
+                usg_wt_score,
+                absolute_precision,
+                pub_prec,
+                lowest_p_prec,
+                lowest_n_prec,
+            ) = compute_precision(
+                model,
+                test_sample,
+                y,
+                reverse_data_dictionary,
+                usage_scores,
+                actual_classes_pos,
+                abs_topk,
+                standard_conn,
+                last_tool_id,
+                lowest_tool_ids,
+            )
             precision[i][index] = absolute_precision
             usage_weights[i][index] = usg_wt_score
             epo_pub_prec[i][index] = pub_prec
@@ -202,22 +249,36 @@
     mean_pub_prec = np.nanmean(epo_pub_prec, axis=0)
     mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0)
     mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0)
-    return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, lowest_counter
+    return (
+        mean_usage,
+        mean_precision,
+        mean_pub_prec,
+        mean_lowest_pub_prec,
+        mean_lowest_norm_prec,
+        lowest_counter,
+    )
 
 
-def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections):
+def save_model(
+    results,
+    data_dictionary,
+    compatible_next_tools,
+    trained_model_path,
+    class_weights,
+    standard_connections,
+):
     # save files
     trained_model = results["model"]
     best_model_parameters = results["best_parameters"]
     model_config = trained_model.to_json()
     model_weights = trained_model.get_weights()
     model_values = {
-        'data_dictionary': data_dictionary,
-        'model_config': model_config,
-        'best_parameters': best_model_parameters,
-        'model_weights': model_weights,
+        "data_dictionary": data_dictionary,
+        "model_config": model_config,
+        "best_parameters": best_model_parameters,
+        "model_weights": model_weights,
         "compatible_tools": compatible_next_tools,
         "class_weights": class_weights,
-        "standard_connections": standard_connections
+        "standard_connections": standard_connections,
     }
     set_trained_model(trained_model_path, model_values)