comparison main.py @ 0:9bf25dbe00ad draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
author bgruening
date Wed, 28 Aug 2019 07:19:38 -0400
parents
children 12764915e1c5
comparison
equal deleted inserted replaced
-1:000000000000 0:9bf25dbe00ad
1 """
2 Predict next tools in the Galaxy workflows
3 using machine learning (recurrent neural network)
4 """
5
6 import numpy as np
7 import argparse
8 import time
9
10 # machine learning library
11 import keras.callbacks as callbacks
12
13 import extract_workflow_connections
14 import prepare_data
15 import optimise_hyperparameters
16 import utils
17
18
19 class PredictTool:
20
21 @classmethod
22 def __init__(self):
23 """ Init method. """
24
25 @classmethod
26 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, compatible_next_tools):
27 """
28 Define recurrent neural network and train sequential data
29 """
30 print("Start hyperparameter optimisation...")
31 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation()
32 best_params = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, class_weights)
33
34 # retrieve the model and train on complete dataset without validation set
35 model, best_params = utils.set_recurrent_network(best_params, reverse_dictionary, class_weights)
36
37 # define callbacks
38 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, compatible_next_tools, usage_pred)
39 # tensor_board = callbacks.TensorBoard(log_dir=log_directory, histogram_freq=0, write_graph=True, write_images=True)
40 callbacks_list = [predict_callback_test]
41
42 print("Start training on the best model...")
43 model_fit = model.fit(
44 train_data,
45 train_labels,
46 batch_size=int(best_params["batch_size"]),
47 epochs=n_epochs,
48 verbose=2,
49 callbacks=callbacks_list,
50 shuffle="batch",
51 validation_data=(test_data, test_labels)
52 )
53
54 train_performance = {
55 "train_loss": np.array(model_fit.history["loss"]),
56 "model": model,
57 "best_parameters": best_params
58 }
59
60 # if there is test data, add more information
61 if len(test_data) > 0:
62 train_performance["validation_loss"] = np.array(model_fit.history["val_loss"])
63 train_performance["precision"] = predict_callback_test.precision
64 train_performance["usage_weights"] = predict_callback_test.usage_weights
65 return train_performance
66
67
68 class PredictCallback(callbacks.Callback):
69 def __init__(self, test_data, test_labels, reverse_data_dictionary, n_epochs, next_compatible_tools, usg_scores):
70 self.test_data = test_data
71 self.test_labels = test_labels
72 self.reverse_data_dictionary = reverse_data_dictionary
73 self.precision = list()
74 self.usage_weights = list()
75 self.n_epochs = n_epochs
76 self.next_compatible_tools = next_compatible_tools
77 self.pred_usage_scores = usg_scores
78
79 def on_epoch_end(self, epoch, logs={}):
80 """
81 Compute absolute and compatible precision for test data
82 """
83 if len(self.test_data) > 0:
84 precision, usage_weights = utils.verify_model(self.model, self.test_data, self.test_labels, self.reverse_data_dictionary, self.next_compatible_tools, self.pred_usage_scores)
85 self.precision.append(precision)
86 self.usage_weights.append(usage_weights)
87 print("Epoch %d precision: %s" % (epoch + 1, precision))
88 print("Epoch %d usage weights: %s" % (epoch + 1, usage_weights))
89
90
91 if __name__ == "__main__":
92 start_time = time.time()
93 arg_parser = argparse.ArgumentParser()
94 arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file")
95 arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file")
96 arg_parser.add_argument("-om", "--output_model", required=True, help="trained model file")
97 # data parameters
98 arg_parser.add_argument("-cd", "--cutoff_date", required=True, help="earliest date for taking tool usage")
99 arg_parser.add_argument("-pl", "--maximum_path_length", required=True, help="maximum length of tool path")
100 arg_parser.add_argument("-ep", "--n_epochs", required=True, help="number of iterations to run to create model")
101 arg_parser.add_argument("-oe", "--optimize_n_epochs", required=True, help="number of iterations to run to find best model parameters")
102 arg_parser.add_argument("-me", "--max_evals", required=True, help="maximum number of configuration evaluations")
103 arg_parser.add_argument("-ts", "--test_share", required=True, help="share of data to be used for testing")
104 arg_parser.add_argument("-vs", "--validation_share", required=True, help="share of data to be used for validation")
105 # neural network parameters
106 arg_parser.add_argument("-bs", "--batch_size", required=True, help="size of the tranining batch i.e. the number of samples per batch")
107 arg_parser.add_argument("-ut", "--units", required=True, help="number of hidden recurrent units")
108 arg_parser.add_argument("-es", "--embedding_size", required=True, help="size of the fixed vector learned for each tool")
109 arg_parser.add_argument("-dt", "--dropout", required=True, help="percentage of neurons to be dropped")
110 arg_parser.add_argument("-sd", "--spatial_dropout", required=True, help="1d dropout used for embedding layer")
111 arg_parser.add_argument("-rd", "--recurrent_dropout", required=True, help="dropout for the recurrent layers")
112 arg_parser.add_argument("-lr", "--learning_rate", required=True, help="learning rate")
113 arg_parser.add_argument("-ar", "--activation_recurrent", required=True, help="activation function for recurrent layers")
114 arg_parser.add_argument("-ao", "--activation_output", required=True, help="activation function for output layers")
115 arg_parser.add_argument("-lt", "--loss_type", required=True, help="type of the loss/error function")
116 # get argument values
117 args = vars(arg_parser.parse_args())
118 tool_usage_path = args["tool_usage_file"]
119 workflows_path = args["workflow_file"]
120 cutoff_date = args["cutoff_date"]
121 maximum_path_length = int(args["maximum_path_length"])
122 trained_model_path = args["output_model"]
123 n_epochs = int(args["n_epochs"])
124 optimize_n_epochs = int(args["optimize_n_epochs"])
125 max_evals = int(args["max_evals"])
126 test_share = float(args["test_share"])
127 validation_share = float(args["validation_share"])
128 batch_size = args["batch_size"]
129 units = args["units"]
130 embedding_size = args["embedding_size"]
131 dropout = args["dropout"]
132 spatial_dropout = args["spatial_dropout"]
133 recurrent_dropout = args["recurrent_dropout"]
134 learning_rate = args["learning_rate"]
135 activation_recurrent = args["activation_recurrent"]
136 activation_output = args["activation_output"]
137 loss_type = args["loss_type"]
138
139 config = {
140 'cutoff_date': cutoff_date,
141 'maximum_path_length': maximum_path_length,
142 'n_epochs': n_epochs,
143 'optimize_n_epochs': optimize_n_epochs,
144 'max_evals': max_evals,
145 'test_share': test_share,
146 'validation_share': validation_share,
147 'batch_size': batch_size,
148 'units': units,
149 'embedding_size': embedding_size,
150 'dropout': dropout,
151 'spatial_dropout': spatial_dropout,
152 'recurrent_dropout': recurrent_dropout,
153 'learning_rate': learning_rate,
154 'activation_recurrent': activation_recurrent,
155 'activation_output': activation_output,
156 'loss_type': loss_type
157 }
158
159 # Extract and process workflows
160 connections = extract_workflow_connections.ExtractWorkflowConnections()
161 workflow_paths, compatible_next_tools = connections.read_tabular_file(workflows_path)
162 # Process the paths from workflows
163 print("Dividing data...")
164 data = prepare_data.PrepareData(maximum_path_length, test_share)
165 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools)
166 # find the best model and start training
167 predict_tool = PredictTool()
168 # start training with weighted classes
169 print("Training with weighted classes and samples ...")
170 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, compatible_next_tools)
171 print()
172 print("Best parameters \n")
173 print(results_weighted["best_parameters"])
174 print()
175 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights)
176 end_time = time.time()
177 print()
178 print("Program finished in %s seconds" % str(end_time - start_time))