Previous changeset 4:afec8c595124 (2020-07-07) Next changeset 6:e94dc7945639 (2022-10-16) |
Commit message:
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb" |
modified:
create_tool_recommendation_model.xml extract_workflow_connections.py main.py optimise_hyperparameters.py predict_tool_usage.py prepare_data.py test-data/test_tool_usage test-data/test_workflows utils.py |
b |
diff -r afec8c595124 -r 4f7e6612906b create_tool_recommendation_model.xml --- a/create_tool_recommendation_model.xml Tue Jul 07 03:25:49 2020 -0400 +++ b/create_tool_recommendation_model.xml Fri May 06 09:05:18 2022 +0000 |
b |
@@ -1,13 +1,13 @@ -<tool id="create_tool_recommendation_model" name="Create a model to recommend tools" version="0.0.3"> +<tool id="create_tool_recommendation_model" name="Create a model to recommend tools" version="0.0.4"> <description>using deep learning</description> <requirements> - <requirement type="package" version="3.6">python</requirement> - <requirement type="package" version="1.13.1">tensorflow</requirement> - <requirement type="package" version="2.3.0">keras</requirement> - <requirement type="package" version="0.21.3">scikit-learn</requirement> - <requirement type="package" version="2.9.0">h5py</requirement> + <requirement type="package" version="3.9.7">python</requirement> + <requirement type="package" version="2.7.0">tensorflow</requirement> + <requirement type="package" version="2.7.0">keras</requirement> + <requirement type="package" version="1.0.2">scikit-learn</requirement> + <requirement type="package" version="3.6.0">h5py</requirement> <requirement type="package" version="1.0.4">csvkit</requirement> - <requirement type="package" version="0.1.2">hyperopt</requirement> + <requirement type="package" version="0.2.5">hyperopt</requirement> </requirements> <version_command>echo "@VERSION@"</version_command> <command detect_errors="aggressive"> |
b |
diff -r afec8c595124 -r 4f7e6612906b extract_workflow_connections.py --- a/extract_workflow_connections.py Tue Jul 07 03:25:49 2020 -0400 +++ b/extract_workflow_connections.py Fri May 06 09:05:18 2022 +0000 |
[ |
@@ -10,7 +10,6 @@ class ExtractWorkflowConnections: - def __init__(self): """ Init method. """ @@ -33,12 +32,12 @@ workflow_paths = list() unique_paths = dict() standard_connections = dict() - with open(raw_file_path, 'rt') as workflow_connections_file: - workflow_connections = csv.reader(workflow_connections_file, delimiter='\t') + with open(raw_file_path, "rt") as workflow_connections_file: + workflow_connections = csv.reader(workflow_connections_file, delimiter="\t") for index, row in enumerate(workflow_connections): wf_id = str(row[0]) - in_tool = row[3] - out_tool = row[6] + in_tool = row[3].strip() + out_tool = row[6].strip() if wf_id not in workflows: workflows[wf_id] = list() if out_tool and in_tool and out_tool != in_tool: @@ -144,7 +143,9 @@ if end in graph: for node in graph[end]: if node not in path: - new_tools_paths = self.find_tool_paths_workflow(graph, start, node, path) + new_tools_paths = self.find_tool_paths_workflow( + graph, start, node, path + ) for tool_path in new_tools_paths: path_list.append(tool_path) return path_list |
b |
diff -r afec8c595124 -r 4f7e6612906b main.py --- a/main.py Tue Jul 07 03:25:49 2020 -0400 +++ b/main.py Fri May 06 09:05:18 2022 +0000 |
[ |
b'@@ -3,35 +3,36 @@\n using machine learning (recurrent neural network)\n """\n \n-import numpy as np\n import argparse\n import time\n \n-# machine learning library\n-import tensorflow as tf\n-from keras import backend as K\n+import extract_workflow_connections\n import keras.callbacks as callbacks\n-\n-import extract_workflow_connections\n+import numpy as np\n+import optimise_hyperparameters\n import prepare_data\n-import optimise_hyperparameters\n import utils\n \n \n class PredictTool:\n-\n def __init__(self, num_cpus):\n """ Init method. """\n- # set the number of cpus\n- cpu_config = tf.ConfigProto(\n- device_count={"CPU": num_cpus},\n- intra_op_parallelism_threads=num_cpus,\n- inter_op_parallelism_threads=num_cpus,\n- allow_soft_placement=True\n- )\n- K.set_session(tf.Session(config=cpu_config))\n \n- def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, tool_freq, tool_tr_samples):\n+ def find_train_best_network(\n+ self,\n+ network_config,\n+ reverse_dictionary,\n+ train_data,\n+ train_labels,\n+ test_data,\n+ test_labels,\n+ n_epochs,\n+ class_weights,\n+ usage_pred,\n+ standard_connections,\n+ tool_freq,\n+ tool_tr_samples,\n+ ):\n """\n Define recurrent neural network and train sequential data\n """\n@@ -40,11 +41,34 @@\n \n print("Start hyperparameter optimisation...")\n hyper_opt = optimise_hyperparameters.HyperparameterOptimisation()\n- best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights)\n+ best_params, best_model = hyper_opt.train_model(\n+ network_config,\n+ reverse_dictionary,\n+ train_data,\n+ train_labels,\n+ test_data,\n+ test_labels,\n+ tool_tr_samples,\n+ class_weights,\n+ )\n \n # define callbacks\n- early_stopping = callbacks.EarlyStopping(monitor=\'loss\', mode=\'min\', verbose=1, min_delta=1e-1, restore_best_weights=True)\n- predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids)\n+ early_stopping = callbacks.EarlyStopping(\n+ monitor="loss",\n+ mode="min",\n+ verbose=1,\n+ min_delta=1e-1,\n+ restore_best_weights=True,\n+ )\n+ predict_callback_test = PredictCallback(\n+ test_data,\n+ test_labels,\n+ reverse_dictionary,\n+ n_epochs,\n+ usage_pred,\n+ standard_connections,\n+ lowest_tool_ids,\n+ )\n \n callbacks_list = [predict_callback_test, early_stopping]\n batch_size = int(best_params["batch_size"])\n@@ -57,21 +81,29 @@\n train_labels,\n batch_size,\n tool_tr_samples,\n- reverse_dictionary\n+ reverse_dictionary,\n ),\n steps_per_epoch=len(train_data) // batch_size,\n epochs=n_epochs,\n callbacks=callbacks_list,\n validation_data=(test_data, test_labels),\n verbose=2,\n- shuffle=True\n+ shuffle=True,\n )\n- train_performance["validation_loss"] = np.array(trained_model.history["val_loss"])\n+ train_performance["validation_loss"] = np.array(\n+ trained_model.history["val_loss"]\n+ )\n train_performance["precision"] = predict_callback_test.precision\n train_performance["usage_weights"] = predict_callback_test.usage_weights\n- train_performance["published_precision"] = predict_callback_test.published_precision\n- train_performance["lowest_p'..b'_parser.add_argument(\n+ "-sd",\n+ "--spatial_dropout",\n+ required=True,\n+ help="1d dropout used for embedding layer",\n+ )\n+ arg_parser.add_argument(\n+ "-rd",\n+ "--recurrent_dropout",\n+ required=True,\n+ help="dropout for the recurrent layers",\n+ )\n+ arg_parser.add_argument(\n+ "-lr", "--learning_rate", required=True, help="learning rate"\n+ )\n \n # get argument values\n args = vars(arg_parser.parse_args())\n@@ -156,33 +277,74 @@\n num_cpus = 16\n \n config = {\n- \'cutoff_date\': cutoff_date,\n- \'maximum_path_length\': maximum_path_length,\n- \'n_epochs\': n_epochs,\n- \'optimize_n_epochs\': optimize_n_epochs,\n- \'max_evals\': max_evals,\n- \'test_share\': test_share,\n- \'batch_size\': batch_size,\n- \'units\': units,\n- \'embedding_size\': embedding_size,\n- \'dropout\': dropout,\n- \'spatial_dropout\': spatial_dropout,\n- \'recurrent_dropout\': recurrent_dropout,\n- \'learning_rate\': learning_rate\n+ "cutoff_date": cutoff_date,\n+ "maximum_path_length": maximum_path_length,\n+ "n_epochs": n_epochs,\n+ "optimize_n_epochs": optimize_n_epochs,\n+ "max_evals": max_evals,\n+ "test_share": test_share,\n+ "batch_size": batch_size,\n+ "units": units,\n+ "embedding_size": embedding_size,\n+ "dropout": dropout,\n+ "spatial_dropout": spatial_dropout,\n+ "recurrent_dropout": recurrent_dropout,\n+ "learning_rate": learning_rate,\n }\n \n # Extract and process workflows\n connections = extract_workflow_connections.ExtractWorkflowConnections()\n- workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path)\n+ (\n+ workflow_paths,\n+ compatible_next_tools,\n+ standard_connections,\n+ ) = connections.read_tabular_file(workflows_path)\n # Process the paths from workflows\n print("Dividing data...")\n data = prepare_data.PrepareData(maximum_path_length, test_share)\n- train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, train_tool_freq, tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections)\n+ (\n+ train_data,\n+ train_labels,\n+ test_data,\n+ test_labels,\n+ data_dictionary,\n+ reverse_dictionary,\n+ class_weights,\n+ usage_pred,\n+ train_tool_freq,\n+ tool_tr_samples,\n+ ) = data.get_data_labels_matrices(\n+ workflow_paths,\n+ tool_usage_path,\n+ cutoff_date,\n+ compatible_next_tools,\n+ standard_connections,\n+ )\n # find the best model and start training\n predict_tool = PredictTool(num_cpus)\n # start training with weighted classes\n print("Training with weighted classes and samples ...")\n- results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, train_tool_freq, tool_tr_samples)\n- utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections)\n+ results_weighted = predict_tool.find_train_best_network(\n+ config,\n+ reverse_dictionary,\n+ train_data,\n+ train_labels,\n+ test_data,\n+ test_labels,\n+ n_epochs,\n+ class_weights,\n+ usage_pred,\n+ standard_connections,\n+ train_tool_freq,\n+ tool_tr_samples,\n+ )\n+ utils.save_model(\n+ results_weighted,\n+ data_dictionary,\n+ compatible_next_tools,\n+ trained_model_path,\n+ class_weights,\n+ standard_connections,\n+ )\n end_time = time.time()\n print("Program finished in %s seconds" % str(end_time - start_time))\n' |
b |
diff -r afec8c595124 -r 4f7e6612906b optimise_hyperparameters.py --- a/optimise_hyperparameters.py Tue Jul 07 03:25:49 2020 -0400 +++ b/optimise_hyperparameters.py Fri May 06 09:05:18 2022 +0000 |
[ |
@@ -3,24 +3,29 @@ """ import numpy as np -from hyperopt import fmin, tpe, hp, STATUS_OK, Trials - -from keras.models import Sequential -from keras.layers import Dense, GRU, Dropout -from keras.layers.embeddings import Embedding -from keras.layers.core import SpatialDropout1D -from keras.optimizers import RMSprop -from keras.callbacks import EarlyStopping - import utils +from hyperopt import fmin, hp, STATUS_OK, tpe, Trials +from tensorflow.keras.callbacks import EarlyStopping +from tensorflow.keras.layers import Dense, Dropout, Embedding, GRU, SpatialDropout1D +from tensorflow.keras.models import Sequential +from tensorflow.keras.optimizers import RMSprop class HyperparameterOptimisation: - def __init__(self): """ Init method. """ - def train_model(self, config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights): + def train_model( + self, + config, + reverse_dictionary, + train_data, + train_labels, + test_data, + test_labels, + tool_tr_samples, + class_weights, + ): """ Train a model and report accuracy """ @@ -40,52 +45,101 @@ # get dimensions dimensions = len(reverse_dictionary) + 1 best_model_params = dict() - early_stopping = EarlyStopping(monitor='val_loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True) + early_stopping = EarlyStopping( + monitor="val_loss", + mode="min", + verbose=1, + min_delta=1e-1, + restore_best_weights=True, + ) # specify the search space for finding the best combination of parameters using Bayesian optimisation params = { - "embedding_size": hp.quniform("embedding_size", l_embedding_size[0], l_embedding_size[1], 1), + "embedding_size": hp.quniform( + "embedding_size", l_embedding_size[0], l_embedding_size[1], 1 + ), "units": hp.quniform("units", l_units[0], l_units[1], 1), - "batch_size": hp.quniform("batch_size", l_batch_size[0], l_batch_size[1], 1), - "learning_rate": hp.loguniform("learning_rate", np.log(l_learning_rate[0]), np.log(l_learning_rate[1])), + "batch_size": hp.quniform( + "batch_size", l_batch_size[0], l_batch_size[1], 1 + ), + "learning_rate": hp.loguniform( + "learning_rate", np.log(l_learning_rate[0]), np.log(l_learning_rate[1]) + ), "dropout": hp.uniform("dropout", l_dropout[0], l_dropout[1]), - "spatial_dropout": hp.uniform("spatial_dropout", l_spatial_dropout[0], l_spatial_dropout[1]), - "recurrent_dropout": hp.uniform("recurrent_dropout", l_recurrent_dropout[0], l_recurrent_dropout[1]) + "spatial_dropout": hp.uniform( + "spatial_dropout", l_spatial_dropout[0], l_spatial_dropout[1] + ), + "recurrent_dropout": hp.uniform( + "recurrent_dropout", l_recurrent_dropout[0], l_recurrent_dropout[1] + ), } def create_model(params): model = Sequential() - model.add(Embedding(dimensions, int(params["embedding_size"]), mask_zero=True)) + model.add( + Embedding(dimensions, int(params["embedding_size"]), mask_zero=True) + ) model.add(SpatialDropout1D(params["spatial_dropout"])) - model.add(GRU(int(params["units"]), dropout=params["dropout"], recurrent_dropout=params["recurrent_dropout"], return_sequences=True, activation="elu")) + model.add( + GRU( + int(params["units"]), + dropout=params["dropout"], + recurrent_dropout=params["recurrent_dropout"], + return_sequences=True, + activation="elu", + ) + ) model.add(Dropout(params["dropout"])) - model.add(GRU(int(params["units"]), dropout=params["dropout"], recurrent_dropout=params["recurrent_dropout"], return_sequences=False, activation="elu")) + model.add( + GRU( + int(params["units"]), + dropout=params["dropout"], + recurrent_dropout=params["recurrent_dropout"], + return_sequences=False, + activation="elu", + ) + ) model.add(Dropout(params["dropout"])) model.add(Dense(2 * dimensions, activation="sigmoid")) optimizer_rms = RMSprop(lr=params["learning_rate"]) batch_size = int(params["batch_size"]) - model.compile(loss=utils.weighted_loss(class_weights), optimizer=optimizer_rms) + model.compile( + loss=utils.weighted_loss(class_weights), optimizer=optimizer_rms + ) print(model.summary()) - model_fit = model.fit_generator( + model_fit = model.fit( utils.balanced_sample_generator( train_data, train_labels, batch_size, tool_tr_samples, - reverse_dictionary + reverse_dictionary, ), steps_per_epoch=len(train_data) // batch_size, epochs=optimize_n_epochs, callbacks=[early_stopping], validation_data=(test_data, test_labels), verbose=2, - shuffle=True + shuffle=True, ) - return {'loss': model_fit.history["val_loss"][-1], 'status': STATUS_OK, 'model': model} + return { + "loss": model_fit.history["val_loss"][-1], + "status": STATUS_OK, + "model": model, + } + # minimize the objective function using the set of parameters above trials = Trials() - learned_params = fmin(create_model, params, trials=trials, algo=tpe.suggest, max_evals=int(config["max_evals"])) - best_model = trials.results[np.argmin([r['loss'] for r in trials.results])]['model'] + learned_params = fmin( + create_model, + params, + trials=trials, + algo=tpe.suggest, + max_evals=int(config["max_evals"]), + ) + best_model = trials.results[np.argmin([r["loss"] for r in trials.results])][ + "model" + ] # set the best params with respective values for item in learned_params: item_val = learned_params[item] |
b |
diff -r afec8c595124 -r 4f7e6612906b predict_tool_usage.py --- a/predict_tool_usage.py Tue Jul 07 03:25:49 2020 -0400 +++ b/predict_tool_usage.py Fri May 06 09:05:18 2022 +0000 |
[ |
@@ -2,17 +2,16 @@ Predict tool usage to weigh the predicted tools """ -import os -import numpy as np -import warnings +import collections import csv -import collections +import os +import warnings -from sklearn.svm import SVR +import numpy as np +import utils from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline - -import utils +from sklearn.svm import SVR warnings.filterwarnings("ignore") @@ -20,7 +19,6 @@ class ToolPopularity: - def __init__(self): """ Init method. """ @@ -31,10 +29,11 @@ tool_usage_dict = dict() all_dates = list() all_tool_list = list(dictionary.keys()) - with open(tool_usage_file, 'rt') as usage_file: - tool_usage = csv.reader(usage_file, delimiter='\t') + with open(tool_usage_file, "rt") as usage_file: + tool_usage = csv.reader(usage_file, delimiter="\t") for index, row in enumerate(tool_usage): - if (str(row[1]) > cutoff_date) is True: + row = [item.strip() for item in row] + if (str(row[1]).strip() > cutoff_date) is True: tool_id = utils.format_tool_id(row[0]) if tool_id in all_tool_list: all_dates.append(row[1]) @@ -67,18 +66,25 @@ """ epsilon = 0.0 cv = 5 - s_typ = 'neg_mean_absolute_error' + s_typ = "neg_mean_absolute_error" n_jobs = 4 s_error = 1 - iid = True tr_score = False try: - pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) + pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))]) param_grid = { - 'regressor__kernel': ['rbf', 'poly', 'linear'], - 'regressor__degree': [2, 3] + "regressor__kernel": ["rbf", "poly", "linear"], + "regressor__degree": [2, 3], } - search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) + search = GridSearchCV( + pipe, + param_grid, + cv=cv, + scoring=s_typ, + n_jobs=n_jobs, + error_score=s_error, + return_train_score=tr_score, + ) search.fit(x_reshaped, y_reshaped.ravel()) model = search.best_estimator_ # set the next time point to get prediction for @@ -87,7 +93,8 @@ if prediction < epsilon: prediction = [epsilon] return prediction[0] - except Exception: + except Exception as e: + print(e) return epsilon def get_pupularity_prediction(self, tools_usage): |
b |
diff -r afec8c595124 -r 4f7e6612906b prepare_data.py --- a/prepare_data.py Tue Jul 07 03:25:49 2020 -0400 +++ b/prepare_data.py Fri May 06 09:05:18 2022 +0000 |
[ |
@@ -4,18 +4,17 @@ into the test and training sets """ +import collections import os -import collections -import numpy as np import random +import numpy as np import predict_tool_usage main_path = os.getcwd() class PrepareData: - def __init__(self, max_seq_length, test_data_share): """ Init method. """ self.max_tool_sequence_len = max_seq_length @@ -27,15 +26,20 @@ """ tokens = list() raw_paths = workflow_paths - raw_paths = [x.replace("\n", '') for x in raw_paths] + raw_paths = [x.replace("\n", "") for x in raw_paths] for item in raw_paths: split_items = item.split(",") for token in split_items: - if token is not "": + if token != "": tokens.append(token) tokens = list(set(tokens)) tokens = np.array(tokens) - tokens = np.reshape(tokens, [-1, ]) + tokens = np.reshape( + tokens, + [ + -1, + ], + ) return tokens, raw_paths def create_new_dict(self, new_data_dict): @@ -60,7 +64,10 @@ dictionary = dict() for word, _ in count: dictionary[word] = len(dictionary) + 1 - dictionary, reverse_dictionary = self.assemble_dictionary(dictionary, old_data_dictionary) + word = word.strip() + dictionary, reverse_dictionary = self.assemble_dictionary( + dictionary, old_data_dictionary + ) return dictionary, reverse_dictionary def decompose_paths(self, paths, dictionary): @@ -74,13 +81,17 @@ if len_tools <= self.max_tool_sequence_len: for window in range(1, len_tools): sequence = tools[0: window + 1] - tools_pos = [str(dictionary[str(tool_item)]) for tool_item in sequence] + tools_pos = [ + str(dictionary[str(tool_item)]) for tool_item in sequence + ] if len(tools_pos) > 1: sub_paths_pos.append(",".join(tools_pos)) sub_paths_pos = list(set(sub_paths_pos)) return sub_paths_pos - def prepare_paths_labels_dictionary(self, dictionary, reverse_dictionary, paths, compatible_next_tools): + def prepare_paths_labels_dictionary( + self, dictionary, reverse_dictionary, paths, compatible_next_tools + ): """ Create a dictionary of sequences with their labels for training and test paths """ @@ -90,14 +101,18 @@ if item and item not in "": tools = item.split(",") label = tools[-1] - train_tools = tools[:len(tools) - 1] + train_tools = tools[: len(tools) - 1] last_but_one_name = reverse_dictionary[int(train_tools[-1])] try: - compatible_tools = compatible_next_tools[last_but_one_name].split(",") + compatible_tools = compatible_next_tools[last_but_one_name].split( + "," + ) except Exception: continue if len(compatible_tools) > 0: - compatible_tools_ids = [str(dictionary[x]) for x in compatible_tools] + compatible_tools_ids = [ + str(dictionary[x]) for x in compatible_tools + ] compatible_tools_ids.append(label) composite_labels = ",".join(compatible_tools_ids) train_tools = ",".join(train_tools) @@ -127,7 +142,9 @@ train_counter += 1 return data_mat, label_mat - def pad_paths(self, paths_dictionary, num_classes, standard_connections, reverse_dictionary): + def pad_paths( + self, paths_dictionary, num_classes, standard_connections, reverse_dictionary + ): """ Add padding to the tools sequences and create multi-hot encoded labels """ @@ -231,12 +248,22 @@ l_tool_tr_samples[last_tool_id].append(index) return l_tool_tr_samples - def get_data_labels_matrices(self, workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections, old_data_dictionary={}): + def get_data_labels_matrices( + self, + workflow_paths, + tool_usage_path, + cutoff_date, + compatible_next_tools, + standard_connections, + old_data_dictionary={}, + ): """ Convert the training and test paths into corresponding numpy matrices """ processed_data, raw_paths = self.process_workflow_paths(workflow_paths) - dictionary, rev_dict = self.create_data_dictionary(processed_data, old_data_dictionary) + dictionary, rev_dict = self.create_data_dictionary( + processed_data, old_data_dictionary + ) num_classes = len(dictionary) print("Raw paths: %d" % len(raw_paths)) @@ -247,18 +274,26 @@ random.shuffle(all_unique_paths) print("Creating dictionaries...") - multilabels_paths = self.prepare_paths_labels_dictionary(dictionary, rev_dict, all_unique_paths, compatible_next_tools) + multilabels_paths = self.prepare_paths_labels_dictionary( + dictionary, rev_dict, all_unique_paths, compatible_next_tools + ) print("Complete data: %d" % len(multilabels_paths)) - train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths) + train_paths_dict, test_paths_dict = self.split_test_train_data( + multilabels_paths + ) print("Train data: %d" % len(train_paths_dict)) print("Test data: %d" % len(test_paths_dict)) print("Padding train and test data...") # pad training and test data with leading zeros - test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict) - train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict) + test_data, test_labels = self.pad_paths( + test_paths_dict, num_classes, standard_connections, rev_dict + ) + train_data, train_labels = self.pad_paths( + train_paths_dict, num_classes, standard_connections, rev_dict + ) print("Estimating sample frequency...") l_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict) @@ -274,4 +309,15 @@ # get class weights using the predicted usage for each tool class_weights = self.assign_class_weights(num_classes, t_pred_usage) - return train_data, train_labels, test_data, test_labels, dictionary, rev_dict, class_weights, t_pred_usage, l_tool_freq, l_tool_tr_samples + return ( + train_data, + train_labels, + test_data, + test_labels, + dictionary, + rev_dict, + class_weights, + t_pred_usage, + l_tool_freq, + l_tool_tr_samples, + ) |
b |
diff -r afec8c595124 -r 4f7e6612906b test-data/test_tool_usage --- a/test-data/test_tool_usage Tue Jul 07 03:25:49 2020 -0400 +++ b/test-data/test_tool_usage Fri May 06 09:05:18 2022 +0000 |
b |
b'@@ -1,500 +1,93 @@\n-toolshed.g2.bx.psu.edu/repos/bgruening/rdock_rbdock/rdock_rbdock/0.1.1\t2020-04-01\t34568\n-upload1\t2020-04-01\t18321\n-toolshed.g2.bx.psu.edu/repos/devteam/fastqc/fastqc/0.72+galaxy1\t2020-04-01\t2839\n-toolshed.g2.bx.psu.edu/repos/bgruening/xchem_transfs_scoring/xchem_transfs_scoring/0.2.0\t2020-04-01\t1220\n-toolshed.g2.bx.psu.edu/repos/devteam/bowtie2/bowtie2/2.3.4.3+galaxy0\t2020-04-01\t958\n-CONVERTER_gz_to_uncompressed\t2020-04-01\t919\n-Filter1\t2020-04-01\t800\n-Cut1\t2020-04-01\t787\n-__SET_METADATA__\t2020-04-01\t685\n-toolshed.g2.bx.psu.edu/repos/pjbriggs/trimmomatic/trimmomatic/0.36.5\t2020-04-01\t643\n-toolshed.g2.bx.psu.edu/repos/iuc/featurecounts/featurecounts/1.6.4+galaxy1\t2020-04-01\t627\n-toolshed.g2.bx.psu.edu/repos/iuc/rgrnastar/rna_star/2.7.2b\t2020-04-01\t605\n-toolshed.g2.bx.psu.edu/repos/bgruening/trim_galore/trim_galore/0.4.3.1\t2020-04-01\t538\n-toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_plot_heatmap/deeptools_plot_heatmap/3.3.2.0.1\t2020-04-01\t530\n-toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_compute_matrix/deeptools_compute_matrix/3.3.2.0.0\t2020-04-01\t475\n-Remove beginning1\t2020-04-01\t447\n-toolshed.g2.bx.psu.edu/repos/bgruening/text_processing/tp_sort_header_tool/1.1.1\t2020-04-01\t408\n-join1\t2020-04-01\t369\n-toolshed.g2.bx.psu.edu/repos/iuc/hisat2/hisat2/2.1.0+galaxy5\t2020-04-01\t357\n-Convert characters1\t2020-04-01\t351\n-toolshed.g2.bx.psu.edu/repos/devteam/ncbi_blast_plus/ncbi_blastp_wrapper/0.3.3\t2020-04-01\t336\n-toolshed.g2.bx.psu.edu/repos/lparsons/cutadapt/cutadapt/1.16.5\t2020-04-01\t316\n-toolshed.g2.bx.psu.edu/repos/iuc/bedtools/bedtools_intersectbed/2.29.0\t2020-04-01\t313\n-toolshed.g2.bx.psu.edu/repos/iuc/samtools_fastx/samtools_fastx/1.9+galaxy1\t2020-04-01\t311\n-toolshed.g2.bx.psu.edu/repos/iuc/fastp/fastp/0.19.5+galaxy1\t2020-04-01\t305\n-toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_bam_coverage/deeptools_bam_coverage/3.3.2.0.0\t2020-04-01\t302\n-cat1\t2020-04-01\t299\n-toolshed.g2.bx.psu.edu/repos/iuc/macs2/macs2_callpeak/2.1.1.20160309.6\t2020-04-01\t298\n-addValue\t2020-04-01\t294\n-toolshed.g2.bx.psu.edu/repos/devteam/column_maker/Add_a_column1/1.1.0\t2020-04-01\t276\n-toolshed.g2.bx.psu.edu/repos/devteam/bwa/bwa_mem/0.7.17.1\t2020-04-01\t275\n-toolshed.g2.bx.psu.edu/repos/devteam/fastq_to_fasta/cshl_fastq_to_fasta/1.0.2\t2020-04-01\t274\n-toolshed.g2.bx.psu.edu/repos/iuc/multiqc/multiqc/1.7.1\t2020-04-01\t274\n-ebi_sra_main\t2020-04-01\t272\n-Grouping1\t2020-04-01\t237\n-toolshed.g2.bx.psu.edu/repos/devteam/picard/picard_MarkDuplicates/2.18.2.2\t2020-04-01\t236\n-toolshed.g2.bx.psu.edu/repos/lparsons/htseq_count/htseq_count/0.9.1\t2020-04-01\t233\n-toolshed.g2.bx.psu.edu/repos/devteam/fastq_groomer/fastq_groomer/1.1.5\t2020-04-01\t220\n-toolshed.g2.bx.psu.edu/repos/galaxyp/filter_by_fasta_ids/filter_by_fasta_ids/2.1\t2020-04-01\t219\n-toolshed.g2.bx.psu.edu/repos/iuc/sra_tools/fastq_dump/2.10.4+galaxy1\t2020-04-01\t218\n-toolshed.g2.bx.psu.edu/repos/bgruening/sucos_max_score/sucos_max_score/0.2.1\t2020-04-01\t216\n-toolshed.g2.bx.psu.edu/repos/bgruening/text_processing/tp_cut_tool/1.1.0\t2020-04-01\t216\n-toolshed.g2.bx.psu.edu/repos/jjohnson/query_tabular/query_tabular/3.0.0\t2020-04-01\t209\n-toolshed.g2.bx.psu.edu/repos/crs4/prokka/prokka/1.14.5\t2020-04-01\t204\n-toolshed.g2.bx.psu.edu/repos/iuc/snippy/snippy/4.5.0\t2020-04-01\t187\n-toolshed.g2.bx.psu.edu/repos/devteam/fastq_paired_end_joiner/fastq_paired_end_joiner/2.0.1.1+galaxy0\t2020-04-01\t185\n-CONVERTER_interval_to_bed_0\t2020-04-01\t182\n-toolshed.g2.bx.psu.edu/repos/nilesh/rseqc/rseqc_read_distribution/2.6.4.1\t2020-04-01\t178\n-ucsc_table_direct1\t2020-04-01\t177\n-toolshed.g2.bx.psu.edu/repos/iuc/deseq2/deseq2/2.11.40.6+galaxy1\t2020-04-01\t176\n-toolshed.g2.bx.psu.edu/repos/devteam/fastq_paired_end_deinterlacer/fastq_paired_end_deinterlacer/1.1.5\t2020-04-01\t175\n-toolshed.g2.bx.psu.edu/repos/bgruening/text_processing/tp_easyjoin_tool/1.1.1\t2020-04-01\t172\n-toolshed.g2.bx.psu.edu/repos/iuc/sra_tools/fastq_dump/2.10.4\t2020-04-01\t169\n-toolshed.g2.bx.psu.edu/repos/devteam/fastqtofasta/fastq_to_fasta_python/1.1.'..b'22-02-01\t1036\n+ toolshed.g2.bx.psu.edu/repos/devteam/column_maker/Add_a_column1/1.3.0 \t2022-02-01\t1019\n+ toolshed.g2.bx.psu.edu/repos/iuc/multiqc/multiqc/1.11+galaxy0 \t2022-02-01\t973\n+ toolshed.g2.bx.psu.edu/repos/bgruening/text_processing/tp_cat/0.1.1 \t2022-02-01\t962\n+ toolshed.g2.bx.psu.edu/repos/iuc/sra_tools/fasterq_dump/2.11.0+galaxy1 \t2022-02-01\t926\n+ toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_compute_matrix/deeptools_compute_matrix/3.3.2.0.0 \t2022-02-01\t915\n+ toolshed.g2.bx.psu.edu/repos/devteam/fastq_groomer/fastq_groomer/1.1.5 \t2022-02-01\t907\n+ toolshed.g2.bx.psu.edu/repos/iuc/obi_illumina_pairend/obi_illumina_pairend/1.2.13 \t2022-02-01\t897\n+ toolshed.g2.bx.psu.edu/repos/nilesh/rseqc/rseqc_infer_experiment/2.6.4.1 \t2022-02-01\t879\n+ toolshed.g2.bx.psu.edu/repos/devteam/bowtie2/bowtie2/2.4.5+galaxy0 \t2022-02-01\t871\n+ toolshed.g2.bx.psu.edu/repos/devteam/bowtie2/bowtie2/2.4.2+galaxy0 \t2022-02-01\t862\n+ cat1 \t2022-02-01\t795\n+ toolshed.g2.bx.psu.edu/repos/lparsons/cutadapt/cutadapt/3.5+galaxy1 \t2022-02-01\t772\n+ toolshed.g2.bx.psu.edu/repos/peterjc/blast_rbh/blast_reciprocal_best_hits/0.1.11 \t2022-02-01\t766\n+ toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_plot_heatmap/deeptools_plot_heatmap/3.3.2.0.1 \t2022-02-01\t758\n+ toolshed.g2.bx.psu.edu/repos/galaxyp/regex_find_replace/regexColumn1/1.0.1 \t2022-02-01\t705\n+ toolshed.g2.bx.psu.edu/repos/iuc/macs2/macs2_callpeak/2.1.1.20160309.6 \t2022-02-01\t652\n+ toolshed.g2.bx.psu.edu/repos/devteam/ncbi_blast_plus/ncbi_blastn_wrapper/2.10.1+galaxy0 \t2022-02-01\t648\n+ toolshed.g2.bx.psu.edu/repos/peterjc/seq_filter_by_mapping/seq_filter_by_mapping/0.0.6 \t2022-02-01\t614\n+ toolshed.g2.bx.psu.edu/repos/iuc/deseq2/deseq2/2.11.40.7+galaxy1 \t2022-02-01\t606\n+ toolshed.g2.bx.psu.edu/repos/iuc/compose_text_param/compose_text_param/0.1.1 \t2022-02-01\t588\n+ join1 \t2022-02-01\t584\n+ toolshed.g2.bx.psu.edu/repos/devteam/fastqtofasta/fastq_to_fasta_python/1.1.5 \t2022-02-01\t580\n+ toolshed.g2.bx.psu.edu/repos/bgruening/deeptools_plot_profile/deeptools_plot_profile/3.3.2.0.0 \t2022-02-01\t555\n+ toolshed.g2.bx.psu.edu/repos/lparsons/htseq_count/htseq_count/0.9.1+galaxy1 \t2022-02-01\t544\n' |
b |
diff -r afec8c595124 -r 4f7e6612906b test-data/test_workflows --- a/test-data/test_workflows Tue Jul 07 03:25:49 2020 -0400 +++ b/test-data/test_workflows Fri May 06 09:05:18 2022 +0000 |
b |
b'@@ -1,1000 +1,499 @@\n-3\t2013-02-07 16:48:46.721866\t5\tGrep1\t1.0.1\t7\tRemove beginning1\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t6\tCut1\t1.0.1\t8\taddValue\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t7\tRemove beginning1\t1.0.0\t9\tCut1\t1.0.1\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t7\tRemove beginning1\t1.0.0\t6\tCut1\t1.0.1\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t8\taddValue\t1.0.0\t11\tPaste1\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t9\tCut1\t1.0.1\t11\tPaste1\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t11\tPaste1\t1.0.0\t10\taddValue\t1.0.0\tf\tt\tf\n-3\t2013-02-07 16:48:46.721866\t12\t\t\t5\tGrep1\t1.0.1\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t13\tcat1\t1.0.0\t22\tbarchart_gnuplot\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t14\tbedtools_intersectBed\t\t16\twc_gnu\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t15\tbedtools_intersectBed\t\t17\tsort1\t1.0.1\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t16\twc_gnu\t1.0.0\t18\taddValue\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t17\tsort1\t1.0.1\t19\tcshl_awk_tool\t\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t18\taddValue\t1.0.0\t13\tcat1\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t19\tcshl_awk_tool\t\t21\tcshl_uniq_tool\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t20\tCount1\t1.0.0\t13\tcat1\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t21\tcshl_uniq_tool\t1.0.0\t20\tCount1\t1.0.0\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t23\t\t\t15\tbedtools_intersectBed\t\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t23\t\t\t14\tbedtools_intersectBed\t\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t24\t\t\t15\tbedtools_intersectBed\t\tf\tt\tf\n-4\t2013-02-07 16:48:55.340018\t24\t\t\t14\tbedtools_intersectBed\t\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t25\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t26\tcat1\t1.0.0\t67\tbarchart_gnuplot\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t27\tCut1\t1.0.1\t59\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t28\tCut1\t1.0.1\t59\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t29\tCut1\t1.0.1\t66\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t30\tCut1\t1.0.1\t66\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t31\tCut1\t1.0.1\t36\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t32\tCut1\t1.0.1\t36\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t33\tCut1\t1.0.1\t60\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t34\tCut1\t1.0.1\t60\tPaste1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t35\tPaste1\t1.0.0\t65\tAdd_a_column1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t36\tPaste1\t1.0.0\t64\tAdd_a_column1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t37\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t38\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t39\tFilter1\t1.1.0\t46\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t39\tFilter1\t1.1.0\t45\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t44\tFilter1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t43\tFilter1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t42\tFilter1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t41\tcshl_grep_tool\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t40\tgops_coverage_1\t1.0.0\t39\tFilter1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t41\tcshl_grep_tool\t1.0.0\t52\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t41\tcshl_grep_tool\t1.0.0\t51\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t42\tFilter1\t1.1.0\t50\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t42\tFilter1\t1.1.0\t49\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t43\tFilter1\t1.1.0\t56\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t43\tFilter1\t1.1.0\t55\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t44\tFilter1\t1.1.0\t54\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t44\tFilter1\t1.1.0\t53\tSummary_Statistics1\t1.1.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t45\tSummary_Statistics1\t1.1.0\t57\tCut1\t1.0.1\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t46\tSummary_Statistics1\t1.1.0\t58\tCut1\t1.0.1\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t47\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:49:04.367628\t48\taddValue\t1.0.0\t26\tcat1\t1.0.0\tf\tt\tf\n-5\t2013-02-07 16:4'..b'0 \t276\t mergeCols1 \t 1.0.1 \t f \t f \t f\n+31\t2013-02-18\t276\t mergeCols1 \t 1.0.1 \t277\t Cut1 \t 1.0.1 \t f \t f \t f\n+31\t2013-02-18\t274\t \t \t275\t addValue \t 1.0.0 \t f \t f \t f\n+31\t2013-02-18\t275\t addValue \t 1.0.0 \t276\t mergeCols1 \t 1.0.1 \t f \t f \t f\n+31\t2013-02-18\t276\t mergeCols1 \t 1.0.1 \t277\t Cut1 \t 1.0.1 \t f \t f \t f\n+31\t2013-02-18\t274\t \t \t275\t addValue \t 1.0.0 \t f \t f \t f\n+31\t2013-02-18\t275\t addValue \t 1.0.0 \t276\t mergeCols1 \t 1.0.1 \t f \t f \t f\n+31\t2013-02-18\t276\t mergeCols1 \t 1.0.1 \t277\t Cut1 \t 1.0.1 \t f \t f \t f\n+31\t2013-02-18\t274\t \t \t275\t addValue \t 1.0.0 \t f \t f \t f\n+31\t2013-02-18\t275\t addValue \t 1.0.0 \t276\t mergeCols1 \t 1.0.1 \t f \t f \t f\n' |
b |
diff -r afec8c595124 -r 4f7e6612906b utils.py --- a/utils.py Tue Jul 07 03:25:49 2020 -0400 +++ b/utils.py Fri May 06 09:05:18 2022 +0000 |
[ |
@@ -1,10 +1,11 @@ -import numpy as np import json -import h5py import random + +import h5py +import numpy as np +import tensorflow as tf from numpy.random import choice - -from keras import backend as K +from tensorflow.keras import backend def read_file(file_path): @@ -29,10 +30,10 @@ """ Create an h5 file with the trained weights and associated dicts """ - hf_file = h5py.File(dump_file, 'w') + hf_file = h5py.File(dump_file, "w") for key in model_values: value = model_values[key] - if key == 'model_weights': + if key == "model_weights": for idx, item in enumerate(value): w_key = "weight_" + str(idx) if w_key in hf_file: @@ -54,14 +55,19 @@ """ weight_values = list(class_weights.values()) weight_values.extend(weight_values) + def weighted_binary_crossentropy(y_true, y_pred): # add another dimension to compute dot product - expanded_weights = K.expand_dims(weight_values, axis=-1) - return K.dot(K.binary_crossentropy(y_true, y_pred), expanded_weights) + expanded_weights = tf.expand_dims(weight_values, axis=-1) + bce = backend.binary_crossentropy(y_true, y_pred) + return backend.dot(bce, expanded_weights) + return weighted_binary_crossentropy -def balanced_sample_generator(train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary): +def balanced_sample_generator( + train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary +): while True: dimension = train_data.shape[1] n_classes = train_labels.shape[1] @@ -80,7 +86,18 @@ yield generator_batch_data, generator_batch_labels -def compute_precision(model, x, y, reverse_data_dictionary, usage_scores, actual_classes_pos, topk, standard_conn, last_tool_id, lowest_tool_ids): +def compute_precision( + model, + x, + y, + reverse_data_dictionary, + usage_scores, + actual_classes_pos, + topk, + standard_conn, + last_tool_id, + lowest_tool_ids, +): """ Compute absolute and compatible precision """ @@ -137,7 +154,9 @@ else: lowest_pub_prec = np.nan if standard_topk_prediction_pos in usage_scores: - usage_wt_score.append(np.log(usage_scores[standard_topk_prediction_pos] + 1.0)) + usage_wt_score.append( + np.log(usage_scores[standard_topk_prediction_pos] + 1.0) + ) else: # count precision only when there is actually true published tools # else set to np.nan. Set to 0 only when there is wrong prediction @@ -148,7 +167,9 @@ pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)] if pred_t_name in actual_next_tool_names: if normal_topk_prediction_pos in usage_scores: - usage_wt_score.append(np.log(usage_scores[normal_topk_prediction_pos] + 1.0)) + usage_wt_score.append( + np.log(usage_scores[normal_topk_prediction_pos] + 1.0) + ) top_precision = 1.0 if last_tool_id in lowest_tool_ids: lowest_norm_prec = 1.0 @@ -166,7 +187,16 @@ return lowest_ids -def verify_model(model, x, y, reverse_data_dictionary, usage_scores, standard_conn, lowest_tool_ids, topk_list=[1, 2, 3]): +def verify_model( + model, + x, + y, + reverse_data_dictionary, + usage_scores, + standard_conn, + lowest_tool_ids, + topk_list=[1, 2, 3], +): """ Verify the model on test data """ @@ -187,7 +217,24 @@ test_sample = x[i, :] last_tool_id = str(int(test_sample[-1])) for index, abs_topk in enumerate(topk_list): - usg_wt_score, absolute_precision, pub_prec, lowest_p_prec, lowest_n_prec = compute_precision(model, test_sample, y, reverse_data_dictionary, usage_scores, actual_classes_pos, abs_topk, standard_conn, last_tool_id, lowest_tool_ids) + ( + usg_wt_score, + absolute_precision, + pub_prec, + lowest_p_prec, + lowest_n_prec, + ) = compute_precision( + model, + test_sample, + y, + reverse_data_dictionary, + usage_scores, + actual_classes_pos, + abs_topk, + standard_conn, + last_tool_id, + lowest_tool_ids, + ) precision[i][index] = absolute_precision usage_weights[i][index] = usg_wt_score epo_pub_prec[i][index] = pub_prec @@ -202,22 +249,36 @@ mean_pub_prec = np.nanmean(epo_pub_prec, axis=0) mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0) mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0) - return mean_usage, mean_precision, mean_pub_prec, mean_lowest_pub_prec, mean_lowest_norm_prec, lowest_counter + return ( + mean_usage, + mean_precision, + mean_pub_prec, + mean_lowest_pub_prec, + mean_lowest_norm_prec, + lowest_counter, + ) -def save_model(results, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections): +def save_model( + results, + data_dictionary, + compatible_next_tools, + trained_model_path, + class_weights, + standard_connections, +): # save files trained_model = results["model"] best_model_parameters = results["best_parameters"] model_config = trained_model.to_json() model_weights = trained_model.get_weights() model_values = { - 'data_dictionary': data_dictionary, - 'model_config': model_config, - 'best_parameters': best_model_parameters, - 'model_weights': model_weights, + "data_dictionary": data_dictionary, + "model_config": model_config, + "best_parameters": best_model_parameters, + "model_weights": model_weights, "compatible_tools": compatible_next_tools, "class_weights": class_weights, - "standard_connections": standard_connections + "standard_connections": standard_connections, } set_trained_model(trained_model_path, model_values) |