changeset 3:f6a65e05d6ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit b430f8b466655878c3bf63b053655fdbf039ddb0
author goeckslab
date Wed, 09 Jul 2025 01:12:48 +0000
parents 77c88226bfde
children
files base_model_trainer.py feature_importance.py pycaret_train.py tabular_learner.xml test-data/expected_best_model_classification_customized_cross_off.csv test-data/expected_model_classification_customized_cross_off.h5
diffstat 6 files changed, 58 insertions(+), 10 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Wed Jul 02 18:59:39 2025 +0000
+++ b/base_model_trainer.py	Wed Jul 09 01:12:48 2025 +0000
@@ -127,9 +127,11 @@
             and self.cross_validation is not None
             and self.cross_validation is False
         ):
-            self.setup_params["cross_validation"] = self.cross_validation
+            logging.info(
+                "cross_validation is set to False. This will disable cross-validation."
+            )
 
-        if hasattr(self, "cross_validation") and self.cross_validation is not None:
+        if hasattr(self, "cross_validation") and self.cross_validation:
             if hasattr(self, "cross_validation_folds"):
                 self.setup_params["fold"] = self.cross_validation_folds
 
@@ -182,10 +184,11 @@
             )
 
         if hasattr(self, "models") and self.models is not None:
-            self.best_model = self.exp.compare_models(include=self.models)
+            self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation)
         else:
-            self.best_model = self.exp.compare_models()
+            self.best_model = self.exp.compare_models(cross_validation=self.cross_validation)
         self.results = self.exp.pull()
+
         if self.task_type == "classification":
             self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
 
@@ -314,7 +317,7 @@
         html_content += (
             "</div>"
             '<div id="summary" class="tab-content">'
-            "<h2>Model Metrics from Cross-Validation Set</h2>"
+            f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>"
             f"<h2>Best Model: {model_name}</h2>"
             "<h5>The best model is selected by: Accuracy (Classification)"
             " or R2 (Regression).</h5>"
--- a/feature_importance.py	Wed Jul 02 18:59:39 2025 +0000
+++ b/feature_importance.py	Wed Jul 09 01:12:48 2025 +0000
@@ -120,6 +120,9 @@
             used_features = model.feature_name_
         elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"):
             used_features = model.booster_.feature_name()
+        elif hasattr(model, "feature_names_in_"):
+            # scikit‐learn's standard attribute for the names of features used during fit
+            used_features = list(model.feature_names_in_)
         else:
             used_features = X_transformed.columns
 
@@ -130,7 +133,14 @@
             plot_X = X_shap
             plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)"
         else:
-            sampled_X = X_transformed[used_features].sample(100, random_state=42)
+            logging.warning(f"len(X_transformed) = {len(X_transformed)}")
+            max_samples = 100
+            n_samples = min(max_samples, len(X_transformed))
+            sampled_X = X_transformed[used_features].sample(
+                n=n_samples,
+                replace=False,
+                random_state=42
+            )
             explainer = shap.KernelExplainer(model.predict, sampled_X)
             shap_values = explainer.shap_values(sampled_X)
             plot_X = sampled_X
--- a/pycaret_train.py	Wed Jul 02 18:59:39 2025 +0000
+++ b/pycaret_train.py	Wed Jul 09 01:12:48 2025 +0000
@@ -29,6 +29,9 @@
     parser.add_argument("--cross_validation", action="store_true",
                         default=None,
                         help="Perform cross-validation for PyCaret setup")
+    parser.add_argument("--no_cross_validation", action="store_true",
+                        default=None,
+                        help="Don't perform cross-validation for PyCaret setup")
     parser.add_argument("--cross_validation_folds", type=int,
                         default=None,
                         help="Number of cross-validation folds \
@@ -62,11 +65,15 @@
 
     args = parser.parse_args()
 
+    cross_validation = True
+    if args.no_cross_validation:
+        cross_validation = False
+
     model_kwargs = {
         "train_size": args.train_size,
         "normalize": args.normalize,
         "feature_selection": args.feature_selection,
-        "cross_validation": args.cross_validation,
+        "cross_validation": cross_validation,
         "cross_validation_folds": args.cross_validation_folds,
         "remove_outliers": args.remove_outliers,
         "remove_multicollinearity": args.remove_multicollinearity,
--- a/tabular_learner.xml	Wed Jul 02 18:59:39 2025 +0000
+++ b/tabular_learner.xml	Wed Jul 09 01:12:48 2025 +0000
@@ -28,10 +28,13 @@
                 --feature_selection
                 #end if
                 #if $enable_cross_validation == "true" 
-                --cross_validation 
+                    --cross_validation
+                    #if $cross_validation_folds
+                        --cross_validation_folds '$cross_validation_folds'
+                    #end if 
                 #end if
-                #if $cross_validation_folds
-                --cross_validation_folds '$cross_validation_folds'
+                #if $enable_cross_validation == "false"
+                    --no_cross_validation
                 #end if
                 #if $remove_outliers
                 --remove_outliers  
@@ -183,6 +186,28 @@
             <param name="target_feature" value="11"/> 
             <param name="model_type" value="classification"/>
             <param name="random_seed" value="42"/>
+            <param name="customize_defaults" value="true"/>
+            <param name="train_size" value="0.8"/>
+            <param name="normalize" value="true"/>
+            <param name="feature_selection" value="true"/>
+            <param name="enable_cross_validation" value="false"/>
+            <param name="remove_outliers" value="true"/>
+            <param name="remove_multicollinearity" value="true"/>
+            <output name="model" file="expected_model_classification_customized_cross_off.h5" compare="sim_size"/>
+            <output name="comparison_result">
+                <assert_contents>
+                    <has_text text="Validation Result Summary" />
+                    <has_text text="Test Results" />
+                    <has_text text="Feature Importance" />
+                </assert_contents>
+            </output>
+            <output name="best_model_csv" value="expected_best_model_classification_customized_cross_off.csv" />
+        </test>
+        <test>
+            <param name="input_file" value="pcr.tsv"/>
+            <param name="target_feature" value="11"/> 
+            <param name="model_type" value="classification"/>
+            <param name="random_seed" value="42"/>
             <output name="model" file="expected_model_classification.h5" compare="sim_size"/>
             <output name="comparison_result"> 
                 <assert_contents>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test-data/expected_best_model_classification_customized_cross_off.csv	Wed Jul 09 01:12:48 2025 +0000
@@ -0,0 +1,3 @@
+Parameter,Value
+priors,
+var_smoothing,1e-09
Binary file test-data/expected_model_classification_customized_cross_off.h5 has changed