diff predict_tool_usage.py @ 6:e94dc7945639 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/recommendation_training/tools/tool_recommendation_model commit 24bab7a797f53fe4bcc668b18ee0326625486164
author bgruening
date Sun, 16 Oct 2022 11:52:10 +0000
parents 4f7e6612906b
children
line wrap: on
line diff
--- a/predict_tool_usage.py	Fri May 06 09:05:18 2022 +0000
+++ b/predict_tool_usage.py	Sun Oct 16 11:52:10 2022 +0000
@@ -3,9 +3,6 @@
 """
 
 import collections
-import csv
-import os
-import warnings
 
 import numpy as np
 import utils
@@ -13,40 +10,36 @@
 from sklearn.pipeline import Pipeline
 from sklearn.svm import SVR
 
-warnings.filterwarnings("ignore")
-
-main_path = os.getcwd()
-
 
 class ToolPopularity:
+
     def __init__(self):
         """ Init method. """
 
-    def extract_tool_usage(self, tool_usage_file, cutoff_date, dictionary):
+    def extract_tool_usage(self, tool_usage_df, cutoff_date, dictionary):
         """
         Extract the tool usage over time for each tool
         """
         tool_usage_dict = dict()
         all_dates = list()
         all_tool_list = list(dictionary.keys())
-        with open(tool_usage_file, "rt") as usage_file:
-            tool_usage = csv.reader(usage_file, delimiter="\t")
-            for index, row in enumerate(tool_usage):
-                row = [item.strip() for item in row]
-                if (str(row[1]).strip() > cutoff_date) is True:
-                    tool_id = utils.format_tool_id(row[0])
-                    if tool_id in all_tool_list:
-                        all_dates.append(row[1])
-                        if tool_id not in tool_usage_dict:
-                            tool_usage_dict[tool_id] = dict()
-                            tool_usage_dict[tool_id][row[1]] = int(row[2])
+        for index, row in tool_usage_df.iterrows():
+            row = row.tolist()
+            row = [str(item).strip() for item in row]
+            if (row[1] > cutoff_date) is True:
+                tool_id = utils.format_tool_id(row[0])
+                if tool_id in all_tool_list:
+                    all_dates.append(row[1])
+                    if tool_id not in tool_usage_dict:
+                        tool_usage_dict[tool_id] = dict()
+                        tool_usage_dict[tool_id][row[1]] = int(float(row[2]))
+                    else:
+                        curr_date = row[1]
+                        # merge the usage of different version of tools into one
+                        if curr_date in tool_usage_dict[tool_id]:
+                            tool_usage_dict[tool_id][curr_date] += int(float(row[2]))
                         else:
-                            curr_date = row[1]
-                            # merge the usage of different version of tools into one
-                            if curr_date in tool_usage_dict[tool_id]:
-                                tool_usage_dict[tool_id][curr_date] += int(row[2])
-                            else:
-                                tool_usage_dict[tool_id][curr_date] = int(row[2])
+                            tool_usage_dict[tool_id][curr_date] = int(float(row[2]))
         # get unique dates
         unique_dates = list(set(all_dates))
         for tool in tool_usage_dict:
@@ -66,25 +59,17 @@
         """
         epsilon = 0.0
         cv = 5
-        s_typ = "neg_mean_absolute_error"
+        s_typ = 'neg_mean_absolute_error'
         n_jobs = 4
         s_error = 1
         tr_score = False
         try:
-            pipe = Pipeline(steps=[("regressor", SVR(gamma="scale"))])
+            pipe = Pipeline(steps=[('regressor', SVR(gamma='scale'))])
             param_grid = {
-                "regressor__kernel": ["rbf", "poly", "linear"],
-                "regressor__degree": [2, 3],
+                'regressor__kernel': ['rbf', 'poly', 'linear'],
+                'regressor__degree': [2, 3]
             }
-            search = GridSearchCV(
-                pipe,
-                param_grid,
-                cv=cv,
-                scoring=s_typ,
-                n_jobs=n_jobs,
-                error_score=s_error,
-                return_train_score=tr_score,
-            )
+            search = GridSearchCV(pipe, param_grid, cv=cv, scoring=s_typ, n_jobs=n_jobs, error_score=s_error, return_train_score=tr_score)
             search.fit(x_reshaped, y_reshaped.ravel())
             model = search.best_estimator_
             # set the next time point to get prediction for