# HG changeset patch
# User goeckslab
# Date 1748979076 0
# Node ID 4aa51153919969deaffe5c3d5d595c763b736c4e
# Parent 02f7746e77728b3b461968e03b1bd5ff76a30072
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit cf47efb521b91a9cb44ae5c5ade860627f9b9030
diff -r 02f7746e7772 -r 4aa511539199 Dockerfile
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/Dockerfile Tue Jun 03 19:31:16 2025 +0000
@@ -0,0 +1,19 @@
+FROM python:3.11-slim
+
+ARG VERSION=3.3.2
+
+# Install necessary dependencies, including libgomp1
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends git unzip libgomp1 && \
+ rm -rf /var/lib/apt/lists/*
+
+# Install Python packages
+RUN pip install -U pip && \
+ pip install --no-cache-dir --no-compile joblib && \
+ pip install --no-cache-dir --no-compile h5py && \
+ pip install --no-cache-dir --no-compile pycaret[analysis,models]==${VERSION} && \
+ pip install --no-cache-dir --no-compile explainerdashboard
+
+# Clean up unnecessary packages
+RUN apt-get -y autoremove && apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
diff -r 02f7746e7772 -r 4aa511539199 LICENSE
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/LICENSE Tue Jun 03 19:31:16 2025 +0000
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 JunhaoQiu
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff -r 02f7746e7772 -r 4aa511539199 README.md
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/README.md Tue Jun 03 19:31:16 2025 +0000
@@ -0,0 +1,106 @@
+# Galaxy-Pycaret
+A library of Galaxy machine learning tools based on PyCaret — part of the Galaxy ML2 tools, aiming to provide simple, powerful, and robust machine learning capabilities for Galaxy users.
+
+# Install Galaxy-Pycaret into Galaxy
+
+* Update `tool_conf.xml` to include Galaxy-Pycaret tools. See [documentation](https://docs.galaxyproject.org/en/master/admin/tool_panel.html) for more details. This is an example:
+```
+<section id="pycaret" name="Pycaret Applications">
+ <tool file="galaxy-pycaret/tools/pycaret_train.xml" />
+</section>
+```
+
+* Configure the `job_conf.yml` under `lib/galaxy/config/sample` to enable the docker for the environment you want the Ludwig related job running in. This is an example:
+```
+execution:
+ default: local
+ environments:
+ local:
+ runner: local
+ docker_enabled: true
+```
+If you are using an older version of Galaxy, then `job_conf.xml` would be something you want to configure instead of `job_conf.yml`. Then you would want to configure destination instead of execution and environment.
+See [documentation](https://docs.galaxyproject.org/en/master/admin/jobs.html#running-jobs-in-containers) for job_conf configuration.
+* If you haven’t set `sanitize_all_html: false` in `galaxy.yml`, please set it to False to enable our HTML report functionality.
+* Should be good to go.
+
+# Make contributions
+
+## Getting Started
+
+To get started, you’ll need to fork the repository, clone it locally, and create a new branch for your contributions.
+
+1. **Fork the Repository**: Click the "Fork" button at the top right of this page.
+2. **Clone the Fork**:
+ ```bash
+ git clone https://github.com/<your-username>/Galaxy-Pycaret.git
+ cd <your-repo>
+ ```
+3. **Create a Feature/hotfix/bugfix Branch**:
+ ```bash
+ git checkout -b feature/<feature-branch-name>
+ ```
+ or
+ ```bash
+ git checkout -b hotfix/<hoxfix-branch-name>
+ ```
+ or
+ ```bash
+ git checkout -b bugfix/<bugfix-branch-name>
+ ```
+
+## How We Manage the Repo
+
+We follow a structured branching and merging strategy to ensure code quality and stability.
+
+1. **Main Branches**:
+ - **`main`**: Contains production-ready code.
+ - **`dev`**: Contains code that is ready for the next release.
+
+2. **Supporting Branches**:
+ - **Feature Branches**: Created from `dev` for new features.
+ - **Bugfix Branches**: Created from `dev` for bug fixes.
+ - **Release Branches**: Created from `dev` when preparing a new release.
+ - **Hotfix Branches**: Created from `main` for critical fixes in production.
+
+### Workflow
+
+- **Feature Development**:
+ - Branch from `dev`.
+ - Work on your feature.
+ - Submit a Pull Request (PR) to `dev`.
+- **Hotfixes**:
+ - Branch from `main`.
+ - Fix the issue.
+ - Merge back into both `main` and `dev`.
+
+## Contribution Guidelines
+
+We welcome contributions of all kinds. To make contributions easy and effective, please follow these guidelines:
+
+1. **Create an Issue**: Before starting work on a major change, create an issue to discuss it.
+2. **Fork and Branch**: Fork the repo and create a feature branch.
+3. **Write Tests**: Ensure your changes are well-tested if applicable.
+4. **Code Style**: Follow the project’s coding conventions.
+5. **Commit Messages**: Write clear and concise commit messages.
+6. **Pull Request**: Submit a PR to the `dev` branch. Ensure your PR description is clear and includes the issue number.
+
+### Submitting a Pull Request
+
+1. **Push your Branch**:
+ ```bash
+ git push origin feature/<feature-branch-name>
+ ```
+2. **Open a Pull Request**:
+ - Navigate to the original repository where you created your fork.
+ - Click on the "New Pull Request" button.
+ - Select `dev` as the base branch and your feature branch as the compare branch.
+ - Fill in the PR template with details about your changes.
+
+3. **Rebase or Merge `dev` into Your Feature Branch**:
+ - Before submitting your PR or when `dev` has been updated, rebase or merge `dev` into your feature branch to ensure your branch is up to date:
+
+4. **Resolve Conflicts**:
+ - If there are any conflicts during the rebase or merge, Git will pause and allow you to resolve the conflicts.
+
+5. **Review Process**: Your PR will be reviewed by a team member. Please address any feedback and update your PR as needed.
\ No newline at end of file
diff -r 02f7746e7772 -r 4aa511539199 base_model_trainer.py
--- a/base_model_trainer.py Wed Jan 01 03:19:40 2025 +0000
+++ b/base_model_trainer.py Tue Jun 03 19:31:16 2025 +0000
@@ -3,18 +3,12 @@
import os
import tempfile
-from feature_importance import FeatureImportanceAnalyzer
-
import h5py
-
import joblib
-
import numpy as np
-
import pandas as pd
-
+from feature_importance import FeatureImportanceAnalyzer
from sklearn.metrics import average_precision_score
-
from utils import get_html_closing, get_html_template
logging.basicConfig(level=logging.DEBUG)
@@ -31,8 +25,7 @@
task_type,
random_seed,
test_file=None,
- **kwargs
- ):
+ **kwargs):
self.exp = None # This will be set in the subclass
self.input_file = input_file
self.target_col = target_col
@@ -71,7 +64,7 @@
LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
names = self.data.columns.to_list()
- target_index = int(self.target_col)-1
+ target_index = int(self.target_col) - 1
self.target = names[target_index]
self.features_name = [name
for i, name in enumerate(names)
@@ -97,7 +90,7 @@
pd.to_numeric, errors='coerce')
self.test_data.columns = self.test_data.columns.str.replace(
'.', '_'
- )
+ )
def setup_pycaret(self):
LOG.info("Initializing PyCaret")
@@ -206,19 +199,22 @@
for k, v in self.setup_params.items() if k not in excluded_params
}
setup_params_table = pd.DataFrame(
- list(filtered_setup_params.items()),
- columns=['Parameter', 'Value'])
+ list(filtered_setup_params.items()), columns=['Parameter', 'Value']
+ )
best_model_params = pd.DataFrame(
self.best_model.get_params().items(),
- columns=['Parameter', 'Value'])
+ columns=['Parameter', 'Value']
+ )
best_model_params.to_csv(
- os.path.join(self.output_dir, 'best_model.csv'),
- index=False)
- self.results.to_csv(os.path.join(
- self.output_dir, "comparison_results.csv"))
- self.test_result_df.to_csv(os.path.join(
- self.output_dir, "test_results.csv"))
+ os.path.join(self.output_dir, "best_model.csv"), index=False
+ )
+ self.results.to_csv(
+ os.path.join(self.output_dir, "comparison_results.csv")
+ )
+ self.test_result_df.to_csv(
+ os.path.join(self.output_dir, "test_results.csv")
+ )
plots_html = ""
length = len(self.plots)
@@ -250,7 +246,8 @@
data=self.data,
target_col=self.target_col,
task_type=self.task_type,
- output_dir=self.output_dir)
+ output_dir=self.output_dir,
+ )
feature_importance_html = analyzer.run()
html_content = f"""
@@ -263,38 +260,37 @@
Best Model Plots</div>
<div class="tab" onclick="openTab(event, 'feature')">
Feature Importance</div>
- """
+ """
if self.plots_explainer_html:
html_content += """
- "<div class="tab" onclick="openTab(event, 'explainer')">"
+ <div class="tab" onclick="openTab(event, 'explainer')">
Explainer Plots</div>
"""
html_content += f"""
</div>
<div id="summary" class="tab-content">
<h2>Setup Parameters</h2>
- <table>
- <tr><th>Parameter</th><th>Value</th></tr>
- {setup_params_table.to_html(
- index=False, header=False, classes='table')}
- </table>
+ {setup_params_table.to_html(
+ index=False,
+ header=True,
+ classes='table sortable'
+ )}
<h5>If you want to know all the experiment setup parameters,
please check the PyCaret documentation for
the classification/regression <code>exp</code> function.</h5>
<h2>Best Model: {model_name}</h2>
- <table>
- <tr><th>Parameter</th><th>Value</th></tr>
- {best_model_params.to_html(
- index=False, header=False, classes='table')}
- </table>
+ {best_model_params.to_html(
+ index=False,
+ header=True,
+ classes='table sortable'
+ )}
<h2>Comparison Results on the Cross-Validation Set</h2>
- <table>
- {self.results.to_html(index=False, classes='table')}
- </table>
+ {self.results.to_html(index=False, classes='table sortable')}
<h2>Results on the Test Set for the best model</h2>
- <table>
- {self.test_result_df.to_html(index=False, classes='table')}
- </table>
+ {self.test_result_df.to_html(
+ index=False,
+ classes='table sortable'
+ )}
</div>
<div id="plots" class="tab-content">
<h2>Best Model Plots on the testing set</h2>
@@ -310,14 +306,66 @@
{self.plots_explainer_html}
{tree_plots}
</div>
- {get_html_closing()}
"""
- else:
- html_content += f"""
- {get_html_closing()}
- """
- with open(os.path.join(
- self.output_dir, "comparison_result.html"), "w") as file:
+ html_content += """
+ <script>
+ document.addEventListener("DOMContentLoaded", function() {
+ var tables = document.querySelectorAll("table.sortable");
+ tables.forEach(function(table) {
+ var headers = table.querySelectorAll("th");
+ headers.forEach(function(header, index) {
+ header.style.cursor = "pointer";
+ // Add initial arrow (up) to indicate sortability
+ header.innerHTML += '<span class="sort-arrow"> ↑</span>';
+ header.addEventListener("click", function() {
+ var direction = this.getAttribute(
+ "data-sort-direction"
+ ) || "asc";
+ // Reset arrows in all headers of this table
+ headers.forEach(function(h) {
+ var arrow = h.querySelector(".sort-arrow");
+ if (arrow) arrow.textContent = " ↑";
+ });
+ // Set arrow for clicked header
+ var arrow = this.querySelector(".sort-arrow");
+ arrow.textContent = direction === "asc" ? " ↓" : " ↑";
+ sortTable(table, index, direction);
+ this.setAttribute("data-sort-direction",
+ direction === "asc" ? "desc" : "asc");
+ });
+ });
+ });
+ });
+
+ function sortTable(table, colNum, direction) {
+ var tb = table.tBodies[0];
+ var tr = Array.prototype.slice.call(tb.rows, 0);
+ var multiplier = direction === "asc" ? 1 : -1;
+ tr = tr.sort(function(a, b) {
+ var aText = a.cells[colNum].textContent.trim();
+ var bText = b.cells[colNum].textContent.trim();
+ // Remove arrow from text comparison
+ aText = aText.replace(/[↑↓]/g, '').trim();
+ bText = bText.replace(/[↑↓]/g, '').trim();
+ if (!isNaN(aText) && !isNaN(bText)) {
+ return multiplier * (
+ parseFloat(aText) - parseFloat(bText)
+ );
+ } else {
+ return multiplier * aText.localeCompare(bText);
+ }
+ });
+ for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);
+ }
+ </script>
+ """
+ html_content += f"""
+ {get_html_closing()}
+ """
+ with open(
+ os.path.join(self.output_dir, "comparison_result.html"),
+ "w"
+ ) as file:
file.write(html_content)
def save_dashboard(self):
diff -r 02f7746e7772 -r 4aa511539199 feature_importance.py
--- a/feature_importance.py Wed Jan 01 03:19:40 2025 +0000
+++ b/feature_importance.py Tue Jun 03 19:31:16 2025 +0000
@@ -3,9 +3,7 @@
import os
import matplotlib.pyplot as plt
-
import pandas as pd
-
from pycaret.classification import ClassificationExperiment
from pycaret.regression import RegressionExperiment
diff -r 02f7746e7772 -r 4aa511539199 pycaret_classification.py
--- a/pycaret_classification.py Wed Jan 01 03:19:40 2025 +0000
+++ b/pycaret_classification.py Tue Jun 03 19:31:16 2025 +0000
@@ -1,11 +1,8 @@
import logging
from base_model_trainer import BaseModelTrainer
-
from dashboard import generate_classifier_explainer_dashboard
-
from pycaret.classification import ClassificationExperiment
-
from utils import add_hr_to_html, add_plot_to_html, predict_proba
LOG = logging.getLogger(__name__)
@@ -64,8 +61,7 @@
'macro': False,
'per_class': False,
'binary': True
- }
- )
+ })
self.plots[plot_name] = plot_path
continue
diff -r 02f7746e7772 -r 4aa511539199 pycaret_macros.xml
--- a/pycaret_macros.xml Wed Jan 01 03:19:40 2025 +0000
+++ b/pycaret_macros.xml Tue Jun 03 19:31:16 2025 +0000
@@ -1,6 +1,6 @@
<macros>
<token name="@PYCARET_VERSION@">3.3.2</token>
- <token name="@SUFFIX@">0</token>
+ <token name="@SUFFIX@">1</token>
<token name="@VERSION@">@PYCARET_VERSION@+@SUFFIX@</token>
<token name="@PROFILE@">21.05</token>
<xml name="python_requirements">
diff -r 02f7746e7772 -r 4aa511539199 pycaret_predict.py
--- a/pycaret_predict.py Wed Jan 01 03:19:40 2025 +0000
+++ b/pycaret_predict.py Tue Jun 03 19:31:16 2025 +0000
@@ -3,16 +3,11 @@
import tempfile
import h5py
-
import joblib
-
import pandas as pd
-
from pycaret.classification import ClassificationExperiment
from pycaret.regression import RegressionExperiment
-
from sklearn.metrics import average_precision_score
-
from utils import encode_image_to_base64, get_html_closing, get_html_template
LOG = logging.getLogger(__name__)
@@ -49,7 +44,7 @@
exp = ClassificationExperiment()
names = data.columns.to_list()
LOG.error(f"Column names: {names}")
- target_index = int(self.target)-1
+ target_index = int(self.target) - 1
target_name = names[target_index]
exp.setup(data, target=target_name, test_data=data, index=False)
exp.add_metric(id='PR-AUC-Weighted',
@@ -73,8 +68,7 @@
'micro': False,
'macro': False,
'per_class': False,
- 'binary': True
- })
+ 'binary': True})
plot_paths[plot_name] = plot_path
continue
@@ -101,7 +95,7 @@
data = pd.read_csv(data_path, engine='python', sep=None)
if self.target:
names = data.columns.to_list()
- target_index = int(self.target)-1
+ target_index = int(self.target) - 1
target_name = names[target_index]
exp = RegressionExperiment()
exp.setup(data, target=target_name, test_data=data, index=False)
diff -r 02f7746e7772 -r 4aa511539199 pycaret_regression.py
--- a/pycaret_regression.py Wed Jan 01 03:19:40 2025 +0000
+++ b/pycaret_regression.py Tue Jun 03 19:31:16 2025 +0000
@@ -1,11 +1,8 @@
import logging
from base_model_trainer import BaseModelTrainer
-
from dashboard import generate_regression_explainer_dashboard
-
from pycaret.regression import RegressionExperiment
-
from utils import add_hr_to_html, add_plot_to_html
LOG = logging.getLogger(__name__)
diff -r 02f7746e7772 -r 4aa511539199 pycaret_train.py
--- a/pycaret_train.py Wed Jan 01 03:19:40 2025 +0000
+++ b/pycaret_train.py Tue Jun 03 19:31:16 2025 +0000
@@ -2,7 +2,6 @@
import logging
from pycaret_classification import ClassificationModelTrainer
-
from pycaret_regression import RegressionModelTrainer
logging.basicConfig(level=logging.DEBUG)
diff -r 02f7746e7772 -r 4aa511539199 pycaret_train.xml
--- a/pycaret_train.xml Wed Jan 01 03:19:40 2025 +0000
+++ b/pycaret_train.xml Tue Jun 03 19:31:16 2025 +0000
@@ -1,5 +1,5 @@
-<tool id="pycaret_compare" name="PyCaret Model Comparison" version="@VERSION@" profile="@PROFILE@">
- <description>compares different machine learning models on a dataset using PyCaret. Do feature analyses using Random Forest and LightGBM. </description>
+<tool id="pycaret_compare" name="Tabular Learner" version="@VERSION@" profile="@PROFILE@">
+ <description>applies and evaluates multiple machine learning models on a tabular dataset</description>
<macros>
<import>pycaret_macros.xml</import>
</macros>
@@ -53,12 +53,12 @@
]]>
</command>
<inputs>
- <param name="input_file" type="data" format="csv,tabular" label="Train Dataset (CSV or TSV)" />
- <param name="test_file" type="data" format="csv,tabular" optional="true" label="Test Dataset (CSV or TSV)"
- help="If a test set is not provided,
- the selected training set will be split into training, validation, and test sets.
- If a test set is provided, the training set will only be split into training and validation sets.
- BTW, cross-validation is always applied by default." />
+ <param name="input_file" type="data" format="csv,tabular" label="Tabular Input Dataset" />
+ <param name="test_file" type="data" format="csv,tabular" optional="true" label="Tabular Test Dataset"
+ help="If a test dataset is not provided,
+ the input dataset will be split into training, validation, and test sets.
+ If a test set is provided, the input dataset will be split into training and validation sets.
+ Cross-validation is applied by default during training." />
<param name="target_feature" multiple="false" type="data_column" use_header_names="true" data_ref="input_file" label="Select the target column:" />
<conditional name="model_selection">
<param name="model_type" type="select" label="Task">
@@ -124,25 +124,25 @@
<option value="true">Yes</option>
</param>
<when value="true">
- <param name="train_size" type="float" value="0.7" min="0.1" max="0.9" label="Train Size" help="Proportion of the dataset to include in the train split." />
+ <param name="train_size" type="float" value="0.7" min="0.1" max="0.9" label="Train Size" help="Proportion of the input dataset to include in the train split." />
<param name="normalize" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Normalize Data" help="Whether to normalize data before training." />
<param name="feature_selection" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Feature Selection" help="Whether to perform feature selection." />
<conditional name="cross_validation">
- <param name="enable_cross_validation" type="select" label="Enable Cross Validation?" help="Select whether to enable cross-validation. Default: Yes" >
+ <param name="enable_cross_validation" type="select" label="Enable Cross Validation?" help="Select whether to enable cross-validation." >
<option value="false" >No</option>
<option value="true" selected="true">Yes</option>
</param>
<when value="true">
- <param name="cross_validation_folds" type="integer" value="10" min="2" max="20" label="Cross Validation Folds" help="Number of folds to use for cross-validation. Default: 10" />
+ <param name="cross_validation_folds" type="integer" value="10" min="2" max="20" label="Cross Validation Folds" help="Number of folds to use for cross-validation." />
</when>
<when value="false">
<!-- No additional parameters to show if the user selects 'No' -->
</when>
</conditional>
- <param name="remove_outliers" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Remove Outliers" help="Whether to remove outliers from the dataset before training. Default: False" />
- <param name="remove_multicollinearity" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Remove Multicollinearity" help="Whether to remove multicollinear features before training. Default: False" />
- <param name="polynomial_features" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Polynomial Features" help="Whether to create polynomial features before training. Default: False" />
- <param name="fix_imbalance" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Fix Imbalance" help="ONLY for classfication! Whether to use SMOTE or similar methods to fix imbalance in the dataset. Default: False" />
+ <param name="remove_outliers" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Remove Outliers" help="Whether to remove outliers from the input dataset before training." />
+ <param name="remove_multicollinearity" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Remove Multicollinearity" help="Whether to remove multicollinear features before training." />
+ <param name="polynomial_features" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Polynomial Features" help="Whether to create polynomial features before training." />
+ <param name="fix_imbalance" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Fix Imbalance" help="ONLY for classfication! Whether to use SMOTE or similar methods to fix imbalance in the input dataset." />
</when>
<when value="false">
<!-- No additional parameters to show if the user selects 'No' -->
@@ -150,9 +150,9 @@
</conditional>
</inputs>
<outputs>
+ <data name="comparison_result" format="html" from_work_dir="comparison_result.html" label="${tool.name} analysis report on ${on_string}"/>
<data name="model" format="h5" from_work_dir="pycaret_model.h5" label="${tool.name} best model on ${on_string}" />
- <data name="comparison_result" format="html" from_work_dir="comparison_result.html" label="${tool.name} Comparison result on ${on_string}"/>
- <data name="best_model_csv" format="csv" from_work_dir="best_model.csv" label="${tool.name} The prams of the best model on ${on_string}" hidden="true" />
+ <data name="best_model_csv" format="csv" from_work_dir="best_model.csv" label="${tool.name} The parameters of the best model on ${on_string}" hidden="true" />
</outputs>
<tests>
<test>
diff -r 02f7746e7772 -r 4aa511539199 utils.py
--- a/utils.py Wed Jan 01 03:19:40 2025 +0000
+++ b/utils.py Tue Jun 03 19:31:16 2025 +0000
@@ -161,4 +161,4 @@
def predict_proba(self, X):
pred = self.predict(X)
- return np.array([1-pred, pred]).T
+ return np.array([1 - pred, pred]).T