diff main.py @ 2:76251d1ccdcc draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 6fa2a0294d615c9f267b766337dca0b2d3637219"
author bgruening
date Fri, 11 Oct 2019 18:24:54 -0400
parents 12764915e1c5
children 5b3c08710e47
line wrap: on
line diff
--- a/main.py	Wed Sep 25 06:42:40 2019 -0400
+++ b/main.py	Fri Oct 11 18:24:54 2019 -0400
@@ -8,6 +8,8 @@
 import time
 
 # machine learning library
+import tensorflow as tf
+from keras import backend as K
 import keras.callbacks as callbacks
 
 import extract_workflow_connections
@@ -19,8 +21,16 @@
 class PredictTool:
 
     @classmethod
-    def __init__(self):
+    def __init__(self, num_cpus):
         """ Init method. """
+        # set the number of cpus
+        cpu_config = tf.ConfigProto(
+            device_count={"CPU": num_cpus},
+            intra_op_parallelism_threads=num_cpus,
+            inter_op_parallelism_threads=num_cpus,
+            allow_soft_placement=True
+        )
+        K.set_session(tf.Session(config=cpu_config))
 
     @classmethod
     def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, compatible_next_tools):
@@ -29,39 +39,43 @@
         """
         print("Start hyperparameter optimisation...")
         hyper_opt = optimise_hyperparameters.HyperparameterOptimisation()
-        best_params = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, class_weights)
-
-        # retrieve the model and train on complete dataset without validation set
-        model, best_params = utils.set_recurrent_network(best_params, reverse_dictionary, class_weights)
+        best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, class_weights)
 
         # define callbacks
+        early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-4, restore_best_weights=True)
         predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, compatible_next_tools, usage_pred)
-        # tensor_board = callbacks.TensorBoard(log_dir=log_directory, histogram_freq=0, write_graph=True, write_images=True)
-        callbacks_list = [predict_callback_test]
+
+        callbacks_list = [predict_callback_test, early_stopping]
 
         print("Start training on the best model...")
-        model_fit = model.fit(
-            train_data,
-            train_labels,
-            batch_size=int(best_params["batch_size"]),
-            epochs=n_epochs,
-            verbose=2,
-            callbacks=callbacks_list,
-            shuffle="batch",
-            validation_data=(test_data, test_labels)
-        )
-
-        train_performance = {
-            "train_loss": np.array(model_fit.history["loss"]),
-            "model": model,
-            "best_parameters": best_params
-        }
-
-        # if there is test data, add more information
+        train_performance = dict()
         if len(test_data) > 0:
-            train_performance["validation_loss"] = np.array(model_fit.history["val_loss"])
+            trained_model = best_model.fit(
+                train_data,
+                train_labels,
+                batch_size=int(best_params["batch_size"]),
+                epochs=n_epochs,
+                verbose=2,
+                callbacks=callbacks_list,
+                shuffle="batch",
+                validation_data=(test_data, test_labels)
+            )
+            train_performance["validation_loss"] = np.array(trained_model.history["val_loss"])
             train_performance["precision"] = predict_callback_test.precision
             train_performance["usage_weights"] = predict_callback_test.usage_weights
+        else:
+            trained_model = best_model.fit(
+                train_data,
+                train_labels,
+                batch_size=int(best_params["batch_size"]),
+                epochs=n_epochs,
+                verbose=2,
+                callbacks=callbacks_list,
+                shuffle="batch"
+            )
+        train_performance["train_loss"] = np.array(trained_model.history["loss"])
+        train_performance["model"] = best_model
+        train_performance["best_parameters"] = best_params
         return train_performance
 
 
@@ -90,6 +104,7 @@
 
 if __name__ == "__main__":
     start_time = time.time()
+
     arg_parser = argparse.ArgumentParser()
     arg_parser.add_argument("-wf", "--workflow_file", required=True, help="workflows tabular file")
     arg_parser.add_argument("-tu", "--tool_usage_file", required=True, help="tool usage file")
@@ -112,6 +127,7 @@
     arg_parser.add_argument("-lr", "--learning_rate", required=True, help="learning rate")
     arg_parser.add_argument("-ar", "--activation_recurrent", required=True, help="activation function for recurrent layers")
     arg_parser.add_argument("-ao", "--activation_output", required=True, help="activation function for output layers")
+
     # get argument values
     args = vars(arg_parser.parse_args())
     tool_usage_path = args["tool_usage_file"]
@@ -133,6 +149,7 @@
     learning_rate = args["learning_rate"]
     activation_recurrent = args["activation_recurrent"]
     activation_output = args["activation_output"]
+    num_cpus = 16
 
     config = {
         'cutoff_date': cutoff_date,
@@ -161,7 +178,7 @@
     data = prepare_data.PrepareData(maximum_path_length, test_share)
     train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools)
     # find the best model and start training
-    predict_tool = PredictTool()
+    predict_tool = PredictTool(num_cpus)
     # start training with weighted classes
     print("Training with weighted classes and samples ...")
     results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, compatible_next_tools)