comparison utils.py @ 6:e94dc7945639 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author bgruening
date Sun, 16 Oct 2022 11:52:10 +0000
parents 4f7e6612906b
children
comparison
equal deleted inserted replaced
5:4f7e6612906b 6:e94dc7945639
1 import json 1 import json
2 import os
2 import random 3 import random
3 4
4 import h5py 5 import h5py
5 import numpy as np 6 import numpy as np
7 import pandas as pd
6 import tensorflow as tf 8 import tensorflow as tf
7 from numpy.random import choice 9
8 from tensorflow.keras import backend 10 binary_ce = tf.keras.losses.BinaryCrossentropy()
11 binary_acc = tf.keras.metrics.BinaryAccuracy()
12 categorical_ce = tf.keras.metrics.CategoricalCrossentropy(from_logits=True)
9 13
10 14
11 def read_file(file_path): 15 def read_file(file_path):
12 """ 16 """
13 Read a file 17 Read a file
15 with open(file_path, "r") as json_file: 19 with open(file_path, "r") as json_file:
16 file_content = json.loads(json_file.read()) 20 file_content = json.loads(json_file.read())
17 return file_content 21 return file_content
18 22
19 23
24 def write_file(file_path, content):
25 """
26 Write a file
27 """
28 remove_file(file_path)
29 with open(file_path, "w") as json_file:
30 json_file.write(json.dumps(content))
31
32
33 def save_h5_data(inp, tar, filename):
34 hf_file = h5py.File(filename, 'w')
35 hf_file.create_dataset("input", data=inp)
36 hf_file.create_dataset("target", data=tar)
37 hf_file.close()
38
39
40 def get_low_freq_te_samples(te_data, te_target, tr_freq_dict):
41 lowest_tool_te_ids = list()
42 lowest_t_ids = get_lowest_tools(tr_freq_dict)
43 for i, te_labels in enumerate(te_target):
44 tools_pos = np.where(te_labels > 0)[0]
45 tools_pos = [str(int(item)) for item in tools_pos]
46 intersection = list(set(tools_pos).intersection(set(lowest_t_ids)))
47 if len(intersection) > 0:
48 lowest_tool_te_ids.append(i)
49 lowest_t_ids = [item for item in lowest_t_ids if item not in intersection]
50 return lowest_tool_te_ids
51
52
53 def save_processed_workflows(file_path, unique_paths):
54 workflow_paths_unique = ""
55 for path in unique_paths:
56 workflow_paths_unique += path + "\n"
57 with open(file_path, "w") as workflows_file:
58 workflows_file.write(workflow_paths_unique)
59
60
20 def format_tool_id(tool_link): 61 def format_tool_id(tool_link):
21 """ 62 """
22 Extract tool id from tool link 63 Extract tool id from tool link
23 """ 64 """
24 tool_id_split = tool_link.split("/") 65 tool_id_split = tool_link.split("/")
25 tool_id = tool_id_split[-2] if len(tool_id_split) > 1 else tool_link 66 tool_id = tool_id_split[-2] if len(tool_id_split) > 1 else tool_link
26 return tool_id 67 return tool_id
27 68
28 69
29 def set_trained_model(dump_file, model_values): 70 def save_model_file(model, r_dict, c_wts, c_tools, s_conn, model_file):
30 """ 71 model.save_weights(model_file, save_format="h5")
31 Create an h5 file with the trained weights and associated dicts 72 hf_file = h5py.File(model_file, 'r+')
32 """ 73 model_values = {
33 hf_file = h5py.File(dump_file, "w") 74 "reverse_dict": r_dict,
34 for key in model_values: 75 "class_weights": c_wts,
35 value = model_values[key] 76 "compatible_tools": c_tools,
36 if key == "model_weights": 77 "standard_connections": s_conn
37 for idx, item in enumerate(value): 78 }
38 w_key = "weight_" + str(idx) 79 for k in model_values:
39 if w_key in hf_file: 80 hf_file.create_dataset(k, data=json.dumps(model_values[k]))
40 hf_file.modify(w_key, item) 81 hf_file.close()
41 else: 82
42 hf_file.create_dataset(w_key, data=item) 83
84 def remove_file(file_path):
85 if os.path.exists(file_path):
86 os.remove(file_path)
87
88
89 def verify_oversampling_freq(oversampled_tr_data, rev_dict):
90 """
91 Compute the frequency of tool sequences after oversampling
92 """
93 freq_dict = dict()
94 freq_dict_names = dict()
95 for tr_data in oversampled_tr_data:
96 t_pos = np.where(tr_data > 0)[0]
97 last_tool_id = str(int(tr_data[t_pos[-1]]))
98 if last_tool_id not in freq_dict:
99 freq_dict[last_tool_id] = 0
100 freq_dict_names[rev_dict[int(last_tool_id)]] = 0
101 freq_dict[last_tool_id] += 1
102 freq_dict_names[rev_dict[int(last_tool_id)]] += 1
103 s_freq = dict(sorted(freq_dict_names.items(), key=lambda kv: kv[1], reverse=True))
104 return s_freq
105
106
107 def collect_sampled_tool_freq(collected_dict, c_freq):
108 for t in c_freq:
109 if t not in collected_dict:
110 collected_dict[t] = int(c_freq[t])
43 else: 111 else:
44 if key in hf_file: 112 collected_dict[t] += int(c_freq[t])
45 hf_file.modify(key, json.dumps(value)) 113 return collected_dict
46 else: 114
47 hf_file.create_dataset(key, data=json.dumps(value)) 115
48 hf_file.close() 116 def save_data_as_dict(f_dict, r_dict, inp, tar, save_path):
49 117 inp_tar = dict()
50 118 for index, (i, t) in enumerate(zip(inp, tar)):
51 def weighted_loss(class_weights): 119 i_pos = np.where(i > 0)[0]
52 """ 120 i_seq = ",".join([str(int(item)) for item in i[1:i_pos[-1] + 1]])
53 Create a weighted loss function. Penalise the misclassification 121 t_pos = np.where(t > 0)[0]
54 of classes more with the higher usage 122 t_seq = ",".join([str(int(item)) for item in t[1:t_pos[-1] + 1]])
55 """ 123 if i_seq not in inp_tar:
56 weight_values = list(class_weights.values()) 124 inp_tar[i_seq] = list()
57 weight_values.extend(weight_values) 125 inp_tar[i_seq].append(t_seq)
58 126 size = 0
59 def weighted_binary_crossentropy(y_true, y_pred): 127 for item in inp_tar:
60 # add another dimension to compute dot product 128 size += len(inp_tar[item])
61 expanded_weights = tf.expand_dims(weight_values, axis=-1) 129 print("Size saved file: ", size)
62 bce = backend.binary_crossentropy(y_true, y_pred) 130 write_file(save_path, inp_tar)
63 return backend.dot(bce, expanded_weights) 131
64 132
65 return weighted_binary_crossentropy 133 def read_train_test(datapath):
66 134 file_obj = h5py.File(datapath, 'r')
67 135 data_input = np.array(file_obj["input"])
68 def balanced_sample_generator( 136 data_target = np.array(file_obj["target"])
69 train_data, train_labels, batch_size, l_tool_tr_samples, reverse_dictionary 137 return data_input, data_target
70 ): 138
71 while True: 139
72 dimension = train_data.shape[1] 140 def sample_balanced_tr_y(x_seqs, y_labels, ulabels_tr_y_dict, b_size, tr_t_freq, prev_sel_tools):
73 n_classes = train_labels.shape[1] 141 batch_y_tools = list(ulabels_tr_y_dict.keys())
74 tool_ids = list(l_tool_tr_samples.keys()) 142 random.shuffle(batch_y_tools)
75 random.shuffle(tool_ids) 143 label_tools = list()
76 generator_batch_data = np.zeros([batch_size, dimension]) 144 rand_batch_indices = list()
77 generator_batch_labels = np.zeros([batch_size, n_classes]) 145 sel_tools = list()
78 generated_tool_ids = choice(tool_ids, batch_size) 146
79 for i in range(batch_size): 147 unselected_tools = [t for t in batch_y_tools if t not in prev_sel_tools]
80 random_toolid = generated_tool_ids[i] 148 rand_selected_tools = unselected_tools[:b_size]
81 sample_indices = l_tool_tr_samples[str(random_toolid)] 149
82 random_index = random.sample(range(0, len(sample_indices)), 1)[0] 150 for l_tool in rand_selected_tools:
83 random_tr_index = sample_indices[random_index] 151 seq_indices = ulabels_tr_y_dict[l_tool]
84 generator_batch_data[i] = train_data[random_tr_index] 152 random.shuffle(seq_indices)
85 generator_batch_labels[i] = train_labels[random_tr_index] 153 rand_s_index = np.random.randint(0, len(seq_indices), 1)[0]
86 yield generator_batch_data, generator_batch_labels 154 rand_sample = seq_indices[rand_s_index]
87 155 sel_tools.append(l_tool)
88 156 rand_batch_indices.append(rand_sample)
89 def compute_precision( 157 label_tools.append(l_tool)
90 model, 158
91 x, 159 x_batch_train = x_seqs[rand_batch_indices]
92 y, 160 y_batch_train = y_labels[rand_batch_indices]
93 reverse_data_dictionary, 161
94 usage_scores, 162 unrolled_x = tf.convert_to_tensor(x_batch_train, dtype=tf.int64)
95 actual_classes_pos, 163 unrolled_y = tf.convert_to_tensor(y_batch_train, dtype=tf.int64)
96 topk, 164 return unrolled_x, unrolled_y, sel_tools
97 standard_conn, 165
98 last_tool_id, 166
99 lowest_tool_ids, 167 def sample_balanced_te_y(x_seqs, y_labels, ulabels_tr_y_dict, b_size):
100 ): 168 batch_y_tools = list(ulabels_tr_y_dict.keys())
101 """ 169 random.shuffle(batch_y_tools)
102 Compute absolute and compatible precision 170 label_tools = list()
103 """ 171 rand_batch_indices = list()
104 pred_t_name = "" 172 sel_tools = list()
105 top_precision = 0.0 173 for l_tool in batch_y_tools:
106 mean_usage = 0.0 174 seq_indices = ulabels_tr_y_dict[l_tool]
107 usage_wt_score = list() 175 random.shuffle(seq_indices)
108 pub_precision = 0.0 176 rand_s_index = np.random.randint(0, len(seq_indices), 1)[0]
109 lowest_pub_prec = 0.0 177 rand_sample = seq_indices[rand_s_index]
110 lowest_norm_prec = 0.0 178 sel_tools.append(l_tool)
111 pub_tools = list() 179 if rand_sample not in rand_batch_indices:
112 actual_next_tool_names = list() 180 rand_batch_indices.append(rand_sample)
113 test_sample = np.reshape(x, (1, len(x))) 181 label_tools.append(l_tool)
114 182 if len(rand_batch_indices) == b_size:
115 # predict next tools for a test path 183 break
116 prediction = model.predict(test_sample, verbose=0) 184 x_batch_train = x_seqs[rand_batch_indices]
117 185 y_batch_train = y_labels[rand_batch_indices]
118 # divide the predicted vector into two halves - one for published and 186
119 # another for normal workflows 187 unrolled_x = tf.convert_to_tensor(x_batch_train, dtype=tf.int64)
120 nw_dimension = prediction.shape[1] 188 unrolled_y = tf.convert_to_tensor(y_batch_train, dtype=tf.int64)
121 half_len = int(nw_dimension / 2) 189 return unrolled_x, unrolled_y, sel_tools
122 190
123 # predict tools 191
124 prediction = np.reshape(prediction, (nw_dimension,)) 192 def get_u_tr_labels(y_tr):
125 # get predictions of tools from published workflows 193 labels = list()
126 standard_pred = prediction[:half_len] 194 labels_pos_dict = dict()
127 # get predictions of tools from normal workflows 195 for i, item in enumerate(y_tr):
128 normal_pred = prediction[half_len:] 196 label_pos = np.where(item > 0)[0]
129 197 labels.extend(label_pos)
130 standard_prediction_pos = np.argsort(standard_pred, axis=-1) 198 for label in label_pos:
131 standard_topk_prediction_pos = standard_prediction_pos[-topk] 199 if label not in labels_pos_dict:
132 200 labels_pos_dict[label] = list()
133 normal_prediction_pos = np.argsort(normal_pred, axis=-1) 201 labels_pos_dict[label].append(i)
134 normal_topk_prediction_pos = normal_prediction_pos[-topk] 202 u_labels = list(set(labels))
135 203 for item in labels_pos_dict:
136 # get true tools names 204 labels_pos_dict[item] = list(set(labels_pos_dict[item]))
137 for a_t_pos in actual_classes_pos: 205 return u_labels, labels_pos_dict
138 if a_t_pos > half_len: 206
139 t_name = reverse_data_dictionary[int(a_t_pos - half_len)] 207
140 else: 208 def compute_loss(y_true, y_pred, class_weights=None):
141 t_name = reverse_data_dictionary[int(a_t_pos)] 209 y_true = tf.cast(y_true, dtype=tf.float32)
142 actual_next_tool_names.append(t_name) 210 loss = binary_ce(y_true, y_pred)
143 last_tool_name = reverse_data_dictionary[x[-1]] 211 categorical_loss = categorical_ce(y_true, y_pred)
144 # compute scores for published recommendations 212 if class_weights is None:
145 if standard_topk_prediction_pos in reverse_data_dictionary: 213 return tf.reduce_mean(loss), categorical_loss
146 pred_t_name = reverse_data_dictionary[int(standard_topk_prediction_pos)] 214 return tf.tensordot(loss, class_weights, axes=1), categorical_loss
147 if last_tool_name in standard_conn: 215
148 pub_tools = standard_conn[last_tool_name] 216
149 if pred_t_name in pub_tools: 217 def compute_acc(y_true, y_pred):
150 pub_precision = 1.0 218 return binary_acc(y_true, y_pred)
151 # count precision only when there is actually true published tools 219
152 if last_tool_id in lowest_tool_ids: 220
153 lowest_pub_prec = 1.0 221 def validate_model(te_x, te_y, te_batch_size, model, f_dict, r_dict, ulabels_te_dict, tr_labels, lowest_t_ids):
154 else: 222 te_x_batch, y_train_batch, _ = sample_balanced_te_y(te_x, te_y, ulabels_te_dict, te_batch_size)
155 lowest_pub_prec = np.nan 223 print("Total test data size: ", te_x.shape, te_y.shape)
156 if standard_topk_prediction_pos in usage_scores: 224 print("Batch test data size: ", te_x_batch.shape, y_train_batch.shape)
157 usage_wt_score.append( 225 te_pred_batch, _ = model(te_x_batch, training=False)
158 np.log(usage_scores[standard_topk_prediction_pos] + 1.0) 226 test_err, _ = compute_loss(y_train_batch, te_pred_batch)
159 ) 227 print("Test loss:")
160 else: 228 print(test_err.numpy())
161 # count precision only when there is actually true published tools 229 print("Test finished")
162 # else set to np.nan. Set to 0 only when there is wrong prediction
163 pub_precision = np.nan
164 lowest_pub_prec = np.nan
165 # compute scores for normal recommendations
166 if normal_topk_prediction_pos in reverse_data_dictionary:
167 pred_t_name = reverse_data_dictionary[int(normal_topk_prediction_pos)]
168 if pred_t_name in actual_next_tool_names:
169 if normal_topk_prediction_pos in usage_scores:
170 usage_wt_score.append(
171 np.log(usage_scores[normal_topk_prediction_pos] + 1.0)
172 )
173 top_precision = 1.0
174 if last_tool_id in lowest_tool_ids:
175 lowest_norm_prec = 1.0
176 else:
177 lowest_norm_prec = np.nan
178 if len(usage_wt_score) > 0:
179 mean_usage = np.mean(usage_wt_score)
180 return mean_usage, top_precision, pub_precision, lowest_pub_prec, lowest_norm_prec
181 230
182 231
183 def get_lowest_tools(l_tool_freq, fraction=0.25): 232 def get_lowest_tools(l_tool_freq, fraction=0.25):
184 l_tool_freq = dict(sorted(l_tool_freq.items(), key=lambda kv: kv[1], reverse=True)) 233 l_tool_freq = dict(sorted(l_tool_freq.items(), key=lambda kv: kv[1], reverse=True))
185 tool_ids = list(l_tool_freq.keys()) 234 tool_ids = list(l_tool_freq.keys())
186 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):] 235 lowest_ids = tool_ids[-int(len(tool_ids) * fraction):]
187 return lowest_ids 236 return lowest_ids
188 237
189 238
190 def verify_model( 239 def remove_pipe(file_path):
191 model, 240 dataframe = pd.read_csv(file_path, sep="|", header=None)
192 x, 241 dataframe = dataframe[1:len(dataframe.index) - 1]
193 y, 242 return dataframe[1:]
194 reverse_data_dictionary,
195 usage_scores,
196 standard_conn,
197 lowest_tool_ids,
198 topk_list=[1, 2, 3],
199 ):
200 """
201 Verify the model on test data
202 """
203 print("Evaluating performance on test data...")
204 print("Test data size: %d" % len(y))
205 size = y.shape[0]
206 precision = np.zeros([len(y), len(topk_list)])
207 usage_weights = np.zeros([len(y), len(topk_list)])
208 epo_pub_prec = np.zeros([len(y), len(topk_list)])
209 epo_lowest_tools_pub_prec = list()
210 epo_lowest_tools_norm_prec = list()
211 lowest_counter = 0
212 # loop over all the test samples and find prediction precision
213 for i in range(size):
214 lowest_pub_topk = list()
215 lowest_norm_topk = list()
216 actual_classes_pos = np.where(y[i] > 0)[0]
217 test_sample = x[i, :]
218 last_tool_id = str(int(test_sample[-1]))
219 for index, abs_topk in enumerate(topk_list):
220 (
221 usg_wt_score,
222 absolute_precision,
223 pub_prec,
224 lowest_p_prec,
225 lowest_n_prec,
226 ) = compute_precision(
227 model,
228 test_sample,
229 y,
230 reverse_data_dictionary,
231 usage_scores,
232 actual_classes_pos,
233 abs_topk,
234 standard_conn,
235 last_tool_id,
236 lowest_tool_ids,
237 )
238 precision[i][index] = absolute_precision
239 usage_weights[i][index] = usg_wt_score
240 epo_pub_prec[i][index] = pub_prec
241 lowest_pub_topk.append(lowest_p_prec)
242 lowest_norm_topk.append(lowest_n_prec)
243 epo_lowest_tools_pub_prec.append(lowest_pub_topk)
244 epo_lowest_tools_norm_prec.append(lowest_norm_topk)
245 if last_tool_id in lowest_tool_ids:
246 lowest_counter += 1
247 mean_precision = np.mean(precision, axis=0)
248 mean_usage = np.mean(usage_weights, axis=0)
249 mean_pub_prec = np.nanmean(epo_pub_prec, axis=0)
250 mean_lowest_pub_prec = np.nanmean(epo_lowest_tools_pub_prec, axis=0)
251 mean_lowest_norm_prec = np.nanmean(epo_lowest_tools_norm_prec, axis=0)
252 return (
253 mean_usage,
254 mean_precision,
255 mean_pub_prec,
256 mean_lowest_pub_prec,
257 mean_lowest_norm_prec,
258 lowest_counter,
259 )
260
261
262 def save_model(
263 results,
264 data_dictionary,
265 compatible_next_tools,
266 trained_model_path,
267 class_weights,
268 standard_connections,
269 ):
270 # save files
271 trained_model = results["model"]
272 best_model_parameters = results["best_parameters"]
273 model_config = trained_model.to_json()
274 model_weights = trained_model.get_weights()
275 model_values = {
276 "data_dictionary": data_dictionary,
277 "model_config": model_config,
278 "best_parameters": best_model_parameters,
279 "model_weights": model_weights,
280 "compatible_tools": compatible_next_tools,
281 "class_weights": class_weights,
282 "standard_connections": standard_connections,
283 }
284 set_trained_model(trained_model_path, model_values)