comparison main.py @ 4:afec8c595124 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author bgruening
date Tue, 07 Jul 2020 03:25:49 -0400
parents 5b3c08710e47
children 4f7e6612906b
comparison
equal deleted inserted replaced
3:5b3c08710e47 4:afec8c595124
29 inter_op_parallelism_threads=num_cpus, 29 inter_op_parallelism_threads=num_cpus,
30 allow_soft_placement=True 30 allow_soft_placement=True
31 ) 31 )
32 K.set_session(tf.Session(config=cpu_config)) 32 K.set_session(tf.Session(config=cpu_config))
33 33
34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples): 34 def find_train_best_network(self, network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, tool_freq, tool_tr_samples):
35 """ 35 """
36 Define recurrent neural network and train sequential data 36 Define recurrent neural network and train sequential data
37 """ 37 """
38 # get tools with lowest representation 38 # get tools with lowest representation
39 lowest_tool_ids = utils.get_lowest_tools(l_tool_freq) 39 lowest_tool_ids = utils.get_lowest_tools(tool_freq)
40 40
41 print("Start hyperparameter optimisation...") 41 print("Start hyperparameter optimisation...")
42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation() 42 hyper_opt = optimise_hyperparameters.HyperparameterOptimisation()
43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, l_tool_tr_samples, class_weights) 43 best_params, best_model = hyper_opt.train_model(network_config, reverse_dictionary, train_data, train_labels, test_data, test_labels, tool_tr_samples, class_weights)
44 44
45 # define callbacks 45 # define callbacks
46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True) 46 early_stopping = callbacks.EarlyStopping(monitor='loss', mode='min', verbose=1, min_delta=1e-1, restore_best_weights=True)
47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids) 47 predict_callback_test = PredictCallback(test_data, test_labels, reverse_dictionary, n_epochs, usage_pred, standard_connections, lowest_tool_ids)
48 48
49 callbacks_list = [predict_callback_test, early_stopping] 49 callbacks_list = [predict_callback_test, early_stopping]
50
51 batch_size = int(best_params["batch_size"]) 50 batch_size = int(best_params["batch_size"])
52 51
53 print("Start training on the best model...") 52 print("Start training on the best model...")
54 train_performance = dict() 53 train_performance = dict()
55 trained_model = best_model.fit_generator( 54 trained_model = best_model.fit_generator(
56 utils.balanced_sample_generator( 55 utils.balanced_sample_generator(
57 train_data, 56 train_data,
58 train_labels, 57 train_labels,
59 batch_size, 58 batch_size,
60 l_tool_tr_samples 59 tool_tr_samples,
60 reverse_dictionary
61 ), 61 ),
62 steps_per_epoch=len(train_data) // batch_size, 62 steps_per_epoch=len(train_data) // batch_size,
63 epochs=n_epochs, 63 epochs=n_epochs,
64 callbacks=callbacks_list, 64 callbacks=callbacks_list,
65 validation_data=(test_data, test_labels), 65 validation_data=(test_data, test_labels),
175 connections = extract_workflow_connections.ExtractWorkflowConnections() 175 connections = extract_workflow_connections.ExtractWorkflowConnections()
176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path) 176 workflow_paths, compatible_next_tools, standard_connections = connections.read_tabular_file(workflows_path)
177 # Process the paths from workflows 177 # Process the paths from workflows
178 print("Dividing data...") 178 print("Dividing data...")
179 data = prepare_data.PrepareData(maximum_path_length, test_share) 179 data = prepare_data.PrepareData(maximum_path_length, test_share)
180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, l_tool_freq, l_tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections) 180 train_data, train_labels, test_data, test_labels, data_dictionary, reverse_dictionary, class_weights, usage_pred, train_tool_freq, tool_tr_samples = data.get_data_labels_matrices(workflow_paths, tool_usage_path, cutoff_date, compatible_next_tools, standard_connections)
181 # find the best model and start training 181 # find the best model and start training
182 predict_tool = PredictTool(num_cpus) 182 predict_tool = PredictTool(num_cpus)
183 # start training with weighted classes 183 # start training with weighted classes
184 print("Training with weighted classes and samples ...") 184 print("Training with weighted classes and samples ...")
185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, l_tool_freq, l_tool_tr_samples) 185 results_weighted = predict_tool.find_train_best_network(config, reverse_dictionary, train_data, train_labels, test_data, test_labels, n_epochs, class_weights, usage_pred, standard_connections, train_tool_freq, tool_tr_samples)
186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections) 186 utils.save_model(results_weighted, data_dictionary, compatible_next_tools, trained_model_path, class_weights, standard_connections)
187 end_time = time.time() 187 end_time = time.time()
188 print()
189 print("Program finished in %s seconds" % str(end_time - start_time)) 188 print("Program finished in %s seconds" % str(end_time - start_time))