annotate base_model_trainer.py @ 2:009b18a75dc3 draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit 9497c4faca7063bcbb6b201ab6d0dd1570f22acb
author goeckslab
date Sat, 14 Dec 2024 23:18:02 +0000
parents 915447b14520
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
1 import base64
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
2 import logging
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
3 import os
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
4 import tempfile
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
5
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
6 from feature_importance import FeatureImportanceAnalyzer
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
7
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
8 import h5py
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
9
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
10 import joblib
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
11
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
12 import numpy as np
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
13
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
14 import pandas as pd
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
15
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
16 from sklearn.metrics import average_precision_score
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
17
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
18 from utils import get_html_closing, get_html_template
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
19
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
20 logging.basicConfig(level=logging.DEBUG)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
21 LOG = logging.getLogger(__name__)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
22
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
23
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
24 class BaseModelTrainer:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
25
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
26 def __init__(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
27 self,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
28 input_file,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
29 target_col,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
30 output_dir,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
31 task_type,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
32 random_seed,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
33 test_file=None,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
34 **kwargs
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
35 ):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
36 self.exp = None # This will be set in the subclass
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
37 self.input_file = input_file
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
38 self.target_col = target_col
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
39 self.output_dir = output_dir
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
40 self.task_type = task_type
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
41 self.random_seed = random_seed
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
42 self.data = None
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
43 self.target = None
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
44 self.best_model = None
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
45 self.results = None
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
46 self.features_name = None
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
47 self.plots = {}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
48 self.expaliner = None
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
49 self.plots_explainer_html = None
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
50 self.trees = []
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
51 for key, value in kwargs.items():
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
52 setattr(self, key, value)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
53 self.setup_params = {}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
54 self.test_file = test_file
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
55 self.test_data = None
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
56
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
57 LOG.info(f"Model kwargs: {self.__dict__}")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
58
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
59 def load_data(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
60 LOG.info(f"Loading data from {self.input_file}")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
61 self.data = pd.read_csv(self.input_file, sep=None, engine='python')
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
62 self.data.columns = self.data.columns.str.replace('.', '_')
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
63
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
64 numeric_cols = self.data.select_dtypes(include=['number']).columns
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
65 non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
66
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
67 self.data[numeric_cols] = self.data[numeric_cols].apply(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
68 pd.to_numeric, errors='coerce')
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
69
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
70 if len(non_numeric_cols) > 0:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
71 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
72
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
73 names = self.data.columns.to_list()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
74 target_index = int(self.target_col)-1
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
75 self.target = names[target_index]
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
76 self.features_name = [name
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
77 for i, name in enumerate(names)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
78 if i != target_index]
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
79 if hasattr(self, 'missing_value_strategy'):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
80 if self.missing_value_strategy == 'mean':
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
81 self.data = self.data.fillna(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
82 self.data.mean(numeric_only=True))
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
83 elif self.missing_value_strategy == 'median':
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
84 self.data = self.data.fillna(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
85 self.data.median(numeric_only=True))
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
86 elif self.missing_value_strategy == 'drop':
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
87 self.data = self.data.dropna()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
88 else:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
89 # Default strategy if not specified
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
90 self.data = self.data.fillna(self.data.median(numeric_only=True))
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
91
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
92 if self.test_file:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
93 LOG.info(f"Loading test data from {self.test_file}")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
94 self.test_data = pd.read_csv(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
95 self.test_file, sep=None, engine='python')
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
96 self.test_data = self.test_data[numeric_cols].apply(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
97 pd.to_numeric, errors='coerce')
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
98 self.test_data.columns = self.test_data.columns.str.replace(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
99 '.', '_'
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
100 )
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
101
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
102 def setup_pycaret(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
103 LOG.info("Initializing PyCaret")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
104 self.setup_params = {
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
105 'target': self.target,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
106 'session_id': self.random_seed,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
107 'html': True,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
108 'log_experiment': False,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
109 'system_log': False,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
110 'index': False,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
111 }
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
112
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
113 if self.test_data is not None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
114 self.setup_params['test_data'] = self.test_data
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
115
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
116 if hasattr(self, 'train_size') and self.train_size is not None \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
117 and self.test_data is None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
118 self.setup_params['train_size'] = self.train_size
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
119
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
120 if hasattr(self, 'normalize') and self.normalize is not None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
121 self.setup_params['normalize'] = self.normalize
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
122
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
123 if hasattr(self, 'feature_selection') and \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
124 self.feature_selection is not None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
125 self.setup_params['feature_selection'] = self.feature_selection
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
126
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
127 if hasattr(self, 'cross_validation') and \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
128 self.cross_validation is not None \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
129 and self.cross_validation is False:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
130 self.setup_params['cross_validation'] = self.cross_validation
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
131
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
132 if hasattr(self, 'cross_validation') and \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
133 self.cross_validation is not None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
134 if hasattr(self, 'cross_validation_folds'):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
135 self.setup_params['fold'] = self.cross_validation_folds
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
136
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
137 if hasattr(self, 'remove_outliers') and \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
138 self.remove_outliers is not None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
139 self.setup_params['remove_outliers'] = self.remove_outliers
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
140
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
141 if hasattr(self, 'remove_multicollinearity') and \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
142 self.remove_multicollinearity is not None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
143 self.setup_params['remove_multicollinearity'] = \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
144 self.remove_multicollinearity
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
145
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
146 if hasattr(self, 'polynomial_features') and \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
147 self.polynomial_features is not None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
148 self.setup_params['polynomial_features'] = self.polynomial_features
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
149
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
150 if hasattr(self, 'fix_imbalance') and \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
151 self.fix_imbalance is not None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
152 self.setup_params['fix_imbalance'] = self.fix_imbalance
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
153
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
154 LOG.info(self.setup_params)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
155 self.exp.setup(self.data, **self.setup_params)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
156
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
157 def train_model(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
158 LOG.info("Training and selecting the best model")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
159 if self.task_type == "classification":
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
160 average_displayed = "Weighted"
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
161 self.exp.add_metric(id=f'PR-AUC-{average_displayed}',
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
162 name=f'PR-AUC-{average_displayed}',
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
163 target='pred_proba',
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
164 score_func=average_precision_score,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
165 average='weighted'
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
166 )
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
167
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
168 if hasattr(self, 'models') and self.models is not None:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
169 self.best_model = self.exp.compare_models(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
170 include=self.models)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
171 else:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
172 self.best_model = self.exp.compare_models()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
173 self.results = self.exp.pull()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
174 if self.task_type == "classification":
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
175 self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
176
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
177 _ = self.exp.predict_model(self.best_model)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
178 self.test_result_df = self.exp.pull()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
179 if self.task_type == "classification":
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
180 self.test_result_df.rename(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
181 columns={'AUC': 'ROC-AUC'}, inplace=True)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
182
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
183 def save_model(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
184 hdf5_model_path = "pycaret_model.h5"
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
185 with h5py.File(hdf5_model_path, 'w') as f:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
186 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
187 joblib.dump(self.best_model, temp_file.name)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
188 temp_file.seek(0)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
189 model_bytes = temp_file.read()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
190 f.create_dataset('model', data=np.void(model_bytes))
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
191
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
192 def generate_plots(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
193 raise NotImplementedError("Subclasses should implement this method")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
194
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
195 def encode_image_to_base64(self, img_path):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
196 with open(img_path, 'rb') as img_file:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
197 return base64.b64encode(img_file.read()).decode('utf-8')
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
198
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
199 def save_html_report(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
200 LOG.info("Saving HTML report")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
201
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
202 model_name = type(self.best_model).__name__
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
203 excluded_params = ['html', 'log_experiment', 'system_log', 'test_data']
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
204 filtered_setup_params = {
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
205 k: v
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
206 for k, v in self.setup_params.items() if k not in excluded_params
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
207 }
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
208 setup_params_table = pd.DataFrame(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
209 list(filtered_setup_params.items()),
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
210 columns=['Parameter', 'Value'])
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
211
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
212 best_model_params = pd.DataFrame(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
213 self.best_model.get_params().items(),
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
214 columns=['Parameter', 'Value'])
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
215 best_model_params.to_csv(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
216 os.path.join(self.output_dir, 'best_model.csv'),
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
217 index=False)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
218 self.results.to_csv(os.path.join(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
219 self.output_dir, "comparison_results.csv"))
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
220 self.test_result_df.to_csv(os.path.join(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
221 self.output_dir, "test_results.csv"))
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
222
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
223 plots_html = ""
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
224 length = len(self.plots)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
225 for i, (plot_name, plot_path) in enumerate(self.plots.items()):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
226 encoded_image = self.encode_image_to_base64(plot_path)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
227 plots_html += f"""
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
228 <div class="plot">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
229 <h3>{plot_name.capitalize()}</h3>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
230 <img src="data:image/png;base64,{encoded_image}"
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
231 alt="{plot_name}">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
232 </div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
233 """
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
234 if i < length - 1:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
235 plots_html += "<hr>"
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
236
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
237 tree_plots = ""
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
238 for i, tree in enumerate(self.trees):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
239 if tree:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
240 tree_plots += f"""
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
241 <div class="plot">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
242 <h3>Tree {i+1}</h3>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
243 <img src="data:image/png;base64,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
244 {tree}"
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
245 alt="tree {i+1}">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
246 </div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
247 """
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
248
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
249 analyzer = FeatureImportanceAnalyzer(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
250 data=self.data,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
251 target_col=self.target_col,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
252 task_type=self.task_type,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
253 output_dir=self.output_dir)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
254 feature_importance_html = analyzer.run()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
255
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
256 html_content = f"""
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
257 {get_html_template()}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
258 <h1>PyCaret Model Training Report</h1>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
259 <div class="tabs">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
260 <div class="tab" onclick="openTab(event, 'summary')">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
261 Setup & Best Model</div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
262 <div class="tab" onclick="openTab(event, 'plots')">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
263 Best Model Plots</div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
264 <div class="tab" onclick="openTab(event, 'feature')">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
265 Feature Importance</div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
266 <div class="tab" onclick="openTab(event, 'explainer')">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
267 Explainer
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
268 </div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
269 </div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
270 <div id="summary" class="tab-content">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
271 <h2>Setup Parameters</h2>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
272 <table>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
273 <tr><th>Parameter</th><th>Value</th></tr>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
274 {setup_params_table.to_html(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
275 index=False, header=False, classes='table')}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
276 </table>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
277 <h5>If you want to know all the experiment setup parameters,
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
278 please check the PyCaret documentation for
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
279 the classification/regression <code>exp</code> function.</h5>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
280 <h2>Best Model: {model_name}</h2>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
281 <table>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
282 <tr><th>Parameter</th><th>Value</th></tr>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
283 {best_model_params.to_html(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
284 index=False, header=False, classes='table')}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
285 </table>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
286 <h2>Comparison Results on the Cross-Validation Set</h2>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
287 <table>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
288 {self.results.to_html(index=False, classes='table')}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
289 </table>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
290 <h2>Results on the Test Set for the best model</h2>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
291 <table>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
292 {self.test_result_df.to_html(index=False, classes='table')}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
293 </table>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
294 </div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
295 <div id="plots" class="tab-content">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
296 <h2>Best Model Plots on the testing set</h2>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
297 {plots_html}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
298 </div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
299 <div id="feature" class="tab-content">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
300 {feature_importance_html}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
301 </div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
302 <div id="explainer" class="tab-content">
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
303 {self.plots_explainer_html}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
304 {tree_plots}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
305 </div>
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
306 {get_html_closing()}
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
307 """
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
308
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
309 with open(os.path.join(
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
310 self.output_dir, "comparison_result.html"), "w") as file:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
311 file.write(html_content)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
312
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
313 def save_dashboard(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
314 raise NotImplementedError("Subclasses should implement this method")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
315
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
316 def generate_plots_explainer(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
317 raise NotImplementedError("Subclasses should implement this method")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
318
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
319 # not working now
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
320 def generate_tree_plots(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
321 from sklearn.ensemble import RandomForestClassifier, \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
322 RandomForestRegressor
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
323 from xgboost import XGBClassifier, XGBRegressor
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
324 from explainerdashboard.explainers import RandomForestExplainer
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
325
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
326 LOG.info("Generating tree plots")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
327 X_test = self.exp.X_test_transformed.copy()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
328 y_test = self.exp.y_test_transformed
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
329
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
330 is_rf = isinstance(self.best_model, RandomForestClassifier) or \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
331 isinstance(self.best_model, RandomForestRegressor)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
332
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
333 is_xgb = isinstance(self.best_model, XGBClassifier) or \
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
334 isinstance(self.best_model, XGBRegressor)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
335
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
336 try:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
337 if is_rf:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
338 num_trees = self.best_model.n_estimators
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
339 if is_xgb:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
340 num_trees = len(self.best_model.get_booster().get_dump())
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
341 explainer = RandomForestExplainer(self.best_model, X_test, y_test)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
342 for i in range(num_trees):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
343 fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
344 LOG.info(f"Tree {i+1}")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
345 LOG.info(fig)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
346 self.trees.append(fig)
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
347 except Exception as e:
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
348 LOG.error(f"Error generating tree plots: {e}")
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
349
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
350 def run(self):
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
351 self.load_data()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
352 self.setup_pycaret()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
353 self.train_model()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
354 self.save_model()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
355 self.generate_plots()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
356 self.generate_plots_explainer()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
357 self.generate_tree_plots()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
358 self.save_html_report()
915447b14520 planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
goeckslab
parents:
diff changeset
359 # self.save_dashboard()