comparison predict_tool_usage.py @ 5:4f7e6612906b draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 5eebc0cb44e71f581d548b7e842002705dd155eb"
author bgruening
date Fri, 06 May 2022 09:05:18 +0000
parents 5b3c08710e47
children e94dc7945639
comparison
equal deleted inserted replaced
4:afec8c595124 5:4f7e6612906b
1 """ 1 """
2 Predict tool usage to weigh the predicted tools 2 Predict tool usage to weigh the predicted tools
3 """ 3 """
4 4
5 import collections
6 import csv
5 import os 7 import os
8 import warnings
9
6 import numpy as np 10 import numpy as np
7 import warnings 11 import utils
8 import csv
9 import collections
10
11 from sklearn.svm import SVR
12 from sklearn.model_selection import GridSearchCV 12 from sklearn.model_selection import GridSearchCV
13 from sklearn.pipeline import Pipeline 13 from sklearn.pipeline import Pipeline
14 14 from sklearn.svm import SVR
15 import utils
16 15
17 warnings.filterwarnings("ignore") 16 warnings.filterwarnings("ignore")
18 17
19 main_path = os.getcwd() 18 main_path = os.getcwd()
20 19
21 20
22 class ToolPopularity: 21 class ToolPopularity:
23
24 def __init__(self): 22 def __init__(self):
25 """ Init method. """ 23 """ Init method. """
26 24
27 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): 25 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary):
28 """ 26 """
29 Extract the tool usage over time for each tool 27 Extract the tool usage over time for each tool
30 """ 28 """
31 tool_usage_dict = dict() 29 tool_usage_dict = dict()
32 all_dates = list() 30 all_dates = list()
33 all_tool_list = list(dictionary.keys()) 31 all_tool_list = list(dictionary.keys())
34 with open(tool_usage_file, 'rt') as usage_file: 32 with open(tool_usage_file, "rt") as usage_file:
35 tool_usage = csv.reader(usage_file, delimiter='\t') 33 tool_usage = csv.reader(usage_file, delimiter="\t")
36 for index, row in enumerate(tool_usage): 34 for index, row in enumerate(tool_usage):
37 if (str(row[1]) > cutoff_date) is True: 35 row = [item.strip() for item in row]
36 if (str(row[1]).strip() > cutoff_date) is True:
38 tool_id = utils.format_tool_id(row[0]) 37 tool_id = utils.format_tool_id(row[0])
39 if tool_id in all_tool_list: 38 if tool_id in all_tool_list:
40 all_dates.append(row[1]) 39 all_dates.append(row[1])
41 if tool_id not in tool_usage_dict: 40 if tool_id not in tool_usage_dict:
42 tool_usage_dict[tool_id] = dict() 41 tool_usage_dict[tool_id] = dict()
65 """ 64 """
66 Fit a curve for the tool usage over time to predict future tool usage 65 Fit a curve for the tool usage over time to predict future tool usage
67 """ 66 """
68 epsilon = 0.0 67 epsilon = 0.0
69 cv = 5 68 cv = 5
70 s_typ = 'neg_mean_absolute_error' 69 s_typ = "neg_mean_absolute_error"
71 n_jobs = 4 70 n_jobs = 4
72 s_error = 1 71 s_error = 1
73 iid = True
74 tr_score = False 72 tr_score = False
75 try: 73 try:
76 pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))]) 74 pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))])
77 param_grid = { 75 param_grid = {
78 'regressor__kernel': ['rbf', 'poly', 'linear'], 76 "regressor__kernel": ["rbf", "poly", "linear"],
79 'regressor__degree': [2, 3] 77 "regressor__degree": [2, 3],
80 } 78 }
81 search = GridSearchCV(pipe, param_grid, iid=iid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score) 79 search = GridSearchCV(
80 pipe,
81 param_grid,
82 cv=cv,
83 scoring=s_typ,
84 n_jobs=n_jobs,
85 error_score=s_error,
86 return_train_score=tr_score,
87 )
82 search.fit(x_reshaped, y_reshaped.ravel()) 88 search.fit(x_reshaped, y_reshaped.ravel())
83 model = search.best_estimator_ 89 model = search.best_estimator_
84 # set the next time point to get prediction for 90 # set the next time point to get prediction for
85 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) 91 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1))
86 prediction = model.predict(prediction_point) 92 prediction = model.predict(prediction_point)
87 if prediction < epsilon: 93 if prediction < epsilon:
88 prediction = [epsilon] 94 prediction = [epsilon]
89 return prediction[0] 95 return prediction[0]
90 except Exception: 96 except Exception as e:
97 print(e)
91 return epsilon 98 return epsilon
92 99
93 def get_pupularity_prediction(self, tools_usage): 100 def get_pupularity_prediction(self, tools_usage):
94 """ 101 """
95 Get the popularity prediction for each tool 102 Get the popularity prediction for each tool