comparison prepare_data.py @ 4:afec8c595124 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 65d36f271296a38deeceb0d0e8d471b2898ee8f4"
author bgruening
date Tue, 07 Jul 2020 03:25:49 -0400
parents 5b3c08710e47
children 4f7e6612906b
comparison
equal deleted inserted replaced
3:5b3c08710e47 4:afec8c595124
8 import collections 8 import collections
9 import numpy as np 9 import numpy as np
10 import random 10 import random
11 11
12 import predict_tool_usage 12 import predict_tool_usage
13 import utils
14 13
15 main_path = os.getcwd() 14 main_path = os.getcwd()
16 15
17 16
18 class PrepareData: 17 class PrepareData:
209 """ 208 """
210 Get the frequency of last tool of each tool sequence 209 Get the frequency of last tool of each tool sequence
211 to estimate the frequency of tool sequences 210 to estimate the frequency of tool sequences
212 """ 211 """
213 last_tool_freq = dict() 212 last_tool_freq = dict()
214 inv_freq = dict() 213 freq_dict_names = dict()
215 for path in train_paths: 214 for path in train_paths:
216 last_tool = path.split(",")[-1] 215 last_tool = path.split(",")[-1]
217 if last_tool not in last_tool_freq: 216 if last_tool not in last_tool_freq:
218 last_tool_freq[last_tool] = 0 217 last_tool_freq[last_tool] = 0
218 freq_dict_names[reverse_dictionary[int(last_tool)]] = 0
219 last_tool_freq[last_tool] += 1 219 last_tool_freq[last_tool] += 1
220 max_freq = max(last_tool_freq.values()) 220 freq_dict_names[reverse_dictionary[int(last_tool)]] += 1
221 for t in last_tool_freq: 221 return last_tool_freq
222 inv_freq[t] = int(np.round(max_freq / float(last_tool_freq[t]), 0))
223 return last_tool_freq, inv_freq
224 222
225 def get_toolid_samples(self, train_data, l_tool_freq): 223 def get_toolid_samples(self, train_data, l_tool_freq):
226 l_tool_tr_samples = dict() 224 l_tool_tr_samples = dict()
227 for tool_id in l_tool_freq: 225 for tool_id in l_tool_freq:
228 for index, tr_sample in enumerate(train_data): 226 for index, tr_sample in enumerate(train_data):
252 multilabels_paths = self.prepare_paths_labels_dictionary(dictionary, rev_dict, all_unique_paths, compatible_next_tools) 250 multilabels_paths = self.prepare_paths_labels_dictionary(dictionary, rev_dict, all_unique_paths, compatible_next_tools)
253 251
254 print("Complete data: %d" % len(multilabels_paths)) 252 print("Complete data: %d" % len(multilabels_paths))
255 train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths) 253 train_paths_dict, test_paths_dict = self.split_test_train_data(multilabels_paths)
256 254
257 # get sample frequency
258 l_tool_freq, inv_last_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict)
259
260 print("Train data: %d" % len(train_paths_dict)) 255 print("Train data: %d" % len(train_paths_dict))
261 print("Test data: %d" % len(test_paths_dict)) 256 print("Test data: %d" % len(test_paths_dict))
262 257
263 print("Padding train and test data...") 258 print("Padding train and test data...")
264 # pad training and test data with leading zeros 259 # pad training and test data with leading zeros
265 test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict) 260 test_data, test_labels = self.pad_paths(test_paths_dict, num_classes, standard_connections, rev_dict)
266 train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict) 261 train_data, train_labels = self.pad_paths(train_paths_dict, num_classes, standard_connections, rev_dict)
267 262
263 print("Estimating sample frequency...")
264 l_tool_freq = self.get_train_last_tool_freq(train_paths_dict, rev_dict)
268 l_tool_tr_samples = self.get_toolid_samples(train_data, l_tool_freq) 265 l_tool_tr_samples = self.get_toolid_samples(train_data, l_tool_freq)
269 266
270 # Predict tools usage 267 # Predict tools usage
271 print("Predicting tools' usage...") 268 print("Predicting tools' usage...")
272 usage_pred = predict_tool_usage.ToolPopularity() 269 usage_pred = predict_tool_usage.ToolPopularity()