comparison predict_tool_usage.py @ 0:9bf25dbe00ad draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 7fac577189d01cedd01118a77fc2baaefe7d5cad"
author bgruening
date Wed, 28 Aug 2019 07:19:38 -0400
parents
children 5b3c08710e47
comparison
equal deleted inserted replaced
-1:000000000000 0:9bf25dbe00ad
1 """
2 Predict tool usage to weigh the predicted tools
3 """
4
5 import os
6 import numpy as np
7 import warnings
8 import csv
9 import collections
10
11 from sklearn.svm import SVR
12 from sklearn.model_selection import GridSearchCV
13 from sklearn.pipeline import Pipeline
14
15 import utils
16
17 warnings.filterwarnings("ignore")
18
19 main_path = os.getcwd()
20
21
22 class ToolPopularity:
23
24 @classmethod
25 def __init__(self):
26 """ Init method. """
27
28 @classmethod
29 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary):
30 """
31 Extract the tool usage over time for each tool
32 """
33 tool_usage_dict = dict()
34 all_dates = list()
35 all_tool_list = list(dictionary.keys())
36 with open(tool_usage_file, 'rt') as usage_file:
37 tool_usage = csv.reader(usage_file, delimiter='\t')
38 for index, row in enumerate(tool_usage):
39 if (str(row[1]) > cutoff_date) is True:
40 tool_id = utils.format_tool_id(row[0])
41 if tool_id in all_tool_list:
42 all_dates.append(row[1])
43 if tool_id not in tool_usage_dict:
44 tool_usage_dict[tool_id] = dict()
45 tool_usage_dict[tool_id][row[1]] = int(row[2])
46 else:
47 curr_date = row[1]
48 # merge the usage of different version of tools into one
49 if curr_date in tool_usage_dict[tool_id]:
50 tool_usage_dict[tool_id][curr_date] += int(row[2])
51 else:
52 tool_usage_dict[tool_id][curr_date] = int(row[2])
53 # get unique dates
54 unique_dates = list(set(all_dates))
55 for tool in tool_usage_dict:
56 usage = tool_usage_dict[tool]
57 # extract those dates for which tool's usage is not present in raw data
58 dates_not_present = list(set(unique_dates) ^ set(usage.keys()))
59 # impute the missing values by 0
60 for dt in dates_not_present:
61 tool_usage_dict[tool][dt] = 0
62 # sort the usage list by date
63 tool_usage_dict[tool] = collections.OrderedDict(sorted(usage.items()))
64 return tool_usage_dict
65
66 @classmethod
67 def learn_tool_popularity(self, x_reshaped, y_reshaped):
68 """
69 Fit a curve for the tool usage over time to predict future tool usage
70 """
71 epsilon = 0.0
72 cv = 5
73 s_typ = 'neg_mean_absolute_error'
74 n_jobs = 4
75 s_error = 1
76 iid = True
77 tr_score = False
78 try:
79 pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))])
80 param_grid = {
81 'regressor__kernel': ['rbf', 'poly', 'linear'],
82 'regressor__degree': [2, 3]
83 }
84 search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score)
85 search.fit(x_reshaped, y_reshaped.ravel())
86 model = search.best_estimator_
87 # set the next time point to get prediction for
88 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1))
89 prediction = model.predict(prediction_point)
90 if prediction < epsilon:
91 prediction = [epsilon]
92 return prediction[0]
93 except Exception:
94 return epsilon
95
96 @classmethod
97 def get_pupularity_prediction(self, tools_usage):
98 """
99 Get the popularity prediction for each tool
100 """
101 usage_prediction = dict()
102 for tool_name, usage in tools_usage.items():
103 y_val = list()
104 x_val = list()
105 for x, y in usage.items():
106 x_val.append(x)
107 y_val.append(y)
108 x_pos = np.arange(len(x_val))
109 x_reshaped = x_pos.reshape(len(x_pos), 1)
110 y_reshaped = np.reshape(y_val, (len(x_pos), 1))
111 prediction = np.round(self.learn_tool_popularity(x_reshaped, y_reshaped), 8)
112 usage_prediction[tool_name] = prediction
113 return usage_prediction