diff main.py @ 4:afec8c595124 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author bgruening
date Tue, 07 Jul 2020 03:25:49 -0400
parents 5b3c08710e47
children 4f7e6612906b
line wrap: on
line diff
--- a/main.py	Sat May 09 05:38:23 2020 -0400
+++ b/main.py	Tue Jul 07 03:25:49 2020 -0400
@@ -31,23 +31,22 @@
         )
         K.set_session(tf.Session(config=cpu_config))
 
-    def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples):
+    def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, tool_freq, tool_tr_samples):
         """
         Define recurrent neural network and train sequential data
         """
         # get tools with lowest representation
-        lowest_tool_ids = utils.get_lowest_tools(l_tool_freq)
+        lowest_tool_ids = utils.get_lowest_tools(tool_freq)
 
         print("Start hyperparameter optimisation...")
         hyper_opt = optimise_hyperparameters.HyperparameterOptimisation()
-        best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, l_tool_tr_samples, class_weights)
+        best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights)
 
         # define callbacks
         early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True)
         predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids)
 
         callbacks_list = [predict_callback_test, early_stopping]
-
         batch_size = int(best_params["batch_size"])
 
         print("Start training on the best model...")
@@ -57,7 +56,8 @@
                 train_data,
                 train_labels,
                 batch_size,
-                l_tool_tr_samples
+                tool_tr_samples,
+                reverse_dictionary
             ),
             steps_per_epoch=len(train_data) // batch_size,
             epochs=n_epochs,
@@ -177,13 +177,12 @@
     # Process the paths from workflows
     print("Dividing data...")
     data = prepare_data.PrepareData(maximum_path_length, test_share)
-    train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, l_tool_freq, l_tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections)
+    train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, train_tool_freq, tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections)
     # find the best model and start training
     predict_tool = PredictTool(num_cpus)
     # start training with weighted classes
     print("Training with weighted classes and samples ...")
-    results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples)
+    results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, train_tool_freq, tool_tr_samples)
     utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections)
     end_time = time.time()
-    print()
     print("Program finished in %s seconds" % str(end_time - start_time))