Mercurial > repos > goeckslab > tabular_learner
changeset 3:f6a65e05d6ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit b430f8b466655878c3bf63b053655fdbf039ddb0
author | goeckslab |
---|---|
date | Wed, 09 Jul 2025 01:12:48 +0000 |
parents | 77c88226bfde |
children | |
files | base_model_trainer.py feature_importance.py pycaret_train.py tabular_learner.xml test-data/expected_best_model_classification_customized_cross_off.csv test-data/expected_model_classification_customized_cross_off.h5 |
diffstat | 6 files changed, 58 insertions(+), 10 deletions(-) [+] |
line wrap: on
line diff
--- a/base_model_trainer.py Wed Jul 02 18:59:39 2025 +0000 +++ b/base_model_trainer.py Wed Jul 09 01:12:48 2025 +0000 @@ -127,9 +127,11 @@ and self.cross_validation is not None and self.cross_validation is False ): - self.setup_params["cross_validation"] = self.cross_validation + logging.info( + "cross_validation is set to False. This will disable cross-validation." + ) - if hasattr(self, "cross_validation") and self.cross_validation is not None: + if hasattr(self, "cross_validation") and self.cross_validation: if hasattr(self, "cross_validation_folds"): self.setup_params["fold"] = self.cross_validation_folds @@ -182,10 +184,11 @@ ) if hasattr(self, "models") and self.models is not None: - self.best_model = self.exp.compare_models(include=self.models) + self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation) else: - self.best_model = self.exp.compare_models() + self.best_model = self.exp.compare_models(cross_validation=self.cross_validation) self.results = self.exp.pull() + if self.task_type == "classification": self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) @@ -314,7 +317,7 @@ html_content += ( "</div>" '<div id="summary" class="tab-content">' - "<h2>Model Metrics from Cross-Validation Set</h2>" + f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>" f"<h2>Best Model: {model_name}</h2>" "<h5>The best model is selected by: Accuracy (Classification)" " or R2 (Regression).</h5>"
--- a/feature_importance.py Wed Jul 02 18:59:39 2025 +0000 +++ b/feature_importance.py Wed Jul 09 01:12:48 2025 +0000 @@ -120,6 +120,9 @@ used_features = model.feature_name_ elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"): used_features = model.booster_.feature_name() + elif hasattr(model, "feature_names_in_"): + # scikitālearn's standard attribute for the names of features used during fit + used_features = list(model.feature_names_in_) else: used_features = X_transformed.columns @@ -130,7 +133,14 @@ plot_X = X_shap plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)" else: - sampled_X = X_transformed[used_features].sample(100, random_state=42) + logging.warning(f"len(X_transformed) = {len(X_transformed)}") + max_samples = 100 + n_samples = min(max_samples, len(X_transformed)) + sampled_X = X_transformed[used_features].sample( + n=n_samples, + replace=False, + random_state=42 + ) explainer = shap.KernelExplainer(model.predict, sampled_X) shap_values = explainer.shap_values(sampled_X) plot_X = sampled_X
--- a/pycaret_train.py Wed Jul 02 18:59:39 2025 +0000 +++ b/pycaret_train.py Wed Jul 09 01:12:48 2025 +0000 @@ -29,6 +29,9 @@ parser.add_argument("--cross_validation", action="store_true", default=None, help="Perform cross-validation for PyCaret setup") + parser.add_argument("--no_cross_validation", action="store_true", + default=None, + help="Don't perform cross-validation for PyCaret setup") parser.add_argument("--cross_validation_folds", type=int, default=None, help="Number of cross-validation folds \ @@ -62,11 +65,15 @@ args = parser.parse_args() + cross_validation = True + if args.no_cross_validation: + cross_validation = False + model_kwargs = { "train_size": args.train_size, "normalize": args.normalize, "feature_selection": args.feature_selection, - "cross_validation": args.cross_validation, + "cross_validation": cross_validation, "cross_validation_folds": args.cross_validation_folds, "remove_outliers": args.remove_outliers, "remove_multicollinearity": args.remove_multicollinearity,
--- a/tabular_learner.xml Wed Jul 02 18:59:39 2025 +0000 +++ b/tabular_learner.xml Wed Jul 09 01:12:48 2025 +0000 @@ -28,10 +28,13 @@ --feature_selection #end if #if $enable_cross_validation == "true" - --cross_validation + --cross_validation + #if $cross_validation_folds + --cross_validation_folds '$cross_validation_folds' + #end if #end if - #if $cross_validation_folds - --cross_validation_folds '$cross_validation_folds' + #if $enable_cross_validation == "false" + --no_cross_validation #end if #if $remove_outliers --remove_outliers @@ -183,6 +186,28 @@ <param name="target_feature" value="11"/> <param name="model_type" value="classification"/> <param name="random_seed" value="42"/> + <param name="customize_defaults" value="true"/> + <param name="train_size" value="0.8"/> + <param name="normalize" value="true"/> + <param name="feature_selection" value="true"/> + <param name="enable_cross_validation" value="false"/> + <param name="remove_outliers" value="true"/> + <param name="remove_multicollinearity" value="true"/> + <output name="model" file="expected_model_classification_customized_cross_off.h5" compare="sim_size"/> + <output name="comparison_result"> + <assert_contents> + <has_text text="Validation Result Summary" /> + <has_text text="Test Results" /> + <has_text text="Feature Importance" /> + </assert_contents> + </output> + <output name="best_model_csv" value="expected_best_model_classification_customized_cross_off.csv" /> + </test> + <test> + <param name="input_file" value="pcr.tsv"/> + <param name="target_feature" value="11"/> + <param name="model_type" value="classification"/> + <param name="random_seed" value="42"/> <output name="model" file="expected_model_classification.h5" compare="sim_size"/> <output name="comparison_result"> <assert_contents>