comparison predict_tool_usage.py @ 6:e94dc7945639 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author bgruening
date Sun, 16 Oct 2022 11:52:10 +0000
parents 4f7e6612906b
children
comparison
equal deleted inserted replaced
5:4f7e6612906b 6:e94dc7945639
1 """ 1 """
2 Predict tool usage to weigh the predicted tools 2 Predict tool usage to weigh the predicted tools
3 """ 3 """
4 4
5 import collections 5 import collections
6 import csv
7 import os
8 import warnings
9 6
10 import numpy as np 7 import numpy as np
11 import utils 8 import utils
12 from sklearn.model_selection import GridSearchCV 9 from sklearn.model_selection import GridSearchCV
13 from sklearn.pipeline import Pipeline 10 from sklearn.pipeline import Pipeline
14 from sklearn.svm import SVR 11 from sklearn.svm import SVR
15 12
16 warnings.filterwarnings("ignore")
17
18 main_path = os.getcwd()
19
20 13
21 class ToolPopularity: 14 class ToolPopularity:
15
22 def __init__(self): 16 def __init__(self):
23 """ Init method. """ 17 """ Init method. """
24 18
25 def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary): 19 def extract_tool_usage(self, tool_usage_df, cutoff_date, dictionary):
26 """ 20 """
27 Extract the tool usage over time for each tool 21 Extract the tool usage over time for each tool
28 """ 22 """
29 tool_usage_dict = dict() 23 tool_usage_dict = dict()
30 all_dates = list() 24 all_dates = list()
31 all_tool_list = list(dictionary.keys()) 25 all_tool_list = list(dictionary.keys())
32 with open(tool_usage_file, "rt") as usage_file: 26 for index, row in tool_usage_df.iterrows():
33 tool_usage = csv.reader(usage_file, delimiter="\t") 27 row = row.tolist()
34 for index, row in enumerate(tool_usage): 28 row = [str(item).strip() for item in row]
35 row = [item.strip() for item in row] 29 if (row[1] > cutoff_date) is True:
36 if (str(row[1]).strip() > cutoff_date) is True: 30 tool_id = utils.format_tool_id(row[0])
37 tool_id = utils.format_tool_id(row[0]) 31 if tool_id in all_tool_list:
38 if tool_id in all_tool_list: 32 all_dates.append(row[1])
39 all_dates.append(row[1]) 33 if tool_id not in tool_usage_dict:
40 if tool_id not in tool_usage_dict: 34 tool_usage_dict[tool_id] = dict()
41 tool_usage_dict[tool_id] = dict() 35 tool_usage_dict[tool_id][row[1]] = int(float(row[2]))
42 tool_usage_dict[tool_id][row[1]] = int(row[2]) 36 else:
37 curr_date = row[1]
38 # merge the usage of different version of tools into one
39 if curr_date in tool_usage_dict[tool_id]:
40 tool_usage_dict[tool_id][curr_date] += int(float(row[2]))
43 else: 41 else:
44 curr_date = row[1] 42 tool_usage_dict[tool_id][curr_date] = int(float(row[2]))
45 # merge the usage of different version of tools into one
46 if curr_date in tool_usage_dict[tool_id]:
47 tool_usage_dict[tool_id][curr_date] += int(row[2])
48 else:
49 tool_usage_dict[tool_id][curr_date] = int(row[2])
50 # get unique dates 43 # get unique dates
51 unique_dates = list(set(all_dates)) 44 unique_dates = list(set(all_dates))
52 for tool in tool_usage_dict: 45 for tool in tool_usage_dict:
53 usage = tool_usage_dict[tool] 46 usage = tool_usage_dict[tool]
54 # extract those dates for which tool's usage is not present in raw data 47 # extract those dates for which tool's usage is not present in raw data
64 """ 57 """
65 Fit a curve for the tool usage over time to predict future tool usage 58 Fit a curve for the tool usage over time to predict future tool usage
66 """ 59 """
67 epsilon = 0.0 60 epsilon = 0.0
68 cv = 5 61 cv = 5
69 s_typ = "neg_mean_absolute_error" 62 s_typ = 'neg_mean_absolute_error'
70 n_jobs = 4 63 n_jobs = 4
71 s_error = 1 64 s_error = 1
72 tr_score = False 65 tr_score = False
73 try: 66 try:
74 pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))]) 67 pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))])
75 param_grid = { 68 param_grid = {
76 "regressor__kernel": ["rbf", "poly", "linear"], 69 'regressor__kernel': ['rbf', 'poly', 'linear'],
77 "regressor__degree": [2, 3], 70 'regressor__degree': [2, 3]
78 } 71 }
79 search = GridSearchCV( 72 search = GridSearchCV(pipe, param_grid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score)
80 pipe,
81 param_grid,
82 cv=cv,
83 scoring=s_typ,
84 n_jobs=n_jobs,
85 error_score=s_error,
86 return_train_score=tr_score,
87 )
88 search.fit(x_reshaped, y_reshaped.ravel()) 73 search.fit(x_reshaped, y_reshaped.ravel())
89 model = search.best_estimator_ 74 model = search.best_estimator_
90 # set the next time point to get prediction for 75 # set the next time point to get prediction for
91 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1)) 76 prediction_point = np.reshape([x_reshaped[-1][0] + 1], (1, 1))
92 prediction = model.predict(prediction_point) 77 prediction = model.predict(prediction_point)