comparison main.py @ 6:e94dc7945639 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author bgruening
date Sun, 16 Oct 2022 11:52:10 +0000
parents 4f7e6612906b
children
comparison
equal deleted inserted replaced
5:4f7e6612906b 6:e94dc7945639
1 """ 1 """
2 Predict next tools in the Galaxy workflows 2 Predict next tools in the Galaxy workflows
3 using machine learning (recurrent neural network) 3 using deep learning learning (Transformers)
4 """ 4 """
5
6 import argparse 5 import argparse
7 import time 6 import time
8 7
9 import extract_workflow_connections 8 import extract_workflow_connections
10 import keras.callbacks as callbacks
11 import numpy as np
12 import optimise_hyperparameters
13 import prepare_data 9 import prepare_data
14 import utils 10 import train_transformer
15
16
17 class PredictTool:
18 def __init__(self, num_cpus):
19 """ Init method. """
20
21 def find_train_best_network(
22 self,
23 network_config,
24 reverse_dictionary,
25 train_data,
26 train_labels,
27 test_data,
28 test_labels,
29 n_epochs,
30 class_weights,
31 usage_pred,
32 standard_connections,
33 tool_freq,
34 tool_tr_samples,
35 ):
36 """
37 Define recurrent neural network and train sequential data
38 """
39 # get tools with lowest representation
40 lowest_tool_ids = utils.get_lowest_tools(tool_freq)
41
42 print("Start hyperparameter optimisation...")
43 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation()
44 best_params, best_model = hyper_opt.train_model(
45 network_config,
46 reverse_dictionary,
47 train_data,
48 train_labels,
49 test_data,
50 test_labels,
51 tool_tr_samples,
52 class_weights,
53 )
54
55 # define callbacks
56 early_stopping = callbacks.EarlyStopping(
57 monitor="loss",
58 mode="min",
59 verbose=1,
60 min_delta=1e-1,
61 restore_best_weights=True,
62 )
63 predict_callback_test = PredictCallback(
64 test_data,
65 test_labels,
66 reverse_dictionary,
67 n_epochs,
68 usage_pred,
69 standard_connections,
70 lowest_tool_ids,
71 )
72
73 callbacks_list = [predict_callback_test, early_stopping]
74 batch_size = int(best_params["batch_size"])
75
76 print("Start training on the best model...")
77 train_performance = dict()
78 trained_model = best_model.fit_generator(
79 utils.balanced_sample_generator(
80 train_data,
81 train_labels,
82 batch_size,
83 tool_tr_samples,
84 reverse_dictionary,
85 ),
86 steps_per_epoch=len(train_data) // batch_size,
87 epochs=n_epochs,
88 callbacks=callbacks_list,
89 validation_data=(test_data, test_labels),
90 verbose=2,
91 shuffle=True,
92 )
93 train_performance["validation_loss"] = np.array(
94 trained_model.history["val_loss"]
95 )
96 train_performance["precision"] = predict_callback_test.precision
97 train_performance["usage_weights"] = predict_callback_test.usage_weights
98 train_performance[
99 "published_precision"
100 ] = predict_callback_test.published_precision
101 train_performance[
102 "lowest_pub_precision"
103 ] = predict_callback_test.lowest_pub_precision
104 train_performance[
105 "lowest_norm_precision"
106 ] = predict_callback_test.lowest_norm_precision
107 train_performance["train_loss"] = np.array(trained_model.history["loss"])
108 train_performance["model"] = best_model
109 train_performance["best_parameters"] = best_params
110 return train_performance
111
112
113 class PredictCallback(callbacks.Callback):
114 def __init__(
115 self,
116 test_data,
117 test_labels,
118 reverse_data_dictionary,
119 n_epochs,
120 usg_scores,
121 standard_connections,
122 lowest_tool_ids,
123 ):
124 self.test_data = test_data
125 self.test_labels = test_labels
126 self.reverse_data_dictionary = reverse_data_dictionary
127 self.precision = list()
128 self.usage_weights = list()
129 self.published_precision = list()
130 self.n_epochs = n_epochs
131 self.pred_usage_scores = usg_scores
132 self.standard_connections = standard_connections
133 self.lowest_tool_ids = lowest_tool_ids
134 self.lowest_pub_precision = list()
135 self.lowest_norm_precision = list()
136
137 def on_epoch_end(self, epoch, logs={}):
138 """
139 Compute absolute and compatible precision for test data
140 """
141 if len(self.test_data) > 0:
142 (
143 usage_weights,
144 precision,
145 precision_pub,
146 low_pub_prec,
147 low_norm_prec,
148 low_num,
149 ) = utils.verify_model(
150 self.model,
151 self.test_data,
152 self.test_labels,
153 self.reverse_data_dictionary,
154 self.pred_usage_scores,
155 self.standard_connections,
156 self.lowest_tool_ids,
157 )
158 self.precision.append(precision)
159 self.usage_weights.append(usage_weights)
160 self.published_precision.append(precision_pub)
161 self.lowest_pub_precision.append(low_pub_prec)
162 self.lowest_norm_precision.append(low_norm_prec)
163 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights))
164 print("Epoch %d normal precision: %s" % (epoch + 1, precision))
165 print("Epoch %d published precision: %s" % (epoch + 1, precision_pub))
166 print("Epoch %d lowest published precision: %s" % (epoch + 1, low_pub_prec))
167 print("Epoch %d lowest normal precision: %s" % (epoch + 1, low_norm_prec))
168 print(
169 "Epoch %d number of test samples with lowest tool ids: %s"
170 % (epoch + 1, low_num)
171 )
172
173 11
174 if __name__ == "__main__": 12 if __name__ == "__main__":
175 start_time = time.time() 13 start_time = time.time()
176 14
177 arg_parser = argparse.ArgumentParser() 15 arg_parser = argparse.ArgumentParser()
178 arg_parser.add_argument( 16 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file")
179 "-wf", "--workflow_file", required=True, help="workflows tabular file" 17 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file")
180 )
181 arg_parser.add_argument(
182 "-tu", "--tool_usage_file", required=True, help="tool usage file"
183 )
184 arg_parser.add_argument(
185 "-om", "--output_model", required=True, help="trained model file"
186 )
187 # data parameters 18 # data parameters
188 arg_parser.add_argument( 19 arg_parser.add_argument("-cd", "--cutoff_date", required=True, help="earliest date for taking tool usage")
189 "-cd", 20 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path")
190 "--cutoff_date", 21 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model path")
191 required=True,
192 help="earliest date for taking tool usage",
193 )
194 arg_parser.add_argument(
195 "-pl",
196 "--maximum_path_length",
197 required=True,
198 help="maximum length of tool path",
199 )
200 arg_parser.add_argument(
201 "-ep",
202 "--n_epochs",
203 required=True,
204 help="number of iterations to run to create model",
205 )
206 arg_parser.add_argument(
207 "-oe",
208 "--optimize_n_epochs",
209 required=True,
210 help="number of iterations to run to find best model parameters",
211 )
212 arg_parser.add_argument(
213 "-me",
214 "--max_evals",
215 required=True,
216 help="maximum number of configuration evaluations",
217 )
218 arg_parser.add_argument(
219 "-ts",
220 "--test_share",
221 required=True,
222 help="share of data to be used for testing",
223 )
224 # neural network parameters 22 # neural network parameters
225 arg_parser.add_argument( 23 arg_parser.add_argument("-ti", "--n_train_iter", required=True, help="Number of training iterations run to create model")
226 "-bs", 24 arg_parser.add_argument("-nhd", "--n_heads", required=True, help="Number of head in transformer's multi-head attention")
227 "--batch_size", 25 arg_parser.add_argument("-ed", "--n_embed_dim", required=True, help="Embedding dimension")
228 required=True, 26 arg_parser.add_argument("-fd", "--n_feed_forward_dim", required=True, help="Feed forward network dimension")
229 help="size of the tranining batch i.e. the number of samples per batch", 27 arg_parser.add_argument("-dt", "--dropout", required=True, help="Percentage of neurons to be dropped")
230 ) 28 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="Learning rate")
231 arg_parser.add_argument( 29 arg_parser.add_argument("-ts", "--te_share", required=True, help="Share of data to be used for testing")
232 "-ut", "--units", required=True, help="number of hidden recurrent units" 30 arg_parser.add_argument("-trbs", "--tr_batch_size", required=True, help="Train batch size")
233 ) 31 arg_parser.add_argument("-trlg", "--tr_logging_step", required=True, help="Train logging frequency")
234 arg_parser.add_argument( 32 arg_parser.add_argument("-telg", "--te_logging_step", required=True, help="Test logging frequency")
235 "-es", 33 arg_parser.add_argument("-tebs", "--te_batch_size", required=True, help="Test batch size")
236 "--embedding_size",
237 required=True,
238 help="size of the fixed vector learned for each tool",
239 )
240 arg_parser.add_argument(
241 "-dt", "--dropout", required=True, help="percentage of neurons to be dropped"
242 )
243 arg_parser.add_argument(
244 "-sd",
245 "--spatial_dropout",
246 required=True,
247 help="1d dropout used for embedding layer",
248 )
249 arg_parser.add_argument(
250 "-rd",
251 "--recurrent_dropout",
252 required=True,
253 help="dropout for the recurrent layers",
254 )
255 arg_parser.add_argument(
256 "-lr", "--learning_rate", required=True, help="learning rate"
257 )
258 34
259 # get argument values 35 # get argument values
260 args = vars(arg_parser.parse_args()) 36 args = vars(arg_parser.parse_args())
261 tool_usage_path = args["tool_usage_file"] 37 tool_usage_path = args["tool_usage_file"]
262 workflows_path = args["workflow_file"] 38 workflows_path = args["workflow_file"]
263 cutoff_date = args["cutoff_date"] 39 cutoff_date = args["cutoff_date"]
264 maximum_path_length = int(args["maximum_path_length"]) 40 maximum_path_length = int(args["maximum_path_length"])
41
42 n_train_iter = int(args["n_train_iter"])
43 te_share = float(args["te_share"])
44 tr_batch_size = int(args["tr_batch_size"])
45 te_batch_size = int(args["te_batch_size"])
46
47 n_heads = int(args["n_heads"])
48 feed_forward_dim = int(args["n_feed_forward_dim"])
49 embedding_dim = int(args["n_embed_dim"])
50 dropout = float(args["dropout"])
51 learning_rate = float(args["learning_rate"])
52 te_logging_step = int(args["te_logging_step"])
53 tr_logging_step = int(args["tr_logging_step"])
265 trained_model_path = args["output_model"] 54 trained_model_path = args["output_model"]
266 n_epochs = int(args["n_epochs"])
267 optimize_n_epochs = int(args["optimize_n_epochs"])
268 max_evals = int(args["max_evals"])
269 test_share = float(args["test_share"])
270 batch_size = args["batch_size"]
271 units = args["units"]
272 embedding_size = args["embedding_size"]
273 dropout = args["dropout"]
274 spatial_dropout = args["spatial_dropout"]
275 recurrent_dropout = args["recurrent_dropout"]
276 learning_rate = args["learning_rate"]
277 num_cpus = 16
278 55
279 config = { 56 config = {
280 "cutoff_date": cutoff_date, 57 'cutoff_date': cutoff_date,
281 "maximum_path_length": maximum_path_length, 58 'maximum_path_length': maximum_path_length,
282 "n_epochs": n_epochs, 59 'n_train_iter': n_train_iter,
283 "optimize_n_epochs": optimize_n_epochs, 60 'n_heads': n_heads,
284 "max_evals": max_evals, 61 'feed_forward_dim': feed_forward_dim,
285 "test_share": test_share, 62 'embedding_dim': embedding_dim,
286 "batch_size": batch_size, 63 'dropout': dropout,
287 "units": units, 64 'learning_rate': learning_rate,
288 "embedding_size": embedding_size, 65 'te_share': te_share,
289 "dropout": dropout, 66 'te_logging_step': te_logging_step,
290 "spatial_dropout": spatial_dropout, 67 'tr_logging_step': tr_logging_step,
291 "recurrent_dropout": recurrent_dropout, 68 'tr_batch_size': tr_batch_size,
292 "learning_rate": learning_rate, 69 'te_batch_size': te_batch_size,
70 'trained_model_path': trained_model_path
293 } 71 }
294 72 print("Preprocessing workflows...")
295 # Extract and process workflows 73 # Extract and process workflows
296 connections = extract_workflow_connections.ExtractWorkflowConnections() 74 connections = extract_workflow_connections.ExtractWorkflowConnections()
297 ( 75 # Process raw workflow file
298 workflow_paths, 76 wf_dataframe, usage_df = connections.process_raw_files(workflows_path, tool_usage_path, config)
299 compatible_next_tools, 77 workflow_paths, pub_conn = connections.read_tabular_file(wf_dataframe, config)
300 standard_connections,
301 ) = connections.read_tabular_file(workflows_path)
302 # Process the paths from workflows 78 # Process the paths from workflows
303 print("Dividing data...") 79 print("Dividing data...")
304 data = prepare_data.PrepareData(maximum_path_length, test_share) 80 data = prepare_data.PrepareData(maximum_path_length, te_share)
305 ( 81 train_data, train_labels, test_data, test_labels, f_dict, r_dict, c_wts, c_tools, tr_tool_freq = data.get_data_labels_matrices(workflow_paths, usage_df, cutoff_date, pub_conn)
306 train_data, 82 print(train_data.shape, train_labels.shape, test_data.shape, test_labels.shape)
307 train_labels, 83 train_transformer.create_enc_transformer(train_data, train_labels, test_data, test_labels, f_dict, r_dict, c_wts, c_tools, pub_conn, tr_tool_freq, config)
308 test_data,
309 test_labels,
310 data_dictionary,
311 reverse_dictionary,
312 class_weights,
313 usage_pred,
314 train_tool_freq,
315 tool_tr_samples,
316 ) = data.get_data_labels_matrices(
317 workflow_paths,
318 tool_usage_path,
319 cutoff_date,
320 compatible_next_tools,
321 standard_connections,
322 )
323 # find the best model and start training
324 predict_tool = PredictTool(num_cpus)
325 # start training with weighted classes
326 print("Training with weighted classes and samples ...")
327 results_weighted = predict_tool.find_train_best_network(
328 config,
329 reverse_dictionary,
330 train_data,
331 train_labels,
332 test_data,
333 test_labels,
334 n_epochs,
335 class_weights,
336 usage_pred,
337 standard_connections,
338 train_tool_freq,
339 tool_tr_samples,
340 )
341 utils.save_model(
342 results_weighted,
343 data_dictionary,
344 compatible_next_tools,
345 trained_model_path,
346 class_weights,
347 standard_connections,
348 )
349 end_time = time.time() 84 end_time = time.time()
350 print("Program finished in %s seconds" % str(end_time - start_time)) 85 print("Program finished in %s seconds" % str(end_time - start_time))