Mercurial > repos > goeckslab > multimodal_learner
changeset 0:375c36923da1 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit 1c6c1ad7a1b2bd3645aa0eafa2167784820b52e0
| author | goeckslab |
|---|---|
| date | Tue, 09 Dec 2025 23:49:47 +0000 |
| parents | |
| children | |
| files | Dockerfile LICENSE README.md feature_help_modal.py feature_importance.py metrics_logic.py multimodal_learner.py multimodal_learner.xml plot_logic.py report_utils.py split_logic.py test-data/images.zip test-data/sample_output.html test-data/test.csv test-data/train.csv test_pipeline.py training_pipeline.py utils.py |
| diffstat | 18 files changed, 6826 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Dockerfile Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,89 @@ +FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive +ENV NVIDIA_VISIBLE_DEVICES=all +ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility + +# Install system dependencies, Python, and Chromium bits needed for Kaleido +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + python3-venv \ + ca-certificates \ + build-essential \ + gnupg \ + libblas-dev \ + liblapack-dev \ + libgomp1 \ + libopenblas-dev \ + unzip \ + pkg-config \ + libfreetype6-dev \ + libpng-dev \ + libqhull-dev \ + fonts-liberation \ + libasound2 \ + libatk-bridge2.0-0 \ + libatk1.0-0 \ + libcairo2 \ + libcups2 \ + libdbus-1-3 \ + libexpat1 \ + libfontconfig1 \ + libgbm1 \ + libgcc1 \ + libglib2.0-0 \ + libgtk-3-0 \ + libnspr4 \ + libnss3 \ + libpango-1.0-0 \ + libpangocairo-1.0-0 \ + libstdc++6 \ + libx11-6 \ + libx11-xcb1 \ + libxcb1 \ + libxcomposite1 \ + libxcursor1 \ + libxdamage1 \ + libxext6 \ + libxfixes3 \ + libxi6 \ + libxrandr2 \ + libxrender1 \ + libxtst6 \ + lsb-release \ + wget \ + xdg-utils \ + && ln -sf /usr/bin/python3 /usr/bin/python \ + && ln -sf /usr/bin/pip3 /usr/bin/pip \ + && rm -rf /var/lib/apt/lists/* + +# Pin setuptools <81 to avoid pkg_resources warnings and upgrade pip/wheel +RUN pip install --no-cache-dir 'setuptools<81.0.0' && \ + pip install --no-cache-dir --upgrade pip wheel + +# Install GPU-enabled PyTorch stack before AutoGluon so it picks up CUDA +RUN pip install --no-cache-dir \ + --index-url https://download.pytorch.org/whl/cu118 \ + --extra-index-url https://pypi.org/simple \ + torch==2.1.2 \ + torchvision==0.16.2 \ + torchaudio==2.1.2 + +# Core Python dependencies +RUN pip install --no-cache-dir \ + autogluon==1.4.0 \ + pyarrow==20.0.0 \ + matplotlib \ + seaborn \ + shap + +# Update kaleido for plotting support +RUN pip install --no-cache-dir --upgrade kaleido + +RUN apt-get update && apt-get install -y curl \ + && curl -fsSL https://deb.nodesource.com/setup_18.x | bash - \ + && apt-get install -y nodejs \ + && npm install -g yarn \ + && rm -rf /var/lib/apt/lists/*
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/LICENSE Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + <one line to give the program's name and a brief idea of what it does.> + Copyright (C) <year> <name of author> + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <https://www.gnu.org/licenses/>. + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + <program> Copyright (C) <year> <name of author> + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +<https://www.gnu.org/licenses/>. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +<https://www.gnu.org/licenses/why-not-lgpl.html>.
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/README.md Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,2 @@ +# Galaxy-AutoGluon +This repository provides tools to integrate Autogluon, an automated machine learning framework, into the Galaxy workflow environment
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/feature_help_modal.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,130 @@ +import base64 + + +def get_metrics_help_modal() -> str: + # The HTML structure of the modal + modal_html = """ +<div id="metricsHelpModal" class="modal"> + <div class="modal-content"> + <span class="close">×</span> + <h2>How to read this Multimodal Learner report</h2> + <div class="metrics-guide"> + <h3>Tabs & layout</h3> + <p><strong>Model Metric Summary and Config:</strong> Top-level metrics and the key run settings (target column, backbones, presets).</p> + <p><strong>Train and Validation Summary:</strong> Learning curves plus combined ROC/PR/Calibration (binary), and any remaining diagnostics.</p> + <p><strong>Test Summary:</strong> Test metrics table followed by the ROC/PR charts with your chosen threshold marked, and the Prediction Confidence histogram.</p> + + <h3>Dataset Overview</h3> + <p>Shows label counts across Train/Validation/Test so you can quickly spot imbalance or missing splits.</p> + + <h3>Learning curves</h3> + <p><strong>Label Accuracy & Loss:</strong> Train (blue) and Validation (orange) trends. Parallel curves that plateau suggest stable training; large gaps can indicate overfitting.</p> + + <h3>Binary diagnostics (Train vs Validation)</h3> + <p><strong>ROC Curve:</strong> Both splits on one plot. Higher and leftward is better. The red “x” marks the decision threshold when provided.</p> + <p><strong>Precision–Recall:</strong> Both splits on one plot; more informative on imbalance. Red marker shows the threshold point.</p> + <p><strong>Calibration:</strong> Ideally near the diagonal; deviations show over/under-confidence.</p> + <p><strong>Threshold Plot (Validation):</strong> Explore precision/recall/F1 vs threshold; use to pick a balanced operating point.</p> + + <h3>Test tab highlights</h3> + <p><strong>Metrics table:</strong> Thresholded metrics for the test set.</p> + <p><strong>ROC & PR:</strong> Thick lines, red marker and annotation for the selected threshold.</p> + <p><strong>Prediction Confidence:</strong> Histogram of max predicted probabilities (as % of samples) to spot over/under-confidence.</p> + + <h3>Threshold tips</h3> + <ul> + <li>Use the Validation curves to choose a threshold that balances precision/recall for your use case.</li> + <li>Threshold marker/annotation appears on ROC/PR plots when you pass <code>--threshold</code> (binary tasks).</li> + </ul> + + <h3>When to worry</h3> + <ul> + <li>Huge train/val gaps on learning curves → possible overfitting.</li> + <li>Calibration far from diagonal → predicted probabilities may be poorly calibrated.</li> + <li>Very imbalanced label counts → focus on PR curves and per-class metrics (if enabled).</li> + </ul> + </div> + </div> +</div> +""" + # The CSS needed to style and hide/show the modal + modal_css = """ +<style> +.modal { + display: none; + position: fixed; + z-index: 1; + left: 0; + top: 0; + width: 100%; + height: 100%; + overflow: auto; + background-color: rgba(0,0,0,0.4); +} +.modal-content { + background-color: #fefefe; + margin: 15% auto; + padding: 20px; + border: 1px solid #888; + width: 80%; + max-width: 800px; +} +.close { + color: #aaa; + float: right; + font-size: 28px; + font-weight: bold; +} +.close:hover, +.close:focus { + color: black; + text-decoration: none; + cursor: pointer; +} +.metrics-guide h3 { + margin-top: 20px; +} +.metrics-guide p { + margin: 5px 0; +} +.metrics-guide ul { + margin: 10px 0; + padding-left: 20px; +} +</style> +""" + # The JavaScript to open/close the modal on button click + modal_js = """ +<script> +document.addEventListener("DOMContentLoaded", function() { + var modal = document.getElementById("metricsHelpModal"); + var openBtn = document.getElementById("openMetricsHelp"); + var span = document.getElementsByClassName("close")[0]; + if (openBtn && modal) { + openBtn.onclick = function() { + modal.style.display = "block"; + }; + } + if (span && modal) { + span.onclick = function() { + modal.style.display = "none"; + }; + } + window.onclick = function(event) { + if (event.target == modal) { + modal.style.display = "none"; + } + } +}); +</script> +""" + return modal_css + modal_html + modal_js + + +def encode_image_to_base64(image_path): + with open(image_path, "rb") as img_file: + return base64.b64encode(img_file.read()).decode("utf-8") + + +def generate_feature_importance(*args, **kwargs): + return "<p><em>Feature importance visualizations are not supported for this MultiModal workflow.</em></p>"
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/feature_importance.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,11 @@ +"""Feature importance visualization utilities.""" + +import pandas as pd + + +def build_feature_importance_html(predictor, df_train: pd.DataFrame, label_column: str) -> str: + """Feature importance is not currently available for the MultiModal workflow.""" + return ( + "<p><em>Feature importance visualization is not supported for the current " + "MultiModal workflow.</em></p>" + )
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/metrics_logic.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,313 @@ +from collections import OrderedDict +from typing import Dict, Optional, Tuple + +import numpy as np +import pandas as pd +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + cohen_kappa_score, + confusion_matrix, + f1_score, + log_loss, + matthews_corrcoef, + mean_absolute_error, + mean_squared_error, + median_absolute_error, + precision_score, + r2_score, + recall_score, + roc_auc_score, +) + + +# -------------------- Transparent Metrics (task-aware) -------------------- # + +def _safe_y_proba_to_array(y_proba) -> Optional[np.ndarray]: + """Convert predictor.predict_proba output (array/DataFrame/dict) to np.ndarray or None.""" + if y_proba is None: + return None + if isinstance(y_proba, pd.DataFrame): + return y_proba.values + if isinstance(y_proba, (list, tuple)): + return np.asarray(y_proba) + if isinstance(y_proba, np.ndarray): + return y_proba + if isinstance(y_proba, dict): + try: + return np.vstack([np.asarray(v) for _, v in sorted(y_proba.items())]).T + except Exception: + return None + return None + + +def _specificity_from_cm(cm: np.ndarray) -> float: + """Specificity (TNR) for binary confusion matrix.""" + if cm.shape != (2, 2): + return np.nan + tn, fp, fn, tp = cm.ravel() + denom = (tn + fp) + return float(tn / denom) if denom > 0 else np.nan + + +def _compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> "OrderedDict[str, float]": + mse = mean_squared_error(y_true, y_pred) + rmse = float(np.sqrt(mse)) + mae = mean_absolute_error(y_true, y_pred) + # Avoid division by zero using clip + mape = float(np.mean(np.abs((y_true - y_pred) / np.clip(np.abs(y_true), 1e-12, None))) * 100.0) + r2 = r2_score(y_true, y_pred) + medae = median_absolute_error(y_true, y_pred) + + metrics = OrderedDict() + metrics["MSE"] = mse + metrics["RMSE"] = rmse + metrics["MAE"] = mae + metrics["MAPE_%"] = mape + metrics["R2"] = r2 + metrics["MedianAE"] = medae + return metrics + + +def _compute_binary_metrics( + y_true: pd.Series, + y_pred: pd.Series, + y_proba: Optional[np.ndarray], + predictor +) -> "OrderedDict[str, float]": + metrics = OrderedDict() + classes_sorted = np.sort(pd.unique(y_true)) + # Choose the lexicographically larger class as "positive" + pos_label = classes_sorted[-1] + + metrics["Accuracy"] = accuracy_score(y_true, y_pred) + metrics["Precision"] = precision_score(y_true, y_pred, pos_label=pos_label, zero_division=0) + metrics["Recall_(Sensitivity/TPR)"] = recall_score(y_true, y_pred, pos_label=pos_label, zero_division=0) + metrics["F1-Score"] = f1_score(y_true, y_pred, pos_label=pos_label, zero_division=0) + + try: + cm = confusion_matrix(y_true, y_pred, labels=classes_sorted) + metrics["Specificity_(TNR)"] = _specificity_from_cm(cm) + except Exception: + metrics["Specificity_(TNR)"] = np.nan + + # Probabilistic metrics + if y_proba is not None: + # pick column of positive class + if y_proba.ndim == 1: + pos_scores = y_proba + else: + pos_col_idx = -1 + try: + if hasattr(predictor, "class_labels") and predictor.class_labels: + pos_col_idx = list(predictor.class_labels).index(pos_label) + except Exception: + pos_col_idx = -1 + pos_scores = y_proba[:, pos_col_idx] + try: + metrics["ROC-AUC"] = roc_auc_score(y_true == pos_label, pos_scores) + except Exception: + metrics["ROC-AUC"] = np.nan + try: + metrics["PR-AUC"] = average_precision_score(y_true == pos_label, pos_scores) + except Exception: + metrics["PR-AUC"] = np.nan + try: + if y_proba.ndim == 1: + y_proba_ll = np.column_stack([1 - pos_scores, pos_scores]) + else: + y_proba_ll = y_proba + metrics["LogLoss"] = log_loss(y_true, y_proba_ll, labels=classes_sorted) + except Exception: + metrics["LogLoss"] = np.nan + else: + metrics["ROC-AUC"] = np.nan + metrics["PR-AUC"] = np.nan + metrics["LogLoss"] = np.nan + + try: + metrics["MCC"] = matthews_corrcoef(y_true, y_pred) + except Exception: + metrics["MCC"] = np.nan + + return metrics + + +def _compute_multiclass_metrics( + y_true: pd.Series, + y_pred: pd.Series, + y_proba: Optional[np.ndarray] +) -> "OrderedDict[str, float]": + metrics = OrderedDict() + metrics["Accuracy"] = accuracy_score(y_true, y_pred) + metrics["Macro Precision"] = precision_score(y_true, y_pred, average="macro", zero_division=0) + metrics["Macro Recall"] = recall_score(y_true, y_pred, average="macro", zero_division=0) + metrics["Macro F1"] = f1_score(y_true, y_pred, average="macro", zero_division=0) + metrics["Weighted Precision"] = precision_score(y_true, y_pred, average="weighted", zero_division=0) + metrics["Weighted Recall"] = recall_score(y_true, y_pred, average="weighted", zero_division=0) + metrics["Weighted F1"] = f1_score(y_true, y_pred, average="weighted", zero_division=0) + + try: + metrics["Cohen_Kappa"] = cohen_kappa_score(y_true, y_pred) + except Exception: + metrics["Cohen_Kappa"] = np.nan + try: + metrics["MCC"] = matthews_corrcoef(y_true, y_pred) + except Exception: + metrics["MCC"] = np.nan + + # Probabilistic metrics + classes_sorted = np.sort(pd.unique(y_true)) + if y_proba is not None and y_proba.ndim == 2: + try: + metrics["LogLoss"] = log_loss(y_true, y_proba, labels=classes_sorted) + except Exception: + metrics["LogLoss"] = np.nan + # Macro ROC-AUC / PR-AUC via OVR + try: + class_to_index = {c: i for i, c in enumerate(classes_sorted)} + y_true_idx = np.vectorize(class_to_index.get)(y_true) + metrics["ROC-AUC_macro"] = roc_auc_score(y_true_idx, y_proba, multi_class="ovr", average="macro") + except Exception: + metrics["ROC-AUC_macro"] = np.nan + try: + Y_true_ind = np.zeros_like(y_proba) + idx_map = {c: i for i, c in enumerate(classes_sorted)} + Y_true_ind[np.arange(y_proba.shape[0]), np.vectorize(idx_map.get)(y_true)] = 1 + metrics["PR-AUC_macro"] = average_precision_score(Y_true_ind, y_proba, average="macro") + except Exception: + metrics["PR-AUC_macro"] = np.nan + else: + metrics["LogLoss"] = np.nan + metrics["ROC-AUC_macro"] = np.nan + metrics["PR-AUC_macro"] = np.nan + + return metrics + + +def aggregate_metrics(list_of_dicts): + """Aggregate a list of metrics dicts (per split) into mean/std.""" + agg_mean = {} + agg_std = {} + for split in ("Train", "Validation", "Test", "Test (external)"): + keys = set() + for m in list_of_dicts: + if isinstance(m, dict) and split in m: + keys.update(m[split].keys()) + if not keys: + continue + agg_mean[split] = {} + agg_std[split] = {} + for k in keys: + vals = [m[split][k] for m in list_of_dicts if split in m and k in m[split]] + numeric_vals = [] + for v in vals: + try: + numeric_vals.append(float(v)) + except Exception: + pass + if numeric_vals: + agg_mean[split][k] = float(np.mean(numeric_vals)) + agg_std[split][k] = float(np.std(numeric_vals, ddof=0)) + else: + agg_mean[split][k] = vals[-1] if vals else None + agg_std[split][k] = None + return agg_mean, agg_std + + +def compute_metrics_for_split( + predictor, + df: pd.DataFrame, + target_col: str, + problem_type: str, + threshold: Optional[float] = None, # <— NEW +) -> "OrderedDict[str, float]": + """Compute transparency metrics for one split (Train/Val/Test) based on task type.""" + # Prepare inputs + features = df.drop(columns=[target_col], errors="ignore") + y_true_series = df[target_col].reset_index(drop=True) + + # Probabilities (if available) + y_proba = None + try: + y_proba_raw = predictor.predict_proba(features) + y_proba = _safe_y_proba_to_array(y_proba_raw) + except Exception: + y_proba = None + + # Labels (optionally thresholded for binary) + y_pred_series = None + if problem_type == "binary" and (threshold is not None) and (y_proba is not None): + classes_sorted = np.sort(pd.unique(y_true_series)) + pos_label = classes_sorted[-1] + neg_label = classes_sorted[0] + if y_proba.ndim == 1: + pos_scores = y_proba + else: + pos_col_idx = -1 + try: + if hasattr(predictor, "class_labels") and predictor.class_labels: + pos_col_idx = list(predictor.class_labels).index(pos_label) + except Exception: + pos_col_idx = -1 + pos_scores = y_proba[:, pos_col_idx] + y_pred_series = pd.Series(np.where(pos_scores >= float(threshold), pos_label, neg_label)).reset_index(drop=True) + else: + # Fall back to model's default label prediction (argmax / 0.5 equivalent) + y_pred_series = pd.Series(predictor.predict(features)).reset_index(drop=True) + + if problem_type == "regression": + y_true_arr = np.asarray(y_true_series, dtype=float) + y_pred_arr = np.asarray(y_pred_series, dtype=float) + return _compute_regression_metrics(y_true_arr, y_pred_arr) + + if problem_type == "binary": + return _compute_binary_metrics(y_true_series, y_pred_series, y_proba, predictor) + + # multiclass + return _compute_multiclass_metrics(y_true_series, y_pred_series, y_proba) + + +def evaluate_all_transparency( + predictor, + train_df: Optional[pd.DataFrame], + val_df: Optional[pd.DataFrame], + test_df: Optional[pd.DataFrame], + target_col: str, + problem_type: str, + threshold: Optional[float] = None, +) -> Tuple[pd.DataFrame, Dict[str, Dict[str, float]]]: + """ + Evaluate Train/Val/Test with the transparent metrics suite. + Returns: + - metrics_table: DataFrame with index=Metric, columns subset of [Train, Validation, Test] + - raw_dict: nested dict {split -> {metric -> value}} + """ + split_results: Dict[str, Dict[str, float]] = {} + splits = [] + + # IMPORTANT: do NOT apply threshold to Train/Val + if train_df is not None and len(train_df): + split_results["Train"] = compute_metrics_for_split(predictor, train_df, target_col, problem_type, threshold=None) + splits.append("Train") + if val_df is not None and len(val_df): + split_results["Validation"] = compute_metrics_for_split(predictor, val_df, target_col, problem_type, threshold=None) + splits.append("Validation") + if test_df is not None and len(test_df): + split_results["Test"] = compute_metrics_for_split(predictor, test_df, target_col, problem_type, threshold=threshold) + splits.append("Test") + + # Preserve order from the first split; include any extras from others + order_source = split_results[splits[0]] if splits else {} + all_metrics = list(order_source.keys()) + for s in splits[1:]: + for m in split_results[s].keys(): + if m not in all_metrics: + all_metrics.append(m) + + metrics_table = pd.DataFrame(index=all_metrics, columns=splits, dtype=float) + for s in splits: + for m, v in split_results[s].items(): + metrics_table.loc[m, s] = v + + return metrics_table, split_results
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/multimodal_learner.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,391 @@ +#!/usr/bin/env python +""" +Main entrypoint for AutoGluon multimodal training wrapper. +""" + +import argparse +import logging +import os +import sys +from typing import List, Optional + +import pandas as pd +from metrics_logic import aggregate_metrics +from plot_logic import infer_problem_type +from report_utils import write_outputs +from sklearn.model_selection import KFold, StratifiedKFold +from split_logic import split_dataset +from test_pipeline import run_autogluon_test_experiment +from training_pipeline import autogluon_hyperparameters, handle_missing_images, run_autogluon_experiment +# ------------------------------------------------------------------ +# Local imports (your split utilities) +# ------------------------------------------------------------------ +from utils import ( + absolute_path_expander, + enable_deterministic_mode, + enable_tensor_cores_if_available, + ensure_local_tmp, + load_file, + prepare_image_search_dirs, + set_seeds, + str2bool, +) + +# ------------------------------------------------------------------ +# Logger setup +# ------------------------------------------------------------------ +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------ +# Argument parsing (unchanged from your original, only minor fixes) +# ------------------------------------------------------------------ +def parse_args(argv=None): + parser = argparse.ArgumentParser(description="Train & report an AutoGluon model") + + parser.add_argument("--input_csv_train", dest="train_dataset", required=True) + parser.add_argument("--input_csv_test", dest="test_dataset", default=None) + parser.add_argument("--target_column", required=True) + parser.add_argument("--output_json", default="results.json") + parser.add_argument("--output_html", default="report.html") + parser.add_argument("--output_config", default=None) + parser.add_argument("--images_zip", nargs="*", default=None, + help="One or more ZIP files that contain image assets") + parser.add_argument("--missing_image_strategy", default="false", + help="true/false: remove rows with missing images or use placeholder") + parser.add_argument("--threshold", type=float, default=None) + parser.add_argument("--time_limit", type=int, default=None) + parser.add_argument("--deterministic", action="store_true", default=False, + help="Enable deterministic algorithms to reduce run-to-run variance") + parser.add_argument("--random_seed", type=int, default=42) + parser.add_argument("--cross_validation", type=str, default="false") + parser.add_argument("--num_folds", type=int, default=5) + parser.add_argument("--epochs", type=int, default=None) + parser.add_argument("--learning_rate", type=float, default=None) + parser.add_argument("--batch_size", type=int, default=None) + parser.add_argument("--backbone_image", type=str, default="swin_base_patch4_window7_224") + parser.add_argument("--backbone_text", type=str, default="microsoft/deberta-v3-base") + parser.add_argument("--validation_size", type=float, default=0.2) + parser.add_argument("--split_probabilities", type=float, nargs=3, + default=[0.7, 0.1, 0.2], metavar=("train", "val", "test")) + parser.add_argument("--preset", choices=["medium_quality", "high_quality", "best_quality"], + default="medium_quality") + parser.add_argument("--eval_metric", default="roc_auc") + parser.add_argument("--hyperparameters", default=None) + + args, unknown = parser.parse_known_args(argv) + if unknown: + logger.warning("Ignoring unknown CLI tokens: %s", unknown) + + # -------------------------- Validation -------------------------- + if not (0.0 <= args.validation_size <= 1.0): + parser.error("--validation_size must be in [0, 1]") + if len(args.split_probabilities) != 3 or abs(sum(args.split_probabilities) - 1.0) > 1e-6: + parser.error("--split_probabilities must be three numbers summing to 1.0") + if args.cross_validation.lower() == "true" and (args.num_folds < 2): + parser.error("--num_folds must be >= 2 when --cross_validation is true") + + return args + + +def run_cross_validation( + args, + df_full: pd.DataFrame, + test_dataset: Optional[pd.DataFrame], + image_cols: List[str], + ag_config: dict, +): + """Cross-validation loop returning aggregated metrics and last predictor.""" + df_full = df_full.drop(columns=["split"], errors="ignore") + y = df_full[args.target_column] + try: + use_stratified = y.dtype == object or y.nunique() <= 20 + except Exception: + use_stratified = False + + kf = StratifiedKFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) if use_stratified else KFold(n_splits=int(args.num_folds), shuffle=True, random_state=int(args.random_seed)) + + raw_folds = [] + ag_folds = [] + folds_info = [] + last_predictor = None + last_data_ctx = None + + for fold_idx, (train_idx, val_idx) in enumerate(kf.split(df_full, y if use_stratified else None), start=1): + logger.info(f"CV fold {fold_idx}/{args.num_folds}") + df_tr = df_full.iloc[train_idx].copy() + df_va = df_full.iloc[val_idx].copy() + + df_tr["split"] = "train" + df_va["split"] = "val" + fold_dataset = pd.concat([df_tr, df_va], ignore_index=True) + + predictor_fold, data_ctx = run_autogluon_experiment( + train_dataset=fold_dataset, + test_dataset=test_dataset, + target_column=args.target_column, + image_columns=image_cols, + ag_config=ag_config, + ) + last_predictor = predictor_fold + last_data_ctx = data_ctx + problem_type = infer_problem_type(predictor_fold, df_tr, args.target_column) + eval_results = run_autogluon_test_experiment( + predictor=predictor_fold, + data_ctx=data_ctx, + target_column=args.target_column, + eval_metric=args.eval_metric, + ag_config=ag_config, + problem_type=problem_type, + ) + + raw_metrics_fold = eval_results.get("raw_metrics", {}) + ag_by_split_fold = eval_results.get("ag_eval", {}) + raw_folds.append(raw_metrics_fold) + ag_folds.append(ag_by_split_fold) + folds_info.append( + { + "fold": int(fold_idx), + "predictor_path": getattr(predictor_fold, "path", None), + "raw_metrics": raw_metrics_fold, + "ag_eval": ag_by_split_fold, + } + ) + + raw_metrics_mean, raw_metrics_std = aggregate_metrics(raw_folds) + ag_by_split_mean, ag_by_split_std = aggregate_metrics(ag_folds) + return ( + last_predictor, + raw_metrics_mean, + ag_by_split_mean, + raw_folds, + ag_folds, + raw_metrics_std, + ag_by_split_std, + folds_info, + last_data_ctx, + ) + + +# ------------------------------------------------------------------ +# Main execution +# ------------------------------------------------------------------ +def main(): + args = parse_args() + + # ------------------------------------------------------------------ + # Debug output + # ------------------------------------------------------------------ + logger.info("=== AutoGluon Training Wrapper Started ===") + logger.info(f"Working directory: {os.getcwd()}") + logger.info(f"Command line: {' '.join(sys.argv)}") + logger.info(f"Parsed args: {vars(args)}") + + # ------------------------------------------------------------------ + # Reproducibility & performance + # ------------------------------------------------------------------ + set_seeds(args.random_seed) + if args.deterministic: + enable_deterministic_mode(args.random_seed) + logger.info("Deterministic mode enabled (seed=%s)", args.random_seed) + ensure_local_tmp() + enable_tensor_cores_if_available() + + # ------------------------------------------------------------------ + # Load datasets + # ------------------------------------------------------------------ + train_dataset = load_file(args.train_dataset) + test_dataset = load_file(args.test_dataset) if args.test_dataset else None + + logger.info(f"Train dataset loaded: {len(train_dataset)} rows") + if test_dataset is not None: + logger.info(f"Test dataset loaded: {len(test_dataset)} rows") + + # ------------------------------------------------------------------ + # Resolve target column by name; if Galaxy passed a numeric index, + # translate it to the corresponding header so downstream checks pass. + # Galaxy's data_column widget is 1-based. + # ------------------------------------------------------------------ + if args.target_column not in train_dataset.columns and str(args.target_column).isdigit(): + idx = int(args.target_column) - 1 + if 0 <= idx < len(train_dataset.columns): + resolved = train_dataset.columns[idx] + logger.info(f"Target column '{args.target_column}' not found; using column #{idx + 1} header '{resolved}' instead.") + args.target_column = resolved + else: + logger.error(f"Numeric target index '{args.target_column}' is out of range for dataset with {len(train_dataset.columns)} columns.") + sys.exit(1) + + # ------------------------------------------------------------------ + # Image handling (ZIP extraction + absolute path expansion) + # ------------------------------------------------------------------ + extracted_imgs_path = prepare_image_search_dirs(args) + + image_cols = absolute_path_expander(train_dataset, extracted_imgs_path, None) + if test_dataset is not None: + absolute_path_expander(test_dataset, extracted_imgs_path, image_cols) + + # ------------------------------------------------------------------ + # Handle missing images + # ------------------------------------------------------------------ + train_dataset = handle_missing_images( + train_dataset, + image_columns=image_cols, + strategy=args.missing_image_strategy, + ) + if test_dataset is not None: + test_dataset = handle_missing_images( + test_dataset, + image_columns=image_cols, + strategy=args.missing_image_strategy, + ) + + logger.info(f"After cleanup → train: {len(train_dataset)}, test: {len(test_dataset) if test_dataset is not None else 0}") + + # ------------------------------------------------------------------ + # Dataset splitting logic (adds 'split' column to train_dataset) + # ------------------------------------------------------------------ + split_dataset( + train_dataset=train_dataset, + test_dataset=test_dataset, + target_column=args.target_column, + split_probabilities=args.split_probabilities, + validation_size=args.validation_size, + random_seed=args.random_seed, + ) + + logger.info("Preprocessing complete — ready for AutoGluon training!") + logger.info(f"Final split counts:\n{train_dataset['split'].value_counts().sort_index()}") + + # Verify target/image/text columns exist + if args.target_column not in train_dataset.columns: + logger.error(f"Target column '{args.target_column}' not found in training data.") + sys.exit(1) + if test_dataset is not None and args.target_column not in test_dataset.columns: + logger.error(f"Target column '{args.target_column}' not found in test data.") + sys.exit(1) + + # Threshold is only meaningful for binary classification; ignore otherwise. + threshold_for_run = args.threshold + unique_labels = None + target_looks_binary = False + try: + unique_labels = train_dataset[args.target_column].nunique(dropna=True) + target_looks_binary = unique_labels == 2 + except Exception: + logger.warning("Could not inspect target column '%s' for threshold validation; proceeding without binary check.", args.target_column) + + if threshold_for_run is not None: + if target_looks_binary: + threshold_for_run = float(threshold_for_run) + logger.info("Applying custom decision threshold %.4f for binary evaluation.", threshold_for_run) + else: + logger.warning( + "Threshold %.3f provided but target '%s' does not appear binary (unique labels=%s); ignoring threshold.", + threshold_for_run, + args.target_column, + unique_labels if unique_labels is not None else "unknown", + ) + threshold_for_run = None + args.threshold = threshold_for_run + # Image columns are auto-inferred; image_cols already resolved to absolute paths. + # ------------------------------------------------------------------ + # Build AutoGluon configuration from CLI knobs + # ------------------------------------------------------------------ + ag_config = autogluon_hyperparameters( + threshold=args.threshold, + time_limit=args.time_limit, + random_seed=args.random_seed, + epochs=args.epochs, + learning_rate=args.learning_rate, + batch_size=args.batch_size, + backbone_image=args.backbone_image, + backbone_text=args.backbone_text, + preset=args.preset, + eval_metric=args.eval_metric, + hyperparameters=args.hyperparameters, + ) + logger.info(f"AutoGluon config prepared: fit={ag_config.get('fit')}, hyperparameters keys={list(ag_config.get('hyperparameters', {}).keys())}") + + cv_enabled = str2bool(args.cross_validation) + if cv_enabled: + ( + predictor, + raw_metrics, + ag_by_split, + raw_folds, + ag_folds, + raw_metrics_std, + ag_by_split_std, + folds_info, + data_ctx, + ) = run_cross_validation( + args=args, + df_full=train_dataset, + test_dataset=test_dataset, + image_cols=image_cols, + ag_config=ag_config, + ) + if predictor is None: + logger.error("All CV folds failed. Exiting.") + sys.exit(1) + eval_results = { + "raw_metrics": raw_metrics, + "ag_eval": ag_by_split, + "fit_summary": None, + } + else: + predictor, data_ctx = run_autogluon_experiment( + train_dataset=train_dataset, + test_dataset=test_dataset, + target_column=args.target_column, + image_columns=image_cols, + ag_config=ag_config, + ) + logger.info("AutoGluon training finished. Model path: %s", getattr(predictor, "path", None)) + + # Evaluate predictor on Train/Val/Test splits + problem_type = infer_problem_type(predictor, train_dataset, args.target_column) + eval_results = run_autogluon_test_experiment( + predictor=predictor, + data_ctx=data_ctx, + target_column=args.target_column, + eval_metric=args.eval_metric, + ag_config=ag_config, + problem_type=problem_type, + ) + raw_metrics = eval_results.get("raw_metrics", {}) + ag_by_split = eval_results.get("ag_eval", {}) + raw_folds = ag_folds = raw_metrics_std = ag_by_split_std = None + + logger.info("Transparent metrics by split: %s", eval_results["raw_metrics"]) + logger.info("AutoGluon evaluate() by split: %s", eval_results["ag_eval"]) + + if "problem_type" in eval_results: + problem_type_final = eval_results["problem_type"] + else: + problem_type_final = infer_problem_type(predictor, train_dataset, args.target_column) + + write_outputs( + args=args, + predictor=predictor, + problem_type=problem_type_final, + eval_results=eval_results, + data_ctx=data_ctx, + raw_folds=raw_folds, + ag_folds=ag_folds, + raw_metrics_std=raw_metrics_std, + ag_by_split_std=ag_by_split_std, + ) + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)s | %(message)s", + datefmt="%H:%M:%S" + ) + # Quiet noisy image parsing logs (e.g., PIL.PngImagePlugin debug streams) + logging.getLogger("PIL").setLevel(logging.WARNING) + logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING) + main()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/multimodal_learner.xml Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,316 @@ +<tool id="multimodal_learner" name="Multimodal Learner" version="0.1.0" profile="22.01"> + <description>Train and evaluate an AutoGluon Multimodal model (tabular + image + text)</description> + + <requirements> + <container type='docker'>quay.io/goeckslab/multimodal-learner:1.4.0</container> + </requirements> + + <required_files> + <include path="multimodal_learner.py"/> + <include path="utils.py"/> + <include path="split_logic.py"/> + <include path="training_pipeline.py"/> + <include path="test_pipeline.py"/> + <include path="metrics_logic.py"/> + <include path="plot_logic.py"/> + <include path="report_utils.py"/> + <include path="feature_help_modal.py"/> + </required_files> + + <stdio> + <exit_code range="137" level="fatal_oom" description="Out of Memory"/> + <exit_code range="1:" level="fatal" description="Tool failed — see Tool Standard Error"/> + </stdio> + + <command detect_errors="exit_code"><![CDATA[ +#import re + +#set $image_zip_paths = [] +#if $use_images_conditional.use_images == "yes" + #for $zip_file in $use_images_conditional.images_zip_repeat + #set $image_zip_paths = $image_zip_paths + [$zip_file.images_zip] + #end for +#end if +#if len($image_zip_paths) > 0 + #set $images_zip_cli = " ".join(["'%s'" % z for z in $image_zip_paths]) +#else + #set $images_zip_cli = None +#end if + +set -e; +ln -sf '$input_csv' 'train_input.csv'; +#if $test_dataset_conditional.has_test_dataset == "yes" +ln -sf '$test_dataset_conditional.input_test' 'test_input.csv'; +#end if + +python '$__tool_directory__/multimodal_learner.py' + --input_csv_train 'train_input.csv' + #if $test_dataset_conditional.has_test_dataset == "yes" + --input_csv_test 'test_input.csv' + #end if + --target_column '$target_column' + + #if $use_images_conditional.use_images == "yes" + #if $images_zip_cli + --images_zip $images_zip_cli + #end if + --missing_image_strategy '$use_images_conditional.missing_image_strategy' + #if $use_images_conditional.backbone_image + --backbone_image '$use_images_conditional.backbone_image' + #end if + #end if + + #if $backbone_text not in ("", None) + --backbone_text '$backbone_text' + #end if + + --preset '$preset' + --eval_metric '$eval_metric' + + --random_seed '$random_seed' + #if $time_limit + --time_limit $time_limit + #end if + #if $deterministic == "true" + --deterministic + #end if + + #if $customize_defaults_conditional.customize_defaults == "yes" + #if $customize_defaults_conditional.validation_size not in ("", None) + --validation_size $customize_defaults_conditional.validation_size + #end if + #if $customize_defaults_conditional.split_probabilities and str($customize_defaults_conditional.split_probabilities).strip() + --split_probabilities #echo " ".join([str(float(x)) for x in str($customize_defaults_conditional.split_probabilities).replace(",", " ").split() if x.strip()]) # + #end if + #if $customize_defaults_conditional.cross_validation == "true" + --cross_validation true + --num_folds $customize_defaults_conditional.num_folds + #end if + #if $customize_defaults_conditional.epochs + --epochs $customize_defaults_conditional.epochs + #end if + #if $customize_defaults_conditional.learning_rate + --learning_rate $customize_defaults_conditional.learning_rate + #end if + #if $customize_defaults_conditional.batch_size + --batch_size $customize_defaults_conditional.batch_size + #end if + #if $customize_defaults_conditional.threshold + --threshold $customize_defaults_conditional.threshold + #end if + #if $customize_defaults_conditional.hyperparameters + --hyperparameters '$customize_defaults_conditional.hyperparameters' + #end if + #end if + + --output_json '$output_json' + --output_html '$output_html' + --output_config '$output_config' +]]></command> + + <inputs> + <param name="input_csv" type="data" format="csv,tsv" label="Training dataset (CSV/TSV)" help="Must contain the target column and optional image paths"/> + <param name="target_column" type="data_column" data_ref="input_csv" numerical="false" use_header_names="true" label="Target / Label column"/> + + <conditional name="test_dataset_conditional"> + <param name="has_test_dataset" type="boolean" truevalue="yes" falsevalue="no" checked="false" label="Provide separate test dataset?"/> + <when value="yes"> + <param name="input_test" type="data" format="csv,tsv" optional="true" label="Test dataset (CSV/TSV)"/> + </when> + <when value="no"/> + </conditional> + + <param name="backbone_text" type="select" label="Text backbone" optional="true"> + <option value="microsoft/deberta-v3-base" selected="true">microsoft/deberta-v3-base</option> + <option value="microsoft/deberta-v3-small">microsoft/deberta-v3-small</option> + <option value="google/electra-base-discriminator">google/electra-base-discriminator</option> + <option value="google/electra-small-discriminator">google/electra-small-discriminator</option> + <option value="roberta-base">roberta-base</option> + <option value="bert-base-uncased">bert-base-uncased</option> + <option value="distilroberta-base">distilroberta-base</option> + <option value="albert-base-v2">albert-base-v2</option> + </param> + + <conditional name="use_images_conditional"> + <param name="use_images" type="boolean" truevalue="yes" falsevalue="no" checked="false" label="Use image modality?"/> + <when value="yes"> + <repeat name="images_zip_repeat" title="Image archive(s)" min="1"> + <param name="images_zip" type="data" format="zip" label="ZIP file containing images"/> + </repeat> + <param name="backbone_image" type="select" label="Image backbone" optional="true"> + <option value='swin_base_patch4_window7_224' selected='true'>swin_base_patch4_window7_224</option> + <option value='swin_large_patch4_window12_384.in22k_ft_in1k'>swin_large_patch4_window12_384.in22k_ft_in1k</option> + <option value='swin_small_patch4_window7_224'>swin_small_patch4_window7_224</option> + <option value='swin_tiny_patch4_window7_224'>swin_tiny_patch4_window7_224</option> + <option value='caformer_b36.in21k_ft_in1k'>caformer_b36.in21k_ft_in1k</option> + <option value='caformer_m36.in21k_ft_in1k'>caformer_m36.in21k_ft_in1k</option> + <option value='caformer_s36.in21k_ft_in1k'>caformer_s36.in21k_ft_in1k</option> + <option value='caformer_s18.in1k'>caformer_s18.in1k</option> + <option value='caformer_b36.sail_in22k_ft_in1k'>caformer_b36.sail_in22k_ft_in1k</option> + <option value='caformer_m36.sail_in22k_ft_in1k'>caformer_m36.sail_in22k_ft_in1k</option> + <option value='caformer_s36.sail_in22k_ft_in1k'>caformer_s36.sail_in22k_ft_in1k</option> + <option value='vit_base_patch16_224'>vit_base_patch16_224</option> + <option value='vit_large_patch14_224'>vit_large_patch14_224</option> + <option value='convnext_base'>convnext_base</option> + <option value='eva02_base_patch14_448.mim_in22k_ft_in22k_in1k'>eva02_base_patch14_448.mim_in22k_ft_in22k_in1k</option> + <option value='resnet50'>resnet50</option> + </param> + <param name="missing_image_strategy" type="boolean" truevalue="true" falsevalue="false" checked="false" + label="Drop rows with missing images?" help="True = drop, False = replace with placeholder (default)"/> + </when> + <when value="no"/> + </conditional> + + <param name="preset" type="select" label="Quality preset"> + <option value="medium_quality" selected="true">Medium quality (fast)</option> + <option value="high_quality">High quality</option> + <option value="best_quality">Best quality (slowest)</option> + </param> + + <param name="eval_metric" type="select" label="Primary evaluation metric"> + <option value="auto" selected="true">Auto (let AutoGluon choose)</option> + <option value="roc_auc">ROC AUC</option> + <option value="accuracy">Accuracy</option> + <option value="balanced_accuracy">Balanced Accuracy</option> + <option value="f1">F1</option> + <option value="f1_macro">F1 Macro</option> + <option value="f1_micro">F1 Micro</option> + <option value="f1_weighted">F1 Weighted</option> + <option value="precision">Precision</option> + <option value="precision_macro">Precision Macro</option> + <option value="precision_micro">Precision Micro</option> + <option value="precision_weighted">Precision Weighted</option> + <option value="recall">Recall</option> + <option value="recall_macro">Recall Macro</option> + <option value="recall_micro">Recall Micro</option> + <option value="recall_weighted">Recall Weighted</option> + <option value="average_precision">Average Precision</option> + <option value="roc_auc_ovo_macro">ROC AUC OVO Macro</option> + <option value="roc_auc_ovo_weighted">ROC AUC OVO Weighted</option> + <option value="roc_auc_ovr_macro">ROC AUC OVR Macro</option> + <option value="roc_auc_ovr_weighted">ROC AUC OVR Weighted</option> + <option value="log_loss">Log Loss</option> + <option value="mse">MSE</option> + <option value="rmse">RMSE</option> + <option value="mae">MAE</option> + <option value="msle">MSLE</option> + <option value="r2">R2</option> + </param> + + <param name="random_seed" type="integer" value="42" label="Random seed"/> + + <param name="time_limit" type="integer" optional="true" min="60" label="Time limit (seconds)" help="Total training time budget. Recommended: 3600+ for real runs"/> + <param name="deterministic" type="boolean" truevalue="true" falsevalue="false" checked="false" + label="Enable deterministic mode" help="Use deterministic algorithms and CuDNN settings to reduce run-to-run variance (may slow training)"/> + + <conditional name="customize_defaults_conditional"> + <param name="customize_defaults" type="boolean" truevalue="yes" falsevalue="no" checked="false" label="Advanced: customize training settings?"/> + <when value="yes"> + <param name="validation_size" type="float" value="0.2" label="Validation fraction (when test set provided)"/> + <param name="split_probabilities" type="text" value="0.7 0.1 0.2" label="Train / Val / Test split (space-separated) when no test set"/> + <param name="cross_validation" type="boolean" truevalue="true" falsevalue="false" checked="false" label="Enable k-fold cross-validation"/> + <param name="num_folds" type="integer" value="5" label="Number of CV folds"/> + <param name="epochs" type="integer" optional="true" label="Max epochs"/> + <param name="learning_rate" type="float" optional="true" label="Learning rate"/> + <param name="batch_size" type="integer" optional="true" label="Batch size"/> + <param name="threshold" type="float" optional="true" min="0" max="1" label="Binary classification threshold"/> + <param name="hyperparameters" type="text" optional="true" label="Extra AutoGluon hyperparameters (JSON or YAML string)"/> + </when> + <when value="no"/> + </conditional> + </inputs> + + <outputs> + <data name="output_html" format="html" label="Multimodal Learner analysis report on data ${input_csv.name}"/> + <data name="output_config" format="yaml" label="Multimodal Learner training config on data ${input_csv.name}"/> + <data name="output_json" format="json" label="Multimodal Learner metric results on data ${input_csv.name}"/> + </outputs> + + <tests> + <!-- Basic run with images + external test set --> + <test expect_num_outputs="3"> + <param name="input_csv" value="train.csv"/> + <param name="target_column" value="7"/> + <param name="test_dataset_conditional|has_test_dataset" value="yes"/> + <param name="test_dataset_conditional|input_test" value="test.csv"/> + <param name="use_images_conditional|use_images" value="yes"/> + <param name="use_images_conditional|images_zip_repeat_0|images_zip" value="images.zip"/> + <param name="use_images_conditional|backbone_image" value="resnet50"/> + <param name="backbone_text" value="google/electra-base-discriminator"/> + <output name="output_html"> + <assert_contents> + <has_text text="Model Performance Summary"/> + </assert_contents> + </output> + </test> + + <!-- Custom threshold --> + <test expect_num_outputs="3"> + <param name="input_csv" value="train.csv"/> + <param name="target_column" value="7"/> + <param name="test_dataset_conditional|has_test_dataset" value="yes"/> + <param name="test_dataset_conditional|input_test" value="test.csv"/> + <param name="use_images_conditional|use_images" value="yes"/> + <param name="use_images_conditional|images_zip_repeat_0|images_zip" value="images.zip"/> + <param name="customize_defaults_conditional|customize_defaults" value="yes"/> + <param name="customize_defaults_conditional|threshold" value="0.4"/> + <output name="output_json"> + <assert_contents> + <has_text text=""threshold": 0.4"/> + </assert_contents> + </output> + </test> + + <!-- No external test set; internal split --> + <test expect_num_outputs="3"> + <param name="input_csv" value="train.csv"/> + <param name="target_column" value="7"/> + <param name="test_dataset_conditional|has_test_dataset" value="no"/> + <param name="use_images_conditional|use_images" value="yes"/> + <param name="use_images_conditional|images_zip_repeat_0|images_zip" value="images.zip"/> + <output name="output_json"> + <assert_contents> + <has_text text=""val""/> + </assert_contents> + </output> + </test> + + <!-- Text/tabular only (ignore images) --> + <test expect_num_outputs="3"> + <param name="input_csv" value="train.csv"/> + <param name="target_column" value="7"/> + <param name="test_dataset_conditional|has_test_dataset" value="yes"/> + <param name="test_dataset_conditional|input_test" value="test.csv"/> + <param name="use_images_conditional|use_images" value="no"/> + <output name="output_html"> + <assert_contents> + <has_text text="Train and Validation Performance Summary"/> + </assert_contents> + </output> + </test> + </tests> + + <help><![CDATA[ +**AutoGluon Multimodal Learner** + +Trains a powerful multimodal model combining tabular features, images, and text using AutoGluon-Multimodal. + +- Handles missing images intelligently +- Supports cross-validation +- Produces detailed HTML reports and transparent metrics +- Fully reproducible + +Ideal for medical imaging + clinical data, product images + descriptions, etc. +]]></help> + + <citations> + <citation type="bibtex"> +@article{AutoGluon2023, + author = {Erickson, Nick and Mueller, Jonas and Wang, Yizhou and others}, + title = {AutoGluon-Tabular: Robust and Accurate AutoML for Structured Data}, + journal = {arXiv preprint arXiv:2003.06505}, + year = {2023} +} + </citation> + </citations> +</tool>
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/plot_logic.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,1697 @@ +from __future__ import annotations + +import html +import os +from html import escape as _escape +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +import shap +from feature_help_modal import get_metrics_help_modal +from report_utils import build_tabbed_html, get_html_closing, get_html_template +from sklearn.calibration import calibration_curve +from sklearn.metrics import ( + auc, + average_precision_score, + classification_report, + confusion_matrix, + log_loss, + precision_recall_curve, + roc_auc_score, + roc_curve, +) +from sklearn.model_selection import learning_curve as skl_learning_curve +from sklearn.preprocessing import label_binarize + +# ========================= +# Utilities +# ========================= + + +def plot_with_table_style_title(fig, title: str) -> str: + """ + Render a Plotly figure with a report-style <h2> header so it matches the + green table section headers. + """ + # kill Plotly’s built-in title + fig.update_layout(title=None) + + # figure HTML without PlotlyJS (we load it once globally) + plot_html = fig.to_html(full_html=False, include_plotlyjs=False) + + # use <h2> — your CSS already styles <h2> like the table headers + return f""" +<h2>{html.escape(title)}</h2> +<div class="plotly-center">{plot_html}</div> +""".strip() + + +def _save_plotly(fig: go.Figure, path: Optional[str]) -> None: + """ + Save a Plotly figure. If `path` ends with `.html`, save interactive HTML. + If it ends with a raster extension (png/jpg/jpeg/webp), uses Kaleido. + If None, do nothing (caller may choose to display in notebook). + """ + if not path: + return + ext = os.path.splitext(path)[1].lower() + if ext == ".html": + fig.write_html(path, include_plotlyjs="cdn", full_html=True) + else: + # Requires kaleido: pip install -U kaleido + fig.write_image(path) + + +def _save_matplotlib(path: Optional[str]) -> None: + """Save current Matplotlib figure if `path` is provided, else show().""" + if path: + plt.savefig(path, bbox_inches="tight") + plt.close() + else: + plt.show() + +# ========================= +# Classification Plots +# ========================= + + +def generate_confusion_matrix_plot( + y_true, + y_pred, + title: str = "Confusion Matrix", +) -> go.Figure: + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + + # Class order (works for strings or numbers) + labels = pd.Index(np.unique(np.concatenate([y_true, y_pred])), dtype=object).tolist() + cm = confusion_matrix(y_true, y_pred, labels=labels) + max_val = cm.max() if cm.size else 0 + + # Use categorical axes by passing string labels for x/y + cats = [str(label) for label in labels] + total = int(cm.sum()) + + fig = go.Figure( + data=go.Heatmap( + z=cm, + x=cats, # categorical x + y=cats, # categorical y + colorscale="Blues", + showscale=True, + colorbar=dict(title="Count"), + xgap=2, + ygap=2, + hovertemplate="True=%{y}<br>Pred=%{x}<br>Count=%{z}<extra></extra>", + zmin=0 + ) + ) + + # Add annotations with count and percentage (all white text, matching sample_output.html) + annotations = [] + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + val = int(cm[i, j]) + pct = (val / total * 100) if total > 0 else 0 + text_color = "white" if max_val and val > (max_val / 2) else "black" + # Count annotation (bold, bottom) + annotations.append( + dict( + x=cats[j], + y=cats[i], + text=f"<b>{val}</b>", + showarrow=False, + font=dict(color=text_color, size=14), + xanchor="center", + yanchor="bottom", + yshift=2 + ) + ) + # Percentage annotation (top) + annotations.append( + dict( + x=cats[j], + y=cats[i], + text=f"{pct:.1f}%", + showarrow=False, + font=dict(color=text_color, size=13), + xanchor="center", + yanchor="top", + yshift=-2 + ) + ) + + fig.update_layout( + title=None, + xaxis_title="Predicted label", + yaxis_title="True label", + xaxis=dict(type="category"), + yaxis=dict(type="category", autorange="reversed"), # typical CM orientation + margin=dict(l=80, r=20, t=40, b=80), + template="plotly_white", + plot_bgcolor="white", + paper_bgcolor="white", + annotations=annotations + ) + return fig + + +def generate_roc_curve_plot( + y_true_bin: np.ndarray, + y_score: np.ndarray, + title: str = "ROC Curve", + marker_threshold: float | None = None, +) -> go.Figure: + y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) + y_score = np.asarray(y_score).astype(float).reshape(-1) + + fpr, tpr, thr = roc_curve(y_true_bin, y_score) + roc_auc = auc(fpr, tpr) + + fig = go.Figure() + fig.add_trace(go.Scatter( + x=fpr, y=tpr, mode="lines", + name=f"ROC (AUC={roc_auc:.3f})", + line=dict(width=3) + )) + + # 45° chance line (no legend to keep it clean) + fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode="lines", + line=dict(dash="dash", width=2, color="#888"), showlegend=False)) + + # Optional marker at the user threshold + if marker_threshold is not None and len(thr): + # roc_curve returns thresholds of same length as fpr/tpr; includes inf at idx 0 + finite = np.isfinite(thr) + if np.any(finite): + idx_local = int(np.argmin(np.abs(thr[finite] - float(marker_threshold)))) + idx = int(np.nonzero(finite)[0][idx_local]) # map back to original indices + x_m, y_m = float(fpr[idx]), float(tpr[idx]) + + fig.add_trace( + go.Scatter( + x=[x_m], y=[y_m], + mode="markers", + name=f"@ {float(marker_threshold):.2f}", + marker=dict(size=12, color="red", symbol="x") + ) + ) + fig.add_annotation( + x=0.02, y=0.98, xref="paper", yref="paper", + text=f"threshold = {float(marker_threshold):.2f}", + showarrow=False, + font=dict(color="black", size=12), + align="left" + ) + + fig.update_layout( + title=None, + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + template="plotly_white", + legend=dict(x=1, y=0, xanchor="right"), + margin=dict(l=60, r=20, t=60, b=60), + ) + return fig + + +def generate_pr_curve_plot( + y_true_bin: np.ndarray, + y_score: np.ndarray, + title: str = "Precision–Recall Curve", + marker_threshold: float | None = None, +) -> go.Figure: + y_true_bin = np.asarray(y_true_bin).astype(int).reshape(-1) + y_score = np.asarray(y_score).astype(float).reshape(-1) + + precision, recall, thr = precision_recall_curve(y_true_bin, y_score) + pr_auc = auc(recall, precision) + + fig = go.Figure() + fig.add_trace(go.Scatter( + x=recall, y=precision, mode="lines", + name=f"PR (AUC={pr_auc:.3f})", + line=dict(width=3) + )) + + # Optional marker at the user threshold + if marker_threshold is not None and len(thr): + # In PR, thresholds has length len(precision)-1. The point for thr[j] is (recall[j+1], precision[j+1]). + j = int(np.argmin(np.abs(thr - float(marker_threshold)))) + j = int(np.clip(j, 0, len(thr) - 1)) + x_m, y_m = float(recall[j + 1]), float(precision[j + 1]) + + fig.add_trace( + go.Scatter( + x=[x_m], y=[y_m], + mode="markers", + name=f"@ {float(marker_threshold):.2f}", + marker=dict(size=12, color="red", symbol="x") + ) + ) + fig.add_annotation( + x=0.02, y=0.98, xref="paper", yref="paper", + text=f"threshold = {float(marker_threshold):.2f}", + showarrow=False, + font=dict(color="black", size=12), + align="left" + ) + + fig.update_layout( + title=None, + xaxis_title="Recall", + yaxis_title="Precision", + template="plotly_white", + legend=dict(x=1, y=0, xanchor="right"), + margin=dict(l=60, r=20, t=60, b=60), + ) + return fig + + +def generate_calibration_plot( + y_true_bin: np.ndarray, + y_prob: np.ndarray, + n_bins: int = 10, + title: str = "Calibration Plot", + path: Optional[str] = None, +) -> go.Figure: + """ + Binary calibration curve (Plotly). + """ + prob_true, prob_pred = calibration_curve(y_true_bin, y_prob, n_bins=n_bins, strategy="uniform") + fig = go.Figure() + fig.add_trace(go.Scatter( + x=prob_pred, y=prob_true, mode="lines+markers", name="Model", + line=dict(color="#1f77b4", width=3), marker=dict(size=7, color="#1f77b4") + )) + fig.add_trace( + go.Scatter( + x=[0, 1], y=[0, 1], + mode="lines", + line=dict(dash="dash", color="#808080", width=2), + name="Perfect" + ) + ) + fig.update_layout( + title=None, + xaxis_title="Predicted Probability", + yaxis_title="Observed Probability", + yaxis=dict(range=[0, 1]), + xaxis=dict(range=[0, 1]), + template="plotly_white", + margin=dict(l=60, r=40, t=50, b=50), + ) + _save_plotly(fig, path) + return fig + + +def generate_threshold_plot( + y_true_bin: np.ndarray, + y_prob: np.ndarray, + title: str = "Threshold Plot", + user_threshold: float | None = None, +) -> go.Figure: + y_true = np.asarray(y_true_bin, dtype=int).ravel() + p = np.asarray(y_prob, dtype=float).ravel() + p = np.nan_to_num(p, nan=0.0) + p = np.clip(p, 0.0, 1.0) + + def _compute_metrics(thresholds: np.ndarray): + """Vectorized-ish helper to compute precision/recall/F1/queue rate arrays.""" + prec, rec, f1, qrate = [], [], [], [] + for t in thresholds: + yhat = (p >= t).astype(int) + tp = int(((yhat == 1) & (y_true == 1)).sum()) + fp = int(((yhat == 1) & (y_true == 0)).sum()) + fn = int(((yhat == 0) & (y_true == 1)).sum()) + + pr = tp / (tp + fp) if (tp + fp) else np.nan # undefined when no predicted positives + rc = tp / (tp + fn) if (tp + fn) else 0.0 + f = (2 * pr * rc) / (pr + rc) if (pr + rc) and not np.isnan(pr) else 0.0 + q = float(yhat.mean()) + + prec.append(pr) + rec.append(rc) + f1.append(f) + qrate.append(q) + return np.asarray(prec, dtype=float), np.asarray(rec, dtype=float), np.asarray(f1, dtype=float), np.asarray(qrate, dtype=float) + + # Use uniform threshold grid for plotting (0 to 1 in steps of 0.01) + th = np.linspace(0.0, 1.0, 101) + prec, rec, f1_arr, qrate = _compute_metrics(th) + + # Compute F1*-optimal threshold using actual score distribution (more precise than grid) + cand_th = np.unique(np.concatenate(([0.0, 1.0], p))) + # cap to a reasonable size by sampling if extremely large + if cand_th.size > 2000: + cand_th = np.linspace(0.0, 1.0, 2001) + _, _, f1_cand, _ = _compute_metrics(cand_th) + + if np.all(np.isnan(f1_cand)): + t_star = 0.5 # fallback when no valid F1 can be computed + else: + f1_max = np.nanmax(f1_cand) + best_idxs = np.where(np.isclose(f1_cand, f1_max, equal_nan=False))[0] + # pick the middle of the best candidates to avoid biasing toward 0 + best_idx = int(best_idxs[len(best_idxs) // 2]) + t_star = float(cand_th[best_idx]) + + # Replace NaNs for plotting (set to 0 where precision is undefined) + prec_plot = np.nan_to_num(prec, nan=0.0) + + fig = go.Figure() + + # Precision (blue line) + fig.add_trace(go.Scatter( + x=th, y=prec_plot, mode="lines", name="Precision", + line=dict(width=3, color="#1f77b4"), + hovertemplate="Threshold=%{x:.3f}<br>Precision=%{y:.3f}<extra></extra>" + )) + + # Recall (orange line) + fig.add_trace(go.Scatter( + x=th, y=rec, mode="lines", name="Recall", + line=dict(width=3, color="#ff7f0e"), + hovertemplate="Threshold=%{x:.3f}<br>Recall=%{y:.3f}<extra></extra>" + )) + + # F1 (green line) + fig.add_trace(go.Scatter( + x=th, y=f1_arr, mode="lines", name="F1", + line=dict(width=3, color="#2ca02c"), + hovertemplate="Threshold=%{x:.3f}<br>F1=%{y:.3f}<extra></extra>" + )) + + # Queue Rate (grey dashed line) + fig.add_trace(go.Scatter( + x=th, y=qrate, mode="lines", name="Queue Rate", + line=dict(width=2, color="#808080", dash="dash"), + hovertemplate="Threshold=%{x:.3f}<br>Queue Rate=%{y:.3f}<extra></extra>" + )) + + # F1*-optimal threshold marker (dashed vertical line) + fig.add_vline( + x=t_star, + line_width=2, + line_dash="dash", + line_color="black", + annotation_text=f"t* = {t_star:.2f}", + annotation_position="top" + ) + + # User threshold (solid red line) if provided + if user_threshold is not None: + fig.add_vline( + x=float(user_threshold), + line_width=2, + line_color="red", + annotation_text=f"threshold = {float(user_threshold):.2f}", + annotation_position="top" + ) + + fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict( + title="Discrimination Threshold", + range=[0, 1], + gridcolor="#e0e0e0", + showgrid=True, + zeroline=False + ), + yaxis=dict( + title="Score", + range=[0, 1], + gridcolor="#e0e0e0", + showgrid=True, + zeroline=False + ), + legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="right", + x=1.0 + ), + margin=dict(l=60, r=20, t=40, b=50), + plot_bgcolor="white", + paper_bgcolor="white", + ) + return fig + + +def generate_per_class_metrics_plot( + y_true: Sequence, + y_pred: Sequence, + metrics: Sequence[str] = ("precision", "recall", "f1_score"), + title: str = "Classification Report", + path: Optional[str] = None, +) -> go.Figure: + """ + Per-class metrics heatmap (Plotly), similar to sklearn's classification report. + Rows = classes, columns = metrics; cell text shows the value (0–1). + """ + # Map display names -> sklearn keys + key_map = {"f1_score": "f1-score", "precision": "precision", "recall": "recall"} + report = classification_report( + y_true, y_pred, output_dict=True, zero_division=0 + ) + + # Order classes sensibly (numeric if possible, else lexical) + def _sort_key(x): + try: + return (0, float(x)) + except Exception: + return (1, str(x)) + + # Use all classes seen in y_true or y_pred (so rows don't jump around) + uniq = sorted(set(list(y_true) + list(y_pred)), key=_sort_key) + classes = [str(c) for c in uniq] + + # Build Z matrix (rows=classes, cols=metrics) + used_metrics = [key_map.get(m, m) for m in metrics] + z = [] + for c in classes: + row = report.get(c, {}) + z.append([float(row.get(m, 0.0) or 0.0) for m in used_metrics]) + z = np.array(z, dtype=float) + + # Pretty cell labels + z_text = [[f"{v:.2f}" for v in r] for r in z] + + fig = go.Figure( + data=go.Heatmap( + z=z, + x=list(metrics), # keep display names ("precision", "recall", "f1_score") + y=classes, # classes as strings + colorscale="Reds", + zmin=0.0, + zmax=1.0, + colorbar=dict(title="Value"), + text=z_text, + texttemplate="%{text}", + hovertemplate="Class %{y}<br>%{x}: %{z:.2f}<extra></extra>", + ) + ) + fig.update_layout( + title=None, + xaxis_title="", + yaxis_title="Class", + template="plotly_white", + margin=dict(l=60, r=60, t=70, b=40), + ) + + _save_plotly(fig, path) + return fig + + +def generate_multiclass_roc_curve_plot( + y_true: Sequence, + y_prob: np.ndarray, + classes: Sequence, + title: str = "Multiclass ROC Curve", + path: Optional[str] = None, +) -> go.Figure: + """ + One-vs-rest ROC curves for multiclass (Plotly). + Handles binary passed as 2-column probs as well. + """ + y_true = np.asarray(y_true) + y_prob = np.asarray(y_prob) + + # Normalize to shape (n_samples, n_classes) + if y_prob.ndim == 1 or y_prob.shape[1] == 1: + y_prob = np.hstack([1 - y_prob.reshape(-1, 1), y_prob.reshape(-1, 1)]) + + y_true_bin = label_binarize(y_true, classes=classes) + if y_true_bin.shape[1] == 1 and y_prob.shape[1] == 2: + y_true_bin = np.hstack([1 - y_true_bin, y_true_bin]) + + if y_prob.shape[1] != y_true_bin.shape[1]: + raise ValueError( + f"Shape mismatch: y_prob has {y_prob.shape[1]} columns but y_true_bin has {y_true_bin.shape[1]}." + ) + + fig = go.Figure() + for i, cls in enumerate(classes): + fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_prob[:, i]) + auc_val = roc_auc_score(y_true_bin[:, i], y_prob[:, i]) + fig.add_trace(go.Scatter(x=fpr, y=tpr, mode="lines", name=f"{cls} (AUC {auc_val:.2f})")) + + fig.add_trace( + go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash"), showlegend=False) + ) + fig.update_layout( + title=None, + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + + +def generate_multiclass_pr_curve_plot( + y_true: Sequence, + y_prob: np.ndarray, + classes: Optional[Sequence] = None, + title: str = "Precision–Recall Curve", + path: Optional[str] = None, +) -> go.Figure: + """ + Multiclass PR curves (Plotly). If classes is None or len==2, shows binary PR. + """ + y_true = np.asarray(y_true) + y_prob = np.asarray(y_prob) + fig = go.Figure() + + if not classes or len(classes) == 2: + precision, recall, _ = precision_recall_curve(y_true, y_prob[:, 1]) + ap = average_precision_score(y_true, y_prob[:, 1]) + fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"AP = {ap:.2f}")) + else: + for i, cls in enumerate(classes): + y_true_bin = (y_true == cls).astype(int) + y_prob_cls = y_prob[:, i] + precision, recall, _ = precision_recall_curve(y_true_bin, y_prob_cls) + ap = average_precision_score(y_true_bin, y_prob_cls) + fig.add_trace(go.Scatter(x=recall, y=precision, mode="lines", name=f"{cls} (AP {ap:.2f})")) + + fig.update_layout( + title=None, + xaxis_title="Recall", + yaxis_title="Precision", + yaxis=dict(range=[0, 1]), + xaxis=dict(range=[0, 1]), + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + + +def generate_metric_comparison_bar( + metrics_scores: Mapping[str, Sequence[float]], + phases: Sequence[str] = ("train", "val", "test"), + title: str = "Metric Comparison Across Phases", + path: Optional[str] = None, +) -> go.Figure: + """ + Grouped bar chart comparing metrics across phases (Plotly). + metrics_scores: {metric_name: [train, val, test]} + """ + df = pd.DataFrame(metrics_scores, index=phases).T.reset_index().rename(columns={"index": "Metric"}) + df_m = df.melt(id_vars="Metric", var_name="Phase", value_name="Score") + fig = px.bar(df_m, x="Metric", y="Score", color="Phase", barmode="group", title=None) + ymax = max(1.0, df_m["Score"].max() * 1.05) + fig.update_yaxes(range=[0, ymax]) + fig.update_layout(template="plotly_white") + _save_plotly(fig, path) + return fig + +# ========================= +# Regression Plots +# ========================= + + +def generate_scatter_plot( + y_true: Sequence[float], + y_pred: Sequence[float], + title: str = "Predicted vs Actual", + path: Optional[str] = None, +) -> go.Figure: + """ + Predicted vs. Actual scatter with y=x reference (Plotly). + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + vmin = float(min(np.min(y_true), np.min(y_pred))) + vmax = float(max(np.max(y_true), np.max(y_pred))) + + fig = px.scatter(x=y_true, y=y_pred, opacity=0.6, labels={"x": "Actual", "y": "Predicted"}, title=None) + fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), name="Ideal")) + fig.update_layout(template="plotly_white") + _save_plotly(fig, path) + return fig + + +def generate_residual_plot( + y_true: Sequence[float], + y_pred: Sequence[float], + title: str = "Residual Plot", + path: Optional[str] = None, +) -> go.Figure: + """ + Residuals vs Predicted (Plotly). + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + residuals = y_true - y_pred + + fig = px.scatter(x=y_pred, y=residuals, opacity=0.6, + labels={"x": "Predicted", "y": "Residual (Actual - Predicted)"}, + title=None) + fig.add_hline(y=0, line_dash="dash") + fig.update_layout(template="plotly_white") + _save_plotly(fig, path) + return fig + + +def generate_residual_histogram( + y_true: Sequence[float], + y_pred: Sequence[float], + bins: int = 30, + title: str = "Residual Histogram", + path: Optional[str] = None, +) -> go.Figure: + """ + Residuals histogram (Plotly). + """ + residuals = np.asarray(y_true) - np.asarray(y_pred) + fig = px.histogram(x=residuals, nbins=bins, labels={"x": "Residual"}, title=None) + fig.update_layout(yaxis_title="Frequency", template="plotly_white") + _save_plotly(fig, path) + return fig + + +def generate_regression_calibration_plot( + y_true: Sequence[float], + y_pred: Sequence[float], + num_bins: int = 10, + title: str = "Regression Calibration Plot", + path: Optional[str] = None, +) -> go.Figure: + """ + Binned Actual vs Predicted means (Plotly). + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + + order = np.argsort(y_pred) + y_true_sorted = y_true[order] + y_pred_sorted = y_pred[order] + + bins = np.array_split(np.arange(len(y_pred_sorted)), num_bins) + bin_means_pred = [float(np.mean(y_pred_sorted[idx])) for idx in bins if len(idx)] + bin_means_true = [float(np.mean(y_true_sorted[idx])) for idx in bins if len(idx)] + + vmin = float(min(np.min(y_pred), np.min(y_true))) + vmax = float(max(np.max(y_pred), np.max(y_true))) + + fig = go.Figure() + fig.add_trace(go.Scatter(x=bin_means_pred, y=bin_means_true, mode="lines+markers", + name="Binned Actual vs Predicted")) + fig.add_trace(go.Scatter(x=[vmin, vmax], y=[vmin, vmax], mode="lines", line=dict(dash="dash"), + name="Ideal")) + fig.update_layout( + title=None, + xaxis_title="Mean Predicted per bin", + yaxis_title="Mean Actual per bin", + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + +# ========================= +# Confidence / Diagnostics +# ========================= + + +def plot_error_vs_confidence( + y_true: Union[Sequence[int], np.ndarray], + y_proba: Union[Sequence[float], np.ndarray], + n_bins: int = 10, + title: str = "Error vs Confidence", + path: Optional[str] = None, +) -> go.Figure: + """ + Error rate vs confidence (binary), confidence=max(p, 1-p). Plotly. + """ + y_true = np.asarray(y_true) + y_proba = np.asarray(y_proba).reshape(-1) + y_pred = (y_proba >= 0.5).astype(int) + confidence = np.maximum(y_proba, 1 - y_proba) + error = (y_pred != y_true).astype(int) + + bins = np.linspace(0.0, 1.0, n_bins + 1) + idx = np.digitize(confidence, bins, right=True) + + centers, err_rates = [], [] + for i in range(1, len(bins)): + mask = (idx == i) + if mask.any(): + centers.append(float(confidence[mask].mean())) + err_rates.append(float(error[mask].mean())) + + fig = go.Figure() + fig.add_trace(go.Scatter(x=centers, y=err_rates, mode="lines+markers", name="Error rate")) + fig.update_layout( + title=None, + xaxis_title="Confidence (max predicted probability)", + yaxis_title="Error Rate", + yaxis=dict(range=[0, 1]), + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + + +def plot_confidence_histogram( + y_proba: np.ndarray, + bins: int = 20, + title: str = "Confidence Histogram", + path: Optional[str] = None, +) -> go.Figure: + """ + Histogram of max predicted probabilities (Plotly). + Works for binary (n_samples,) or (n_samples,2) and multiclass (n_samples,C). + """ + y_proba = np.asarray(y_proba) + if y_proba.ndim == 1: + confidences = np.maximum(y_proba, 1 - y_proba) + else: + confidences = np.max(y_proba, axis=1) + + fig = px.histogram( + x=confidences, + nbins=bins, + range_x=(0, 1), + histnorm="percent", + labels={"x": "Confidence (max predicted probability)", "y": "Percent of samples (%)"}, + title=None, + ) + if fig.data: + fig.update_traces(hovertemplate="Conf=%{x:.2f}<br>%{y:.2f}%<extra></extra>") + fig.update_layout(yaxis_title="Percent of samples (%)", template="plotly_white") + _save_plotly(fig, path) + return fig + +# ========================= +# Learning Curve +# ========================= + + +def generate_learning_curve_from_predictions( + y_true, + y_pred=None, + y_proba=None, + classes=None, + metric: str = "accuracy", + train_fracs: np.ndarray = np.linspace(0.1, 1.0, 10), + n_repeats: int = 5, + seed: int = 42, + title: str = "Learning Curve", + path: str | None = None, + return_stats: bool = False, +) -> Union[go.Figure, tuple[list[int], list[float], list[float]]]: + rng = np.random.default_rng(seed) + y_true = np.asarray(y_true) + N = len(y_true) + + if metric == "accuracy" and y_pred is None: + raise ValueError("accuracy curve requires y_pred") + if metric == "log_loss" and y_proba is None: + raise ValueError("log_loss curve requires y_proba") + + if y_proba is not None: + y_proba = np.asarray(y_proba) + if y_pred is not None: + y_pred = np.asarray(y_pred) + + sizes = (np.clip((train_fracs * N).astype(int), 1, N)).tolist() + means, stds = [], [] + for n in sizes: + vals = [] + for _ in range(n_repeats): + idx = rng.choice(N, size=n, replace=False) + if metric == "accuracy": + vals.append(float((y_true[idx] == y_pred[idx]).mean())) + else: + if y_proba.ndim == 1: + p = y_proba[idx] + pp = np.column_stack([1 - p, p]) + else: + pp = y_proba[idx] + vals.append(float(log_loss(y_true[idx], pp, labels=None if classes is None else classes))) + means.append(np.mean(vals)) + stds.append(np.std(vals)) + + if return_stats: + return sizes, means, stds + + fig = go.Figure() + fig.add_trace(go.Scatter( + x=sizes, y=means, mode="lines+markers", name="Train", + line=dict(width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=stds, visible=True) + )) + fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="epoch" if metric == "log_loss" else "samples", gridcolor="#eee"), + yaxis=dict(title=("loss" if metric == "log_loss" else "accuracy"), gridcolor="#eee"), + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=50, r=20, t=60, b=50), + ) + if path: + _save_plotly(fig, path) + return fig + + +def build_train_html_and_plots( + predictor, + problem_type: str, + df_train: pd.DataFrame, + label_column: str, + tmpdir: str, + df_val: Optional[pd.DataFrame] = None, + seed: int = 42, + perf_table_html: str | None = None, + threshold: Optional[float] = None, + section_tile: str = "Training Diagnostics", +) -> str: + y_true = df_train[label_column].values + y_true_val = df_val[label_column].values if df_val is not None else None + # predictions on TRAIN + pred_labels, pred_proba = None, None + try: + pred_labels = predictor.predict(df_train) + except Exception: + pass + try: + proba_raw = predictor.predict_proba(df_train) + pred_proba = proba_raw.to_numpy() if isinstance(proba_raw, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw) + except Exception: + pred_proba = None + + # predictions on VAL (if provided) + pred_labels_val, pred_proba_val = None, None + if df_val is not None: + try: + pred_labels_val = predictor.predict(df_val) + except Exception: + pred_labels_val = None + try: + proba_raw_val = predictor.predict_proba(df_val) + pred_proba_val = proba_raw_val.to_numpy() if isinstance(proba_raw_val, (pd.Series, pd.DataFrame)) else np.asarray(proba_raw_val) + except Exception: + pred_proba_val = None + + pos_scores_train: Optional[np.ndarray] = None + pos_scores_val: Optional[np.ndarray] = None + if problem_type == "binary": + if pred_proba is not None: + pos_scores_train = ( + pred_proba.reshape(-1) + if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) + else pred_proba[:, -1] + ) + if pred_proba_val is not None: + pos_scores_val = ( + pred_proba_val.reshape(-1) + if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) + else pred_proba_val[:, -1] + ) + + # Collect plots then append in desired order + perf_card = f"<div class='card'>{perf_table_html}</div>" if perf_table_html else None + acc_plot = loss_plot = None + cm_train = pc_train = cm_val = pc_val = None + threshold_val_plot = None + roc_combined = pr_combined = cal_combined = None + mc_roc_val = None + conf_train = conf_val = None + bar_train = bar_val = None + + # 1) Learning Curve — Accuracy + if problem_type in ("binary", "multiclass"): + acc_fig = go.Figure() + added_acc = False + if pred_labels is not None: + train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( + y_true=y_true, + y_pred=np.asarray(pred_labels), + metric="accuracy", + title="Learning Curves — Label Accuracy", + seed=seed, + return_stats=True, + ) + acc_fig.add_trace(go.Scatter( + x=train_sizes, y=train_means, mode="lines+markers", name="Train", + line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=train_stds, visible=True), + )) + added_acc = True + if pred_labels_val is not None and y_true_val is not None: + val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( + y_true=y_true_val, + y_pred=np.asarray(pred_labels_val), + metric="accuracy", + title="Learning Curves — Label Accuracy", + seed=seed, + return_stats=True, + ) + acc_fig.add_trace(go.Scatter( + x=val_sizes, y=val_means, mode="lines+markers", name="Validation", + line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=val_stds, visible=True), + )) + added_acc = True + if added_acc: + acc_fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="samples", gridcolor="#eee"), + yaxis=dict(title="accuracy", gridcolor="#eee"), + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=50, r=20, t=60, b=50), + ) + acc_plot = plot_with_table_style_title(acc_fig, "Learning Curves — Label Accuracy") + + # 2) Learning Curve — Loss + if problem_type in ("binary", "multiclass"): + classes = np.unique(y_true) + loss_fig = go.Figure() + added_loss = False + if pred_proba is not None: + pp = pred_proba.reshape(-1) if pred_proba.ndim == 1 or (pred_proba.ndim == 2 and pred_proba.shape[1] == 1) else pred_proba + train_sizes, train_means, train_stds = generate_learning_curve_from_predictions( + y_true=y_true, + y_proba=pp, + classes=classes, + metric="log_loss", + title="Learning Curves — Label Loss", + seed=seed, + return_stats=True, + ) + loss_fig.add_trace(go.Scatter( + x=train_sizes, y=train_means, mode="lines+markers", name="Train", + line=dict(color="#1f77b4", width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=train_stds, visible=True), + )) + added_loss = True + if pred_proba_val is not None and y_true_val is not None: + pp_val = pred_proba_val.reshape(-1) if pred_proba_val.ndim == 1 or (pred_proba_val.ndim == 2 and pred_proba_val.shape[1] == 1) else pred_proba_val + val_sizes, val_means, val_stds = generate_learning_curve_from_predictions( + y_true=y_true_val, + y_proba=pp_val, + classes=classes, + metric="log_loss", + title="Learning Curves — Label Loss", + seed=seed, + return_stats=True, + ) + loss_fig.add_trace(go.Scatter( + x=val_sizes, y=val_means, mode="lines+markers", name="Validation", + line=dict(color="#ff7f0e", width=3, shape="spline"), marker=dict(size=7), + error_y=dict(type="data", array=val_stds, visible=True), + )) + added_loss = True + if added_loss: + loss_fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="epoch", gridcolor="#eee"), + yaxis=dict(title="loss", gridcolor="#eee"), + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=50, r=20, t=60, b=50), + ) + loss_plot = plot_with_table_style_title(loss_fig, "Learning Curves — Label Loss") + + # Confusion matrices & per-class metrics + cm_train = pc_train = cm_val = pc_val = None + + # Probability diagnostics (binary) + if problem_type == "binary": + # Combined Calibration (Train/Val) + cal_fig = go.Figure() + added_cal = False + if pos_scores_train is not None: + y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) + prob_true, prob_pred = calibration_curve(y_bin_train, pos_scores_train, n_bins=10, strategy="uniform") + cal_fig.add_trace(go.Scatter( + x=prob_pred, y=prob_true, mode="lines+markers", + name="Train", + line=dict(color="#1f77b4", width=3), + marker=dict(size=7, color="#1f77b4"), + )) + added_cal = True + if pos_scores_val is not None and y_true_val is not None: + y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) + prob_true_v, prob_pred_v = calibration_curve(y_bin_val, pos_scores_val, n_bins=10, strategy="uniform") + cal_fig.add_trace(go.Scatter( + x=prob_pred_v, y=prob_true_v, mode="lines+markers", + name="Validation", + line=dict(color="#ff7f0e", width=3), + marker=dict(size=7, color="#ff7f0e"), + )) + added_cal = True + if added_cal: + cal_fig.add_trace(go.Scatter( + x=[0, 1], y=[0, 1], + mode="lines", + line=dict(dash="dash", color="#808080", width=2), + name="Perfect", + showlegend=True, + )) + cal_fig.update_layout( + title=None, + xaxis_title="Predicted Probability", + yaxis_title="Observed Probability", + xaxis=dict(range=[0, 1]), + yaxis=dict(range=[0, 1]), + template="plotly_white", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=60, r=40, t=50, b=50), + ) + cal_combined = plot_with_table_style_title(cal_fig, "Calibration Curve (Train vs Validation)") + + # Combined ROC (Train/Val) + roc_fig = go.Figure() + added_roc = False + if pos_scores_train is not None: + y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) + fpr_tr, tpr_tr, thr_tr = roc_curve(y_bin_train, pos_scores_train) + roc_fig.add_trace(go.Scatter( + x=fpr_tr, y=tpr_tr, mode="lines", + name="Train", + line=dict(color="#1f77b4", width=3), + )) + if threshold is not None and np.isfinite(thr_tr).any(): + finite = np.isfinite(thr_tr) + idx_local = int(np.argmin(np.abs(thr_tr[finite] - float(threshold)))) + idx = int(np.nonzero(finite)[0][idx_local]) + roc_fig.add_trace(go.Scatter( + x=[fpr_tr[idx]], y=[tpr_tr[idx]], + mode="markers", + name="Train @ threshold", + marker=dict(size=12, color="#1f77b4", symbol="x") + )) + added_roc = True + if pos_scores_val is not None and y_true_val is not None: + y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) + fpr_v, tpr_v, thr_v = roc_curve(y_bin_val, pos_scores_val) + roc_fig.add_trace(go.Scatter( + x=fpr_v, y=tpr_v, mode="lines", + name="Validation", + line=dict(color="#ff7f0e", width=3), + )) + if threshold is not None and np.isfinite(thr_v).any(): + finite = np.isfinite(thr_v) + idx_local = int(np.argmin(np.abs(thr_v[finite] - float(threshold)))) + idx = int(np.nonzero(finite)[0][idx_local]) + roc_fig.add_trace(go.Scatter( + x=[fpr_v[idx]], y=[tpr_v[idx]], + mode="markers", + name="Val @ threshold", + marker=dict(size=12, color="#ff7f0e", symbol="x") + )) + added_roc = True + if added_roc: + roc_fig.add_trace(go.Scatter( + x=[0, 1], y=[0, 1], mode="lines", + line=dict(dash="dash", width=2, color="#808080"), + showlegend=False + )) + roc_fig.update_layout( + title=None, + xaxis_title="False Positive Rate", + yaxis_title="True Positive Rate", + template="plotly_white", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=60, r=20, t=60, b=60), + ) + roc_combined = plot_with_table_style_title(roc_fig, "ROC Curve (Train vs Validation)") + + # Combined PR (Train/Val) + pr_fig = go.Figure() + added_pr = False + if pos_scores_train is not None: + y_bin_train = (y_true == np.max(np.unique(y_true))).astype(int) + prec_tr, rec_tr, thr_tr = precision_recall_curve(y_bin_train, pos_scores_train) + pr_auc_tr = auc(rec_tr, prec_tr) + pr_fig.add_trace(go.Scatter( + x=rec_tr, y=prec_tr, mode="lines", + name=f"Train (AUC={pr_auc_tr:.3f})", + line=dict(color="#1f77b4", width=3), + )) + if threshold is not None and len(thr_tr): + j = int(np.argmin(np.abs(thr_tr - float(threshold)))) + j = int(np.clip(j, 0, len(thr_tr) - 1)) + pr_fig.add_trace(go.Scatter( + x=[rec_tr[j + 1]], y=[prec_tr[j + 1]], + mode="markers", + name="Train @ threshold", + marker=dict(size=12, color="#1f77b4", symbol="x") + )) + added_pr = True + if pos_scores_val is not None and y_true_val is not None: + y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) + prec_v, rec_v, thr_v = precision_recall_curve(y_bin_val, pos_scores_val) + pr_auc_v = auc(rec_v, prec_v) + pr_fig.add_trace(go.Scatter( + x=rec_v, y=prec_v, mode="lines", + name=f"Validation (AUC={pr_auc_v:.3f})", + line=dict(color="#ff7f0e", width=3), + )) + if threshold is not None and len(thr_v): + j = int(np.argmin(np.abs(thr_v - float(threshold)))) + j = int(np.clip(j, 0, len(thr_v) - 1)) + pr_fig.add_trace(go.Scatter( + x=[rec_v[j + 1]], y=[prec_v[j + 1]], + mode="markers", + name="Val @ threshold", + marker=dict(size=12, color="#ff7f0e", symbol="x") + )) + added_pr = True + if added_pr: + pr_fig.update_layout( + title=None, + xaxis_title="Recall", + yaxis_title="Precision", + template="plotly_white", + legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0), + margin=dict(l=60, r=20, t=60, b=60), + ) + pr_combined = plot_with_table_style_title(pr_fig, "Precision–Recall Curve (Train vs Validation)") + + if pos_scores_val is not None and y_true_val is not None: + y_bin_val = (y_true_val == np.max(np.unique(y_true_val))).astype(int) + fig_thr_val = generate_threshold_plot(y_true_bin=y_bin_val, y_prob=pos_scores_val, title="Threshold Plot (Validation)", + user_threshold=threshold) + threshold_val_plot = plot_with_table_style_title(fig_thr_val, "Threshold Plot (Validation)") + + # Multiclass OVR ROC (validation) + if problem_type == "multiclass" and pred_proba_val is not None and pred_proba_val.ndim >= 2 and y_true_val is not None: + classes_val = np.unique(y_true_val) + fig_mc_roc_val = generate_multiclass_roc_curve_plot(y_true_val, pred_proba_val, classes_val, title="One-vs-Rest ROC (Validation)") + mc_roc_val = plot_with_table_style_title(fig_mc_roc_val, "One-vs-Rest ROC (Validation)") + + # Prediction Confidence Histogram (train/val) + conf_train = conf_val = None + + # Per-class accuracy bars + if problem_type in ("binary", "multiclass") and pred_labels is not None: + classes_for_bar = pd.Index(np.unique(y_true), dtype=object).tolist() + acc_vals = [] + for c in classes_for_bar: + mask = y_true == c + acc_vals.append(float((np.asarray(pred_labels)[mask] == c).mean()) if mask.any() else 0.0) + bar_fig = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar], y=acc_vals, marker_color="#1f77b4")) + bar_fig.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="Label", gridcolor="#eee"), + yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), + margin=dict(l=50, r=20, t=60, b=50), + ) + bar_train = plot_with_table_style_title(bar_fig, "Per-Class Training Accuracy") + if problem_type in ("binary", "multiclass") and pred_labels_val is not None and y_true_val is not None: + classes_for_bar_val = pd.Index(np.unique(y_true_val), dtype=object).tolist() + acc_vals_val = [] + for c in classes_for_bar_val: + mask = y_true_val == c + acc_vals_val.append(float((np.asarray(pred_labels_val)[mask] == c).mean()) if mask.any() else 0.0) + bar_fig_val = go.Figure(data=go.Bar(x=[str(c) for c in classes_for_bar_val], y=acc_vals_val, marker_color="#ff7f0e")) + bar_fig_val.update_layout( + title=None, + template="plotly_white", + xaxis=dict(title="Label", gridcolor="#eee"), + yaxis=dict(title="Accuracy", gridcolor="#eee", range=[0, 1]), + margin=dict(l=50, r=20, t=60, b=50), + ) + bar_val = plot_with_table_style_title(bar_fig_val, "Per-Class Validation Accuracy") + + # Assemble in requested order + pieces: list[str] = [] + if perf_card: + pieces.append(perf_card) + for block in (threshold_val_plot, roc_combined, pr_combined): + if block: + pieces.append(block) + # Remaining plots (keep existing order) + for block in (cal_combined, cm_train, pc_train, cm_val, pc_val, mc_roc_val, conf_train, conf_val, bar_train, bar_val): + if block: + pieces.append(block) + # Learning curves should appear last in the tab + for block in (acc_plot, loss_plot): + if block: + pieces.append(block) + + if not pieces: + return "<h2>Training Diagnostics</h2><p><em>No training diagnostics available for this run.</em></p>" + + return "<h2>Train and Validation Performance Summary</h2>" + "".join(pieces) + + +def generate_learning_curve( + estimator, + X, + y, + scoring: str = "r2", + cv_folds: int = 5, + n_jobs: int = -1, + train_sizes: np.ndarray = np.linspace(0.1, 1.0, 10), + title: str = "Learning Curve", + path: Optional[str] = None, +) -> go.Figure: + """ + Learning curve using sklearn.learning_curve, visualized with Plotly. + """ + sizes, train_scores, test_scores = skl_learning_curve( + estimator, X, y, cv=cv_folds, scoring=scoring, n_jobs=n_jobs, train_sizes=train_sizes + ) + train_mean = train_scores.mean(axis=1) + train_std = train_scores.std(axis=1) + test_mean = test_scores.mean(axis=1) + test_std = test_scores.std(axis=1) + + fig = go.Figure() + fig.add_trace(go.Scatter( + x=sizes, y=train_mean, mode="lines+markers", name="Training score", + error_y=dict(type="data", array=train_std, visible=True) + )) + fig.add_trace(go.Scatter( + x=sizes, y=test_mean, mode="lines+markers", name="CV score", + error_y=dict(type="data", array=test_std, visible=True) + )) + fig.update_layout( + title=None, + xaxis_title="Training examples", + yaxis_title=scoring, + template="plotly_white", + ) + _save_plotly(fig, path) + return fig + +# ========================= +# SHAP (Matplotlib-based) +# ========================= + + +def generate_shap_summary_plot( + shap_values, features: pd.DataFrame, title: str = "SHAP Summary Plot", path: Optional[str] = None +) -> None: + """ + SHAP summary plot (Matplotlib). SHAP's interactive support with Plotly is limited; + keep matplotlib for clarity and stability. + """ + plt.figure(figsize=(10, 8)) + shap.summary_plot(shap_values, features, show=False) + plt.title(title) + _save_matplotlib(path) + + +def generate_shap_force_plot( + explainer, instance: pd.DataFrame, title: str = "SHAP Force Plot", path: Optional[str] = None +) -> None: + """ + SHAP force plot (Matplotlib). + """ + shap_values = explainer(instance) + plt.figure(figsize=(10, 4)) + shap.plots.force(shap_values[0], show=False) + plt.title(title) + _save_matplotlib(path) + + +def generate_shap_waterfall_plot( + explainer, instance: pd.DataFrame, title: str = "SHAP Waterfall Plot", path: Optional[str] = None +) -> None: + """ + SHAP waterfall plot (Matplotlib). + """ + shap_values = explainer(instance) + plt.figure(figsize=(10, 6)) + shap.plots.waterfall(shap_values[0], show=False) + plt.title(title) + _save_matplotlib(path) + + +def infer_problem_type(predictor, df_train_full: pd.DataFrame, label_column: str) -> str: + """ + Return 'binary', 'multiclass', or 'regression'. + Prefer the predictor's own metadata when available; otherwise infer from label dtype/uniques. + """ + # AutoGluon predictors usually expose .problem_type; be defensive. + pt = getattr(predictor, "problem_type", None) + if isinstance(pt, str): + pt_l = pt.lower() + if "regression" in pt_l: + return "regression" + if "binary" in pt_l: + return "binary" + if "multiclass" in pt_l or "multiclass" in pt_l: + return "multiclass" + + y = df_train_full[label_column] + if pd.api.types.is_numeric_dtype(y) and y.nunique() > 10: + return "regression" + return "binary" if y.nunique() == 2 else "multiclass" + + +def _safe_floatify(d: Dict[str, Any]) -> Dict[str, float]: + """Make evaluate() outputs JSON/csv friendly floats.""" + out = {} + for k, v in d.items(): + try: + out[k] = float(v) + except Exception: + # keep only real-valued scalars + pass + return out + + +def evaluate_all( + predictor, + df_train: pd.DataFrame, + df_val: pd.DataFrame, + df_test: pd.DataFrame, + label_column: str, + problem_type: str, +) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]: + """ + Run predictor.evaluate on train/val/test and normalize the result dicts to floats. + MultiModalPredictor does not accept the `silent` kwarg, so call defensively. + """ + def _evaluate(df): + try: + return predictor.evaluate(df, silent=True) + except TypeError: + return predictor.evaluate(df) + + train_scores = _safe_floatify(_evaluate(df_train)) + val_scores = _safe_floatify(_evaluate(df_val)) + test_scores = _safe_floatify(_evaluate(df_test)) + return train_scores, val_scores, test_scores + + +def build_summary_html( + predictor, + df_train: pd.DataFrame, + df_val: Optional[pd.DataFrame], + df_test: Optional[pd.DataFrame], + label_column: str, + extra_run_rows: Optional[list[tuple[str, str]]] = None, + class_balance_html: Optional[str] = None, + perf_table_html: Optional[str] = None, +) -> str: + sections = [] + + # Dataset Overview (first section in the tab) + if class_balance_html: + sections.append(f""" +<section class="section"> + <h2 class="section-title">Dataset Overview</h2> + <div class="card"> + {class_balance_html} + </div> +</section> +""".strip()) + + # Performance Summary + if perf_table_html: + sections.append(f""" +<section class="section"> + <h2 class="section-title">Model Performance Summary</h2> + <div class="card"> + {perf_table_html} + </div> +</section> +""".strip()) + + # Model Configuration + + # Remove Predictor type and Framework, and ensure Model Architecture is present + base_rows: list[tuple[str, str]] = [] + if extra_run_rows: + # Remove any rows with keys 'Predictor type' or 'Framework' + base_rows.extend([(k, v) for (k, v) in extra_run_rows if k not in ("Predictor type", "Framework")]) + + def _fmt(v): + if v is None or v == "": + return "—" + return _escape(str(v)) + + rows_html = "\n".join( + f"<tr><td>{_escape(str(k))}</td><td>{_fmt(v)}</td></tr>" + for k, v in base_rows + ) + + sections.append(f""" +<section class="section"> + <h2 class="section-title">Model Configuration</h2> + <div class="card"> + <table class="kv-table"> + <thead><tr><th>Key</th><th>Value</th></tr></thead> + <tbody> + {rows_html} + </tbody> + </table> + </div> +</section> +""".strip()) + + return "\n".join(sections).strip() + + +def build_feature_importance_html(predictor, df_train: pd.DataFrame, label_column: str) -> str: + """Build a visualization of feature importance.""" + try: + # Try to get feature importance from predictor + fi = None + if hasattr(predictor, "feature_importance") and callable(predictor.feature_importance): + try: + fi = predictor.feature_importance(df_train) + except Exception as e: + return f"<p>Could not compute feature importance: {e}</p>" + + if fi is None or (isinstance(fi, pd.DataFrame) and fi.empty): + return "<p>Feature importance not available for this model.</p>" + + # Format as a sortable table + rows = [] + if isinstance(fi, pd.DataFrame): + fi = fi.sort_values("importance", ascending=False) + for _, row in fi.iterrows(): + feat = row.index[0] if isinstance(row.index, pd.Index) else row["feature"] + imp = float(row["importance"]) + rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{imp:.4f}</td></tr>") + else: + # Handle other formats (dict, etc) + for feat, imp in sorted(fi.items(), key=lambda x: float(x[1]), reverse=True): + rows.append(f"<tr><td>{_escape(str(feat))}</td><td>{float(imp):.4f}</td></tr>") + + if not rows: + return "<p>No feature importance values available.</p>" + + table_html = f""" + <table class="performance-summary"> + <thead> + <tr> + <th class="sortable">Feature</th> + <th class="sortable">Importance</th> + </tr> + </thead> + <tbody> + {"".join(rows)} + </tbody> + </table> + """ + return table_html + + except Exception as e: + return f"<p>Error building feature importance visualization: {e}</p>" + + +def build_test_html_and_plots( + predictor, + problem_type: str, + df_test: pd.DataFrame, + label_column: str, + tmpdir: str, + threshold: Optional[float] = None, +) -> Tuple[str, List[str]]: + """ + Create a test-summary section (with a placeholder for metric rows) and a list of Plotly HTML divs. + Returns: (html_template_with_{}, list_of_plot_divs) + """ + plots: List[str] = [] + + y_true = df_test[label_column].values + classes = np.unique(y_true) + + # Try proba/labels where meaningful + pred_labels = None + pred_proba = None + try: + pred_labels = predictor.predict(df_test) + except Exception: + pass + try: + # MultiModalPredictor exposes predict_proba for classification problems. + pred_proba = predictor.predict_proba(df_test) + except Exception: + pred_proba = None + + proba_arr = None + if pred_proba is not None: + if isinstance(pred_proba, pd.Series): + proba_arr = pred_proba.to_numpy().reshape(-1, 1) + elif isinstance(pred_proba, pd.DataFrame): + proba_arr = pred_proba.to_numpy() + else: + proba_arr = np.asarray(pred_proba) + + # Thresholded labels for binary + if problem_type == "binary" and threshold is not None and proba_arr is not None: + pos_label, neg_label = classes.max(), classes.min() + pos_scores = proba_arr.reshape(-1) if (proba_arr.ndim == 1 or proba_arr.shape[1] == 1) else proba_arr[:, -1] + pred_labels = np.where(pos_scores >= float(threshold), pos_label, neg_label) + + # Confusion matrix / per-class now reflect thresholded labels + if problem_type in ("binary", "multiclass") and pred_labels is not None: + cm_title = "Confusion Matrix" + if threshold is not None and problem_type == "binary": + thr_str = f"{float(threshold):.3f}".rstrip("0").rstrip(".") + cm_title = f"Confusion Matrix (Threshold = {thr_str})" + fig_cm = generate_confusion_matrix_plot(y_true, pred_labels, title=cm_title) + plots.append(plot_with_table_style_title(fig_cm, cm_title)) + + fig_pc = generate_per_class_metrics_plot(y_true, pred_labels, title="Per-Class Metrics") + plots.append(plot_with_table_style_title(fig_pc, "Per-Class Metrics")) + + # ROC/PR where possible — choose positive-class scores safely + pos_label = classes.max() # or set explicitly, e.g., 1 or "yes" + + if isinstance(pred_proba, pd.DataFrame): + proba_arr = pred_proba.to_numpy() + if pos_label in pred_proba.columns: + pos_idx = list(pred_proba.columns).index(pos_label) + else: + pos_idx = -1 # fallback to last column + elif isinstance(pred_proba, pd.Series): + proba_arr = pred_proba.to_numpy().reshape(-1, 1) + pos_idx = 0 + else: + proba_arr = np.asarray(pred_proba) if pred_proba is not None else None + pos_idx = -1 if (proba_arr is not None and proba_arr.ndim == 2 and proba_arr.shape[1] > 1) else 0 + + if proba_arr is not None: + y_bin = (y_true == pos_label).astype(int) + pos_scores = ( + proba_arr.reshape(-1) + if proba_arr.ndim == 1 or proba_arr.shape[1] == 1 + else proba_arr[:, pos_idx] + ) + + fig_roc = generate_roc_curve_plot(y_bin, pos_scores, title="ROC Curve", marker_threshold=threshold) + plots.append(plot_with_table_style_title(fig_roc, f"ROC Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) + + fig_pr = generate_pr_curve_plot(y_bin, pos_scores, title="Precision–Recall Curve", marker_threshold=threshold) + plots.append(plot_with_table_style_title(fig_pr, f"Precision–Recall Curve{'' if threshold is None else f' (marker at threshold={threshold:.2f})'}")) + + # Additional diagnostics aligned with ImageLearner style + if problem_type == "binary": + conf_fig = plot_confidence_histogram(pos_scores, bins=20, title="Prediction Confidence (Test)") + plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Test)")) + else: + conf_fig = plot_confidence_histogram(proba_arr, bins=20, title="Prediction Confidence (Top-1, Test)") + plots.append(plot_with_table_style_title(conf_fig, "Prediction Confidence (Top-1, Test)")) + + if problem_type == "multiclass" and proba_arr is not None and proba_arr.ndim >= 2: + fig_mc_roc = generate_multiclass_roc_curve_plot(y_true, proba_arr, classes, title="One-vs-Rest ROC (Test)") + plots.append(plot_with_table_style_title(fig_mc_roc, "One-vs-Rest ROC (Test)")) + + # Regression visuals + if problem_type == "regression": + if pred_labels is None: + pred_labels = predictor.predict(df_test) + fig_sc = generate_scatter_plot(y_true, pred_labels, title="Predicted vs Actual") + plots.append(plot_with_table_style_title(fig_sc, "Predicted vs Actual")) + + fig_res = generate_residual_plot(y_true, pred_labels, title="Residual Plot") + plots.append(plot_with_table_style_title(fig_res, "Residual Plot")) + + fig_hist = generate_residual_histogram(y_true, pred_labels, title="Residual Histogram") + plots.append(plot_with_table_style_title(fig_hist, "Residual Histogram")) + + fig_cal = generate_regression_calibration_plot(y_true, pred_labels, title="Regression Calibration") + plots.append(plot_with_table_style_title(fig_cal, "Regression Calibration")) + + # Small HTML template with placeholder for metric rows the caller fills in + test_html_template = """ + <h2>Test Performance Summary</h2> + <table class="performance-summary"> + <thead><tr><th>Metric</th><th>Test</th></tr></thead> + <tbody>{}</tbody> + </table> + """ + return test_html_template, plots + + +def build_feature_html( + predictor, + df_train: pd.DataFrame, + label_column: str, + include_modalities: bool = True, # ← NEW + include_class_balance: bool = True, # ← NEW +) -> str: + sections = [] + + # (Typical feature importance content…) + fi_html = build_feature_importance_html(predictor, df_train, label_column) + sections.append(f"<section class='section'><h2 class='section-title'>Feature Importance</h2><div class='card'>{fi_html}</div></section>") + + # Previously: Modalities & Inputs and/or Class Balance may have been here. + # Only render them if flags are True. + if include_modalities: + from report_utils import build_modalities_html + modalities_html = build_modalities_html(predictor, df_train, label_column) + sections.append(f"<section class='section'><h2 class='section-title'>Modalities & Inputs</h2><div class='card'>{modalities_html}</div></section>") + + if include_class_balance: + from report_utils import build_class_balance_html + cb_html = build_class_balance_html(df_train, label_column) + sections.append(f"<section class='section'><h2 class='section-title'>Class Balance (Train Full)</h2><div class='card'>{cb_html}</div></section>") + + return "\n".join(sections) + + +def assemble_full_html_report( + summary_html: str, + train_html: str, + test_html: str, + plots: List[str], + feature_html: str, +) -> str: + """ + Wrap the four tabs using utils.build_tabbed_html and return full HTML. + """ + # Append plots under the Test tab (already wrapped with titles) + test_full = test_html + "".join(plots) + + tabs = build_tabbed_html(summary_html, train_html, test_full, feature_html, explainer_html=None) + + html_out = get_html_template() + + # 🔧 Ensure Plotly JS is available (we render plots with include_plotlyjs=False) + html_out += '\n<script src="https://cdn.plot.ly/plotly-2.30.0.min.js"></script>\n' + + # Optional: centering tweaks + html_out += """ +<style> + .plotly-center { display: flex; justify-content: center; } + .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { margin: 0 auto !important; } + .js-plotly-plot, .plotly-graph-div { margin-left: auto !important; margin-right: auto !important; } +</style> +""" + # Help modal HTML/JS + html_out += get_metrics_help_modal() + + html_out += tabs + html_out += get_html_closing() + return html_out
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/report_utils.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,1116 @@ +import base64 +import html +import json +import logging +import os +import platform +import shutil +import sys +import tempfile +from datetime import datetime +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +import yaml +from utils import verify_outputs + +logger = logging.getLogger(__name__) + + +def _escape(s: Any) -> str: + return html.escape(str(s)) + + +def _write_predictor_path(predictor): + try: + pred_path = getattr(predictor, "path", None) + if pred_path: + with open("predictor_path.txt", "w") as pf: + pf.write(str(pred_path)) + logger.info("Wrote predictor path → predictor_path.txt") + return pred_path + except Exception: + logger.warning("Could not write predictor_path.txt") + return None + + +def _copy_config_if_available(pred_path: Optional[str], output_config: Optional[str]): + if not output_config: + return + try: + config_yaml_path = os.path.join(pred_path, "config.yaml") if pred_path else None + if config_yaml_path and os.path.isfile(config_yaml_path): + shutil.copy2(config_yaml_path, output_config) + logger.info(f"Wrote AutoGluon config → {output_config}") + else: + with open(output_config, "w") as cfg_out: + cfg_out.write("# config.yaml not found for this run\n") + logger.warning(f"AutoGluon config.yaml not found; created placeholder at {output_config}") + except Exception as e: + logger.error(f"Failed to write config output '{output_config}': {e}") + try: + with open(output_config, "w") as cfg_out: + cfg_out.write(f"# Failed to copy config.yaml: {e}\n") + except Exception: + pass + + +def _load_config_yaml(args, predictor) -> dict: + """ + Load config.yaml either from the predictor path or the exported output_config. + """ + candidates = [] + pred_path = getattr(predictor, "path", None) + if pred_path: + cfg_path = os.path.join(pred_path, "config.yaml") + if os.path.isfile(cfg_path): + candidates.append(cfg_path) + if args.output_config and os.path.isfile(args.output_config): + candidates.append(args.output_config) + + for p in candidates: + try: + with open(p, "r") as f: + return yaml.safe_load(f) or {} + except Exception: + continue + return {} + + +def _summarize_config(cfg: dict, args) -> List[tuple[str, str]]: + """ + Build rows describing model components and key hyperparameters from a loaded config.yaml. + Falls back to CLI args when config values are missing. + """ + rows: List[tuple[str, str]] = [] + model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {} + names = model_cfg.get("names") or [] + if names: + rows.append(("Model components", ", ".join(names))) + + # Tabular backbone with data types + tabular_val = "—" + for k, v in model_cfg.items(): + if k in ("names", "hf_text", "timm_image"): + continue + if isinstance(v, dict) and "data_types" in v: + dtypes = v.get("data_types") or [] + if any(t in ("categorical", "numerical") for t in dtypes): + dt_str = ", ".join(dtypes) if dtypes else "" + tabular_val = f"{k} ({dt_str})" if dt_str else k + break + rows.append(("Tabular backbone", tabular_val)) + + image_val = model_cfg.get("timm_image", {}).get("checkpoint_name") or "—" + rows.append(("Image backbone", image_val)) + + text_val = model_cfg.get("hf_text", {}).get("checkpoint_name") or "—" + rows.append(("Text backbone", text_val)) + + fusion_val = "—" + for k in model_cfg.keys(): + if str(k).startswith("fusion"): + fusion_val = k + break + rows.append(("Fusion backbone", fusion_val)) + + # Optimizer block + optim_cfg = cfg.get("optim", {}) if isinstance(cfg, dict) else {} + optim_map = [ + ("optim_type", "Optimizer"), + ("lr", "Learning rate"), + ("weight_decay", "Weight decay"), + ("lr_decay", "LR decay"), + ("max_epochs", "Max epochs"), + ("max_steps", "Max steps"), + ("patience", "Early-stop patience"), + ("check_val_every_n_epoch", "Val check every N epochs"), + ("top_k", "Top K checkpoints"), + ("top_k_average_method", "Top K averaging"), + ] + for key, label in optim_map: + if key in optim_cfg: + rows.append((label, optim_cfg[key])) + + env_cfg = cfg.get("env", {}) if isinstance(cfg, dict) else {} + if "batch_size" in env_cfg: + rows.append(("Global batch size", env_cfg["batch_size"])) + + return rows + + +def write_outputs( + args, + predictor, + problem_type: str, + eval_results: dict, + data_ctx: dict, + raw_folds=None, + ag_folds=None, + raw_metrics_std=None, + ag_by_split_std=None, +): + from plot_logic import ( + build_summary_html, + build_test_html_and_plots, + build_feature_html, + assemble_full_html_report, + build_train_html_and_plots, + ) + from autogluon.multimodal import MultiModalPredictor + from metrics_logic import aggregate_metrics + + raw_metrics = eval_results.get("raw_metrics", {}) + ag_by_split = eval_results.get("ag_eval", {}) + fit_summary_obj = eval_results.get("fit_summary") + + df_train = data_ctx.get("train") + df_val = data_ctx.get("val") + df_test_internal = data_ctx.get("test_internal") + df_test_external = data_ctx.get("test_external") + df_test = df_test_external if df_test_external is not None else df_test_internal + df_train_full = df_train if df_val is None else pd.concat([df_train, df_val], ignore_index=True) + + # Aggregate folds if provided without stds + if raw_folds and raw_metrics_std is None: + raw_metrics, raw_metrics_std = aggregate_metrics(raw_folds) + if ag_folds and ag_by_split_std is None: + ag_by_split, ag_by_split_std = aggregate_metrics(ag_folds) + + # Inject AG eval into raw metrics for visibility + def _inject_ag(src: dict, dst: dict): + for k, v in (src or {}).items(): + try: + dst[f"AG_{k}"] = float(v) + except Exception: + dst[f"AG_{k}"] = v + if "Train" in raw_metrics and "Train" in ag_by_split: + _inject_ag(ag_by_split["Train"], raw_metrics["Train"]) + if "Validation" in raw_metrics and "Validation" in ag_by_split: + _inject_ag(ag_by_split["Validation"], raw_metrics["Validation"]) + if "Test" in raw_metrics and "Test" in ag_by_split: + _inject_ag(ag_by_split["Test"], raw_metrics["Test"]) + + # JSON + with open(args.output_json, "w") as f: + json.dump( + { + "train": raw_metrics.get("Train", {}), + "val": raw_metrics.get("Validation", {}), + "test": raw_metrics.get("Test", {}), + "test_external": raw_metrics.get("Test (external)", {}), + "ag_eval": ag_by_split, + "ag_eval_std": ag_by_split_std, + "fit_summary": fit_summary_obj, + "problem_type": problem_type, + "predictor_path": getattr(predictor, "path", None), + "threshold": args.threshold, + "threshold_test": args.threshold, + "preset": args.preset, + "eval_metric": args.eval_metric, + "folds": { + "raw_folds": raw_folds, + "ag_folds": ag_folds, + "summary_mean": raw_metrics if raw_folds else None, + "summary_std": raw_metrics_std, + "ag_summary_mean": ag_by_split, + "ag_summary_std": ag_by_split_std, + }, + }, + f, + indent=2, + default=str, + ) + logger.info(f"Wrote full JSON → {args.output_json}") + + # HTML report assembly + label_col = args.target_column + + class_balance_block_html = build_class_balance_html( + df_train=df_train, + label_col=label_col, + df_val=df_val, + df_test=df_test, + ) + summary_perf_table_html = build_model_performance_summary_table( + train_scores=raw_metrics.get("Train", {}), + val_scores=raw_metrics.get("Validation", {}), + test_scores=raw_metrics.get("Test", {}), + include_test=True, + title=None, + show_title=False, + ) + + cfg_yaml = _load_config_yaml(args, predictor) + config_rows = _summarize_config(cfg_yaml, args) + threshold_rows = [] + if problem_type == "binary" and args.threshold is not None: + threshold_rows.append(("Decision threshold (Test)", f"{float(args.threshold):.3f}")) + extra_run_rows = [ + ("Target column", label_col), + ("Model evaluation metric", args.eval_metric or "AutoGluon default"), + ("Experiment quality", args.preset or "AutoGluon default"), + ] + threshold_rows + config_rows + + summary_html = build_summary_html( + predictor=predictor, + df_train=df_train_full, + df_val=df_val, + df_test=df_test, + label_column=label_col, + extra_run_rows=extra_run_rows, + class_balance_html=class_balance_block_html, + perf_table_html=summary_perf_table_html, + ) + + train_tab_perf_html = build_model_performance_summary_table( + train_scores=raw_metrics.get("Train", {}), + val_scores=raw_metrics.get("Validation", {}), + test_scores=raw_metrics.get("Test", {}), + include_test=False, + title=None, + show_title=False, + ) + + train_html = build_train_html_and_plots( + predictor=predictor, + problem_type=problem_type, + df_train=df_train, + df_val=df_val, + label_column=label_col, + tmpdir=tempfile.mkdtemp(), + seed=int(args.random_seed), + perf_table_html=train_tab_perf_html, + threshold=args.threshold, + ) + + test_html_template, plots = build_test_html_and_plots( + predictor, + problem_type, + df_test, + label_col, + tempfile.mkdtemp(), + threshold=args.threshold, + ) + + def _fmt_val(v): + if isinstance(v, (int, np.integer)): + return f"{int(v)}" + if isinstance(v, (float, np.floating)): + return f"{v:.6f}" + return str(v) + + test_scores = raw_metrics.get("Test", {}) + # Drop AutoGluon-injected ROC AUC line from the Test Performance Summary + filtered_test_scores = {k: v for k, v in test_scores.items() if k != "AG_roc_auc"} + metric_rows = "".join( + f"<tr><td>{k.replace('_',' ').replace('(TNR)','(TNR)').replace('(Sensitivity/TPR)', '(Sensitivity/TPR)')}</td>" + f"<td>{_fmt_val(v)}</td></tr>" + for k, v in filtered_test_scores.items() + ) + test_html_filled = test_html_template.format(metric_rows) + + is_multimodal = isinstance(predictor, MultiModalPredictor) + leaderboard_html = "" if is_multimodal else build_leaderboard_html(predictor) + inputs_html = "" + ignored_features_html = "" if is_multimodal else build_ignored_features_html(predictor, df_train_full) + presets_hparams_html = build_presets_hparams_html(predictor) + notices: List[str] = [] + if args.threshold is not None and problem_type == "binary": + notices.append(f"Using decision threshold = {float(args.threshold):.3f} on Test.") + warnings_html = build_warnings_html([], notices) + repro_html = build_reproducibility_html(args, {}, getattr(predictor, "path", None)) + + transparency_blocks = "\n".join( + [ + leaderboard_html, + inputs_html, + ignored_features_html, + presets_hparams_html, + warnings_html, + repro_html, + ] + ) + + try: + feature_text = build_feature_html(predictor, df_test, label_col, tempfile.mkdtemp(), args.random_seed) if df_test is not None else "" + except Exception: + feature_text = "<p>Feature analysis unavailable for this model.</p>" + + full_html = assemble_full_html_report( + summary_html, + train_html, + test_html_filled, + plots, + feature_text + transparency_blocks, + ) + with open(args.output_html, "w") as f: + f.write(full_html) + logger.info(f"Wrote HTML report → {args.output_html}") + + pred_path = _write_predictor_path(predictor) + _copy_config_if_available(pred_path, args.output_config) + + outputs_to_check = [ + (args.output_json, "JSON results"), + (args.output_html, "HTML report"), + ] + if args.output_config: + outputs_to_check.append((args.output_config, "AutoGluon config")) + verify_outputs(outputs_to_check) + + +def get_html_template() -> str: + """ + Returns the opening HTML, <head> (with CSS/JS), and opens <body> + .container. + Includes: + - Base styling for layout and tables + - Sortable table headers with 3-state arrows (none ⇅, asc ↑, desc ↓) + - A scroll helper class (.scroll-rows-30) that approximates ~30 visible rows + - A guarded script so initializing runs only once even if injected twice + """ + return """ +<!DOCTYPE html> +<html> +<head> + <meta charset="UTF-8"> + <title>Galaxy-Ludwig Report</title> + <style> + body { + font-family: Arial, sans-serif; + margin: 0; + padding: 20px; + background-color: #f4f4f4; + } + .container { + max-width: 1200px; + margin: auto; + background: white; + padding: 20px; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); + overflow-x: auto; + } + h1 { + text-align: center; + color: #333; + } + h2 { + border-bottom: 2px solid #4CAF50; + color: #4CAF50; + padding-bottom: 5px; + margin-top: 28px; + } + + /* baseline table setup */ + table { + border-collapse: collapse; + margin: 20px 0; + width: 100%; + table-layout: fixed; + background: #fff; + } + table, th, td { + border: 1px solid #ddd; + } + th, td { + padding: 10px; + text-align: center; + vertical-align: middle; + word-break: break-word; + white-space: normal; + overflow-wrap: anywhere; + } + th { + background-color: #4CAF50; + color: white; + } + + .plot { + text-align: center; + margin: 20px 0; + } + .plot img { + max-width: 100%; + height: auto; + border: 1px solid #ddd; + } + + /* ------------------- + sortable columns (3-state: none ⇅, asc ↑, desc ↓) + ------------------- */ + table.performance-summary th.sortable { + cursor: pointer; + position: relative; + user-select: none; + } + /* default icon space */ + table.performance-summary th.sortable::after { + content: '⇅'; + position: absolute; + right: 12px; + top: 50%; + transform: translateY(-50%); + font-size: 0.8em; + color: #eaf5ea; /* light on green */ + text-shadow: 0 0 1px rgba(0,0,0,0.15); + } + /* three states override the default */ + table.performance-summary th.sortable.sorted-none::after { content: '⇅'; color: #eaf5ea; } + table.performance-summary th.sortable.sorted-asc::after { content: '↑'; color: #ffffff; } + table.performance-summary th.sortable.sorted-desc::after { content: '↓'; color: #ffffff; } + + /* show ~30 rows with a scrollbar (tweak if you want) */ + .scroll-rows-30 { + max-height: 900px; /* ~30 rows depending on row height */ + overflow-y: auto; /* vertical scrollbar (“sidebar”) */ + overflow-x: auto; + } + + /* Tabs + Help button (used by build_tabbed_html) */ + .tabs { + display: flex; + align-items: center; + border-bottom: 2px solid #ccc; + margin-bottom: 1rem; + gap: 6px; + flex-wrap: wrap; + } + .tab { + padding: 10px 20px; + cursor: pointer; + border: 1px solid #ccc; + border-bottom: none; + background: #f9f9f9; + margin-right: 5px; + border-top-left-radius: 8px; + border-top-right-radius: 8px; + } + .tab.active { + background: white; + font-weight: bold; + } + .help-btn { + margin-left: auto; + padding: 6px 12px; + font-size: 0.9rem; + border: 1px solid #4CAF50; + border-radius: 4px; + background: #4CAF50; + color: white; + cursor: pointer; + } + .tab-content { + display: none; + padding: 20px; + border: 1px solid #ccc; + border-top: none; + background: #fff; + } + .tab-content.active { + display: block; + } + + /* Modal (used by get_metrics_help_modal) */ + .modal { + display: none; + position: fixed; + z-index: 9999; + left: 0; top: 0; + width: 100%; height: 100%; + overflow: auto; + background-color: rgba(0,0,0,0.4); + } + .modal-content { + background-color: #fefefe; + margin: 8% auto; + padding: 20px; + border: 1px solid #888; + width: 90%; + max-width: 900px; + border-radius: 8px; + } + .modal .close { + color: #777; + float: right; + font-size: 28px; + font-weight: bold; + line-height: 1; + margin-left: 8px; + } + .modal .close:hover, + .modal .close:focus { + color: black; + text-decoration: none; + cursor: pointer; + } + .metrics-guide h3 { margin-top: 20px; } + .metrics-guide p { margin: 6px 0; } + .metrics-guide ul { margin: 10px 0; padding-left: 20px; } + </style> + + <script> + // Guard to avoid double-initialization if this block is included twice + (function(){ + if (window.__perfSummarySortInit) return; + window.__perfSummarySortInit = true; + + function initPerfSummarySorting() { + // Record original order for "back to original" + document.querySelectorAll('table.performance-summary tbody').forEach(tbody => { + Array.from(tbody.rows).forEach((row, i) => { row.dataset.originalOrder = i; }); + }); + + const getText = td => (td?.innerText || '').trim(); + const cmp = (idx, asc) => (a, b) => { + const v1 = getText(a.children[idx]); + const v2 = getText(b.children[idx]); + const n1 = parseFloat(v1), n2 = parseFloat(v2); + if (!isNaN(n1) && !isNaN(n2)) return asc ? n1 - n2 : n2 - n1; // numeric + return asc ? v1.localeCompare(v2) : v2.localeCompare(v1); // lexical + }; + + document.querySelectorAll('table.performance-summary th.sortable').forEach(th => { + // initialize to “none” + th.classList.remove('sorted-asc','sorted-desc'); + th.classList.add('sorted-none'); + + th.addEventListener('click', () => { + const table = th.closest('table'); + const headerRow = th.parentNode; + const allTh = headerRow.querySelectorAll('th.sortable'); + const tbody = table.querySelector('tbody'); + + // Determine current state BEFORE clearing + const isAsc = th.classList.contains('sorted-asc'); + const isDesc = th.classList.contains('sorted-desc'); + + // Reset all headers in this row + allTh.forEach(x => x.classList.remove('sorted-asc','sorted-desc','sorted-none')); + + // Compute next state + let next; + if (!isAsc && !isDesc) { + next = 'asc'; + } else if (isAsc) { + next = 'desc'; + } else { + next = 'none'; + } + th.classList.add('sorted-' + next); + + // Sort rows according to the chosen state + const rows = Array.from(tbody.rows); + if (next === 'none') { + rows.sort((a, b) => (a.dataset.originalOrder - b.dataset.originalOrder)); + } else { + const idx = Array.from(headerRow.children).indexOf(th); + rows.sort(cmp(idx, next === 'asc')); + } + rows.forEach(r => tbody.appendChild(r)); + }); + }); + } + + // Run after DOM is ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', initPerfSummarySorting); + } else { + initPerfSummarySorting(); + } + })(); + </script> +</head> +<body> + <div class="container"> +""" + + +def get_html_closing(): + """Closes .container, body, and html.""" + return """ + </div> +</body> +</html> +""" + + +def build_tabbed_html( + summary_html: str, + train_html: str, + test_html: str, + feature_html: str, + explainer_html: Optional[str] = None, +) -> str: + """ + Renders the tab headers, contents, and JS to switch tabs. + """ + tabs = [ + '<div class="tabs">', + '<div class="tab active" onclick="showTab(\'summary\')">Model Metric Summary and Config</div>', + '<div class="tab" onclick="showTab(\'train\')">Train and Validation Summary</div>', + '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>', + ] + if explainer_html: + tabs.append('<div class="tab" onclick="showTab(\'explainer\')">Explainer Plots</div>') + tabs.append('<button id="openMetricsHelp" class="help-btn">Help</button>') + tabs.append('</div>') + tabs_section = "\n".join(tabs) + + contents = [ + f'<div id="summary" class="tab-content active">{summary_html}</div>', + f'<div id="train" class="tab-content">{train_html}</div>', + f'<div id="test" class="tab-content">{test_html}</div>', + ] + if explainer_html: + contents.append(f'<div id="explainer" class="tab-content">{explainer_html}</div>') + content_section = "\n".join(contents) + + js = """ +<script> +function showTab(id) { + document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); + document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); + document.getElementById(id).classList.add('active'); + document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active'); +} +</script> +""" + return tabs_section + "\n" + content_section + "\n" + js + + +def encode_image_to_base64(image_path: str) -> str: + """ + Reads an image file from disk and returns a base64-encoded string + for embedding directly in HTML <img> tags. + """ + try: + with open(image_path, "rb") as img_f: + return base64.b64encode(img_f.read()).decode("utf-8") + except Exception as e: + logger.error(f"Failed to encode image '{image_path}': {e}") + return "" + + +def get_model_architecture(predictor: Any) -> str: + """ + Returns a human-friendly description of the final model architecture based on the + MultiModalPredictor configuration (e.g., timm_image=resnet50, hf_text=bert-base-uncased). + """ + # MultiModalPredictor path: read backbones from config if available + archs = [] + for attr in ("_config", "config"): + cfg = getattr(predictor, attr, None) + try: + model_cfg = getattr(cfg, "model", None) + if model_cfg: + # OmegaConf-like mapping + for name, sub in dict(model_cfg).items(): + ck = None + # sub may be an object or a dict-like node + for k in ("checkpoint_name", "name", "model_name"): + try: + ck = getattr(sub, k) + except Exception: + ck = sub.get(k) if isinstance(sub, dict) else ck + if ck: + break + if ck: + archs.append(f"{name}={ck}") + except Exception: + continue + + if archs: + return ", ".join(archs) + + # Fallback + return type(predictor).__name__ + + +def collect_run_context(args, predictor, problem_type: str, + df_train: pd.DataFrame, df_val: pd.DataFrame, df_test: pd.DataFrame, + warnings_list: List[str], + notes_list: List[str]) -> Dict[str, Any]: + """Build a dictionary with run/system context for transparency.""" + # System info (best-effort; not depending on AutoGluon stdout) + try: + import psutil # optional + mem = psutil.virtual_memory() + mem_total_gb = mem.total / (1024 ** 3) + mem_avail_gb = mem.available / (1024 ** 3) + except Exception: + mem_total_gb = mem_avail_gb = None + + ctx = { + "timestamp": datetime.now().isoformat(timespec="seconds"), + "python_version": platform.python_version(), + "platform": { + "system": platform.system(), + "release": platform.release(), + "version": platform.version(), + "machine": platform.machine(), + }, + "cpu_count": os.cpu_count(), + "memory_total_gb": mem_total_gb, + "memory_available_gb": mem_avail_gb, + "packages": {}, + "problem_type": problem_type, + "label_column": args.label_column, + "time_limit_sec": args.time_limit, + "random_seed": args.random_seed, + "splits": { + "train_rows": int(len(df_train)), + "val_rows": int(len(df_val)), + "test_rows": int(len(df_test)), + "n_features_raw": int(len(df_train.columns) - 1), # minus label + }, + "warnings": warnings_list, + "notes": notes_list, + } + # Package versions (safe best-effort) + try: + import autogluon + ctx["packages"]["autogluon"] = getattr(autogluon, "__version__", "unknown") + except Exception: + pass + try: + import torch as _torch + ctx["packages"]["torch"] = getattr(_torch, "__version__", "unknown") + except Exception: + pass + try: + import sklearn + ctx["packages"]["scikit_learn"] = getattr(sklearn, "__version__", "unknown") + except Exception: + pass + try: + import numpy as _np + ctx["packages"]["numpy"] = getattr(_np, "__version__", "unknown") + except Exception: + pass + try: + import pandas as _pd + ctx["packages"]["pandas"] = getattr(_pd, "__version__", "unknown") + except Exception: + pass + return ctx + + +def build_class_balance_html( + df_train: Optional[pd.DataFrame], + label_col: str, + df_val: Optional[pd.DataFrame] = None, + df_test: Optional[pd.DataFrame] = None, +) -> str: + """ + Render label counts for each available split (Train/Validation/Test). + """ + def _count_labels(frame: Optional[pd.DataFrame]) -> pd.Series: + if frame is None or label_col not in frame: + return pd.Series(dtype=int) + series = frame[label_col] + if series.dtype.kind in "ifu": + return pd.Series(series).value_counts(dropna=False).sort_index() + return pd.Series(series.astype(str)).value_counts(dropna=False) + + counts_train = _count_labels(df_train) + counts_val = _count_labels(df_val) + counts_test = _count_labels(df_test) + + labels: list[Any] = [] + for idx in (counts_train.index, counts_val.index, counts_test.index): + for label in idx: + if label not in labels: + labels.append(label) + + has_train = df_train is not None + has_val = df_val is not None + has_test = df_test is not None + + def _fmt_count(counts: pd.Series, label: Any, enabled: bool) -> str: + if not enabled: + return "—" + return str(int(counts.get(label, 0))) + + rows = [ + f"<tr><td>{_escape(label)}</td>" + f"<td>{_fmt_count(counts_train, label, has_train)}</td>" + f"<td>{_fmt_count(counts_val, label, has_val)}</td>" + f"<td>{_fmt_count(counts_test, label, has_test)}</td></tr>" + for label in labels + ] + + if not rows: + return "<p>No label distribution available.</p>" + + return f""" + <h3>Label Counts by Split</h3> + <table class="table"> + <thead><tr><th>Label</th><th>Train</th><th>Validation</th><th>Test</th></tr></thead> + <tbody> + {''.join(rows)} + </tbody> + </table> + """ + + +def build_leaderboard_html(predictor) -> str: + try: + lb = predictor.leaderboard(silent=True) + # keep common helpful columns if present + cols_pref = ["model", "score_val", "eval_metric", "pred_time_val", "fit_time", + "pred_time_val_marginal", "fit_time_marginal", "stack_level", "can_infer", "fit_order"] + cols = [c for c in cols_pref if c in lb.columns] or list(lb.columns) + return "<h3>Model Leaderboard (Validation)</h3>" + lb[cols].to_html(index=False) + except Exception as e: + return f"<h3>Model Leaderboard</h3><p>Unavailable: {_escape(e)}</p>" + + +def build_ignored_features_html(predictor, df_any: pd.DataFrame) -> str: + # MultiModalPredictor does not always expose .features(); guard accordingly. + used = set() + try: + used = set(predictor.features()) + except Exception: + # If we can't determine, don't emit a misleading section + return "" + raw_cols = [c for c in df_any.columns if c != getattr(predictor, "label", None)] + ignored = [c for c in raw_cols if c not in used] + if not ignored: + return "" + items = "".join(f"<li>{html.escape(c)}</li>" for c in ignored) + return f""" + <h3>Ignored / Unused Features</h3> + <p>The following columns were not used by the trained predictor at inference time:</p> + <ul>{items}</ul> + """ + + +def build_presets_hparams_html(predictor) -> str: + # MultiModalPredictor path + mm_hp = {} + for attr in ("_config", "config", "_fit_args"): + if hasattr(predictor, attr): + try: + val = getattr(predictor, attr) + # make it JSON-ish + mm_hp[attr] = str(val) + except Exception: + continue + hp_html = f"<pre>{html.escape(json.dumps(mm_hp, indent=2))}</pre>" if mm_hp else "<i>Unavailable</i>" + return f"<h3>Training Presets & Hyperparameters</h3><details open><summary>Show hyperparameters</summary>{hp_html}</details>" + + +def build_warnings_html(warnings_list: List[str], notes_list: List[str]) -> str: + if not warnings_list and not notes_list: + return "" + w_html = "".join(f"<li>{_escape(w)}</li>" for w in warnings_list) + n_html = "".join(f"<li>{_escape(n)}</li>" for n in notes_list) + return f""" + <h3>Warnings & Notices</h3> + {'<h4>Warnings</h4><ul>'+w_html+'</ul>' if warnings_list else ''} + {'<h4>Notices</h4><ul>'+n_html+'</ul>' if notes_list else ''} + """ + + +def build_reproducibility_html(args, ctx: Dict[str, Any], model_path: Optional[str]) -> str: + cmd = " ".join(_escape(x) for x in sys.argv) + load_snippet = "" + if model_path: + load_snippet = f"""<pre> +from autogluon.multimodal import MultiModalPredictor +predictor = MultiModalPredictor.load("{_escape(model_path)}") +</pre>""" + pkg_rows = "".join(f"<tr><td>{_escape(k)}</td><td>{_escape(v)}</td></tr>" for k, v in (ctx.get("packages") or {}).items()) + sys_table = f""" + <table class="table"> + <tbody> + <tr><th>Timestamp</th><td>{_escape(ctx.get('timestamp'))}</td></tr> + <tr><th>Python</th><td>{_escape(ctx.get('python_version'))}</td></tr> + <tr><th>Platform</th><td>{_escape(ctx.get('platform'))}</td></tr> + <tr><th>CPU Count</th><td>{_escape(ctx.get('cpu_count'))}</td></tr> + <tr><th>Memory (GB)</th><td>Total: {_escape(ctx.get('memory_total_gb'))} | Avail: {_escape(ctx.get('memory_available_gb'))}</td></tr> + <tr><th>Seed</th><td>{_escape(ctx.get('random_seed'))}</td></tr> + <tr><th>Time Limit (s)</th><td>{_escape(ctx.get('time_limit_sec'))}</td></tr> + </tbody> + </table> + """ + pkgs_table = f""" + <h4>Package Versions</h4> + <table class="table"> + <thead><tr><th>Package</th><th>Version</th></tr></thead> + <tbody>{pkg_rows}</tbody> + </table> + """ + return f""" + <h3>Reproducibility</h3> + <h4>Command</h4> + <pre>{cmd}</pre> + {sys_table} + {pkgs_table} + <h4>Load Trained Model</h4> + {load_snippet or '<i>Model path not available</i>'} + """ + + +def build_modalities_html(predictor, df_any: pd.DataFrame, label_col: str, image_col: Optional[str]) -> str: + """Summarize which inputs/modalities are used for MultiModalPredictor.""" + cols = [c for c in df_any.columns] + # exclude label from feature list + feat_cols = [c for c in cols if c != label_col] + # identify image vs tabular columns from args / presence + img_present = (image_col in df_any.columns) if image_col else False + tab_cols = [c for c in feat_cols if c != image_col] + + # brief lists (avoid dumping all, unless small) + def list_or_count(arr, max_show=20): + if len(arr) <= max_show: + items = "".join(f"<li>{html.escape(str(x))}</li>" for x in arr) + return f"<ul>{items}</ul>" + return f"<p>{len(arr)} columns</p>" + + img_block = f"<p><b>Image column:</b> {html.escape(image_col)}</p>" if img_present else "<p><b>Image column:</b> None</p>" + tab_block = f"<div><b>Structured columns:</b> {len(tab_cols)}{list_or_count(tab_cols, max_show=15)}</div>" + + return f""" + <h3>Modalities & Inputs</h3> + <p>This run used <b>MultiModalPredictor</b> (images + structured features).</p> + <p><b>Label column:</b> {html.escape(label_col)}</p> + {img_block} + {tab_block} + """ + + +def build_model_performance_summary_table( + train_scores: dict, + val_scores: dict, + test_scores: dict | None = None, + include_test: bool = True, + title: str | None = 'Model Performance Summary', + show_title: bool = True, +) -> str: + """ + Returns an HTML table for metrics, optionally hiding the Test column. + Keys across score dicts are unioned; missing values render as '—'. + """ + def fmt(v): + if v is None: + return '—' + if isinstance(v, (int, float)): + return f'{v:.4f}' + return str(v) + + # Collect union of metric keys across splits + metrics = set(train_scores.keys()) | set(val_scores.keys()) | (set(test_scores.keys()) if (include_test and test_scores) else set()) + + # Remove AG_roc_auc entirely as requested + metrics.discard('AG_roc_auc') + + # Helper: normalize metric keys for matching preferred names + def _norm(k: str) -> str: + return ''.join(ch for ch in str(k).lower() if ch.isalnum()) + + # Preferred metrics to appear at the end in this specific order (display names): + preferred_display = ['Accuracy', 'ROC-AUC', 'Precision', 'Recall', 'F1-Score', 'PR-AUC', 'Specificity', 'MCC', 'LogLoss'] + # Mapping of normalized key -> display label + norm_to_display = { + 'accuracy': 'Accuracy', + 'acc': 'Accuracy', + 'rocauc': 'ROC-AUC', + 'roc_auc': 'ROC-AUC', + 'rocaucscore': 'ROC-AUC', + 'precision': 'Precision', + 'prec': 'Precision', + 'recall': 'Recall', + 'recallsensitivitytpr': 'Recall', + 'f1': 'F1-Score', + 'f1score': 'F1-Score', + 'pr_auc': 'PR-AUC', + 'prauc': 'PR-AUC', + 'averageprecision': 'PR-AUC', + 'specificity': 'Specificity', + 'tnr': 'Specificity', + 'mcc': 'MCC', + 'logloss': 'LogLoss', + 'crossentropy': 'LogLoss', + } + + # Build ordered list: all non-preferred metrics sorted alphabetically, then preferred metrics in the requested order if present + preferred_norms = [_norm(x) for x in preferred_display] + all_metrics = list(metrics) + # Partition + preferred_present = [] + others = [] + for m in sorted(all_metrics): + nm = _norm(m) + if nm in preferred_norms or any( + p in nm for p in ["rocauc", "prauc", "f1", "mcc", "logloss", "accuracy", "precision", "recall", "specificity"] + ): + # Defer preferred-like metrics to the end (we will place them in canonical order) + preferred_present.append(m) + else: + others.append(m) + + # Now assemble final metric order: others (alpha), then preferred in exact requested order if they exist in metrics + final_metrics = [] + final_metrics.extend(others) + for disp in preferred_display: + # find any original key matching this display (by normalized mapping) + target_norm = _norm(disp) + found = None + for m in preferred_present: + if _norm(m) == target_norm or norm_to_display.get(_norm(m)) == disp or _norm(m).replace(' ', '') == target_norm: + found = m + break + # also allow substring matches (e.g., 'roc_auc' vs 'rocauc') + if target_norm in _norm(m): + found = m + break + if found: + final_metrics.append(found) + + metrics = final_metrics + + # Make all headers sortable by adding the 'sortable' class; the JS in utils.py hooks table.performance-summary + header_cells = [ + '<th class="sortable">Metric</th>', + '<th class="sortable">Train</th>', + '<th class="sortable">Validation</th>' + ] + if include_test and test_scores: + header_cells.append('<th class="sortable">Test</th>') + + rows_html = [] + for m in metrics: + # Display label mapping: clean up common verbose names + disp = m + nm = _norm(m) + if nm in norm_to_display: + disp = norm_to_display[nm] + else: + # generic cleanup: replace underscores with space and remove parenthetical qualifiers + disp = str(m).replace('_', ' ') + disp = disp.replace('(Sensitivity/TPR)', '') + disp = disp.replace('(TNR)', '') + disp = disp.strip() + + cells = [ + f'<td>{_escape(disp)}</td>', + f'<td>{fmt(train_scores.get(m))}</td>', + f'<td>{fmt(val_scores.get(m))}</td>', + ] + if include_test and test_scores: + cells.append(f'<td>{fmt(test_scores.get(m))}</td>') + + rows_html.append('<tr>' + ''.join(cells) + '</tr>') + + title_html = f'<h3 style="margin-top:0">{title}</h3>' if (show_title and title) else '' + + table_html = f""" + {title_html} + <table class="performance-summary"> + <thead><tr>{''.join(header_cells)}</tr></thead> + <tbody>{''.join(rows_html)}</tbody> + </table> + """ + return table_html
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/split_logic.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,76 @@ +import logging +from typing import List, Optional + +import pandas as pd +from sklearn.model_selection import train_test_split + +logger = logging.getLogger(__name__) +SPLIT_COL = "split" + + +def _can_stratify(y: pd.Series) -> bool: + return y.nunique() >= 2 and (y.value_counts() >= 2).all() + + +def split_dataset( + train_dataset: pd.DataFrame, + test_dataset: Optional[pd.DataFrame], + target_column: str, + split_probabilities: List[float], + validation_size: float, + random_seed: int = 42, +) -> None: + if target_column not in train_dataset.columns: + raise ValueError(f"Target column '{target_column}' not found") + + # Drop NaN labels early + before = len(train_dataset) + train_dataset.dropna(subset=[target_column], inplace=True) + if len(train_dataset) == 0: + raise ValueError("No rows remain after dropping NaN targets") + if before != len(train_dataset): + logger.warning(f"Dropped {before - len(train_dataset)} rows with NaN target") + y = train_dataset[target_column] + + # Respect existing valid split column + if SPLIT_COL in train_dataset.columns: + unique = set(train_dataset[SPLIT_COL].dropna().unique()) + valid = {"train", "val", "validation", "test"} + if unique.issubset(valid | {"validation"}): + train_dataset[SPLIT_COL] = train_dataset[SPLIT_COL].replace("validation", "val") + logger.info(f"Using pre-existing 'split' column: {sorted(unique)}") + return + + train_dataset[SPLIT_COL] = "train" + + if test_dataset is not None: + stratify = y if _can_stratify(y) else None + train_idx, val_idx = train_test_split( + train_dataset.index, test_size=validation_size, + random_state=random_seed, stratify=stratify + ) + train_dataset.loc[val_idx, SPLIT_COL] = "val" + logger.info(f"External test set → created val split ({validation_size:.0%})") + + else: + p_train, p_val, p_test = split_probabilities + if abs(p_train + p_val + p_test - 1.0) > 1e-6: + raise ValueError("split_probabilities must sum to 1.0") + + stratify = y if _can_stratify(y) else None + tv_idx, test_idx = train_test_split( + train_dataset.index, test_size=p_test, + random_state=random_seed, stratify=stratify + ) + rel_val = p_val / (p_train + p_val) if (p_train + p_val) > 0 else 0 + strat_tv = y.loc[tv_idx] if _can_stratify(y.loc[tv_idx]) else None + train_idx, val_idx = train_test_split( + tv_idx, test_size=rel_val, + random_state=random_seed, stratify=strat_tv + ) + + train_dataset.loc[val_idx, SPLIT_COL] = "val" + train_dataset.loc[test_idx, SPLIT_COL] = "test" + logger.info(f"3-way split → train:{len(train_idx)}, val:{len(val_idx)}, test:{len(test_idx)}") + + logger.info(f"Final split distribution:\n{train_dataset[SPLIT_COL].value_counts().sort_index()}")
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test-data/sample_output.html Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,990 @@ +<!DOCTYPE html> +<html> +<head> + <meta charset="UTF-8"> + <title>Galaxy-Ludwig Report</title> + <style> + body { + font-family: Arial, sans-serif; + margin: 0; + padding: 20px; + background-color: #f4f4f4; + } + .container { + max-width: 1200px; + margin: auto; + background: white; + padding: 20px; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); + overflow-x: auto; + } + h1 { + text-align: center; + color: #333; + } + h2 { + border-bottom: 2px solid #4CAF50; + color: #4CAF50; + padding-bottom: 5px; + margin-top: 28px; + } + h3 { + margin-top: 20px; + color: #333; + } + .section { + margin-bottom: 30px; + } + .section-title { + border-bottom: 2px solid #4CAF50; + color: #4CAF50; + padding-bottom: 5px; + margin-top: 28px; + } + .card { + background: #fff; + padding: 15px; + border: 1px solid #ddd; + border-radius: 4px; + margin: 10px 0; + } + + /* baseline table setup */ + table { + border-collapse: collapse; + margin: 20px 0; + width: 100%; + table-layout: fixed; + background: #fff; + } + table, th, td { + border: 1px solid #ddd; + } + th, td { + padding: 10px; + text-align: center; + vertical-align: middle; + word-break: break-word; + white-space: normal; + overflow-wrap: anywhere; + } + th { + background-color: #4CAF50; + color: white; + } + .metric-table { + width: 100%; + } + .kv-table { + width: 100%; + } + .kv-table th:first-child, + .kv-table td:first-child { + text-align: left; + width: 30%; + } + + .plot { + text-align: center; + margin: 20px 0; + } + .plot img { + max-width: 100%; + height: auto; + border: 1px solid #ddd; + } + + /* ------------------- + sortable columns (3-state: none ⇅, asc ↑, desc ↓) + ------------------- */ + table.performance-summary th.sortable { + cursor: pointer; + position: relative; + user-select: none; + } + /* default icon space */ + table.performance-summary th.sortable::after { + content: '⇅'; + position: absolute; + right: 12px; + top: 50%; + transform: translateY(-50%); + font-size: 0.8em; + color: #eaf5ea; /* light on green */ + text-shadow: 0 0 1px rgba(0,0,0,0.15); + } + /* three states override the default */ + table.performance-summary th.sortable.sorted-none::after { content: '⇅'; color: #eaf5ea; } + table.performance-summary th.sortable.sorted-asc::after { content: '↑'; color: #ffffff; } + table.performance-summary th.sortable.sorted-desc::after { content: '↓'; color: #ffffff; } + + /* show ~30 rows with a scrollbar (tweak if you want) */ + .scroll-rows-30 { + max-height: 900px; /* ~30 rows depending on row height */ + overflow-y: auto; /* vertical scrollbar ("sidebar") */ + overflow-x: auto; + } + + /* Tabs + Help button (used by build_tabbed_html) */ + .tabs { + display: flex; + align-items: center; + border-bottom: 2px solid #ccc; + margin-bottom: 1rem; + gap: 6px; + flex-wrap: wrap; + } + .tab { + padding: 10px 20px; + cursor: pointer; + border: 1px solid #ccc; + border-bottom: none; + background: #f9f9f9; + margin-right: 5px; + border-top-left-radius: 8px; + border-top-right-radius: 8px; + } + .tab.active { + background: white; + font-weight: bold; + } + .help-btn { + margin-left: auto; + padding: 6px 12px; + font-size: 0.9rem; + border: 1px solid #4CAF50; + border-radius: 4px; + background: #4CAF50; + color: white; + cursor: pointer; + } + .tab-content { + display: none; + padding: 20px; + border: 1px solid #ccc; + border-top: none; + background: #fff; + } + .tab-content.active { + display: block; + } + .plotly-center { + display: flex; + justify-content: center; + } + .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { + margin: 0 auto !important; + } + .js-plotly-plot, .plotly-graph-div { + margin-left: auto !important; + margin-right: auto !important; + } + + /* Modal (used by get_metrics_help_modal) */ + .modal { + display: none; + position: fixed; + z-index: 9999; + left: 0; top: 0; + width: 100%; height: 100%; + overflow: auto; + background-color: rgba(0,0,0,0.4); + } + .modal-content { + background-color: #fefefe; + margin: 8% auto; + padding: 20px; + border: 1px solid #888; + width: 90%; + max-width: 900px; + border-radius: 8px; + } + .modal .close { + color: #777; + float: right; + font-size: 28px; + font-weight: bold; + line-height: 1; + margin-left: 8px; + } + .modal .close:hover, + .modal .close:focus { + color: black; + text-decoration: none; + cursor: pointer; + } + .metrics-guide h3 { margin-top: 20px; } + .metrics-guide p { margin: 6px 0; } + .metrics-guide ul { margin: 10px 0; padding-left: 20px; } + </style> + + <script> + // Guard to avoid double-initialization if this block is included twice + (function(){ + if (window.__perfSummarySortInit) return; + window.__perfSummarySortInit = true; + + function initPerfSummarySorting() { + // Record original order for "back to original" + document.querySelectorAll('table.performance-summary tbody').forEach(tbody => { + Array.from(tbody.rows).forEach((row, i) => { row.dataset.originalOrder = i; }); + }); + + const getText = td => (td?.innerText || '').trim(); + const cmp = (idx, asc) => (a, b) => { + const v1 = getText(a.children[idx]); + const v2 = getText(b.children[idx]); + const n1 = parseFloat(v1), n2 = parseFloat(v2); + if (!isNaN(n1) && !isNaN(n2)) return asc ? n1 - n2 : n2 - n1; // numeric + return asc ? v1.localeCompare(v2) : v2.localeCompare(v1); // lexical + }; + + document.querySelectorAll('table.performance-summary th.sortable').forEach(th => { + // initialize to "none" + th.classList.remove('sorted-asc','sorted-desc'); + th.classList.add('sorted-none'); + + th.addEventListener('click', () => { + const table = th.closest('table'); + const headerRow = th.parentNode; + const allTh = headerRow.querySelectorAll('th.sortable'); + const tbody = table.querySelector('tbody'); + + // Determine current state BEFORE clearing + const isAsc = th.classList.contains('sorted-asc'); + const isDesc = th.classList.contains('sorted-desc'); + + // Reset all headers in this row + allTh.forEach(x => x.classList.remove('sorted-asc','sorted-desc','sorted-none')); + + // Compute next state + let next; + if (!isAsc && !isDesc) { + next = 'asc'; + } else if (isAsc) { + next = 'desc'; + } else { + next = 'none'; + } + th.classList.add('sorted-' + next); + + // Sort rows according to the chosen state + const rows = Array.from(tbody.rows); + if (next === 'none') { + rows.sort((a, b) => (a.dataset.originalOrder - b.dataset.originalOrder)); + } else { + const idx = Array.from(headerRow.children).indexOf(th); + rows.sort(cmp(idx, next === 'asc')); + } + rows.forEach(r => tbody.appendChild(r)); + }); + }); + } + + // Run after DOM is ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', initPerfSummarySorting); + } else { + initPerfSummarySorting(); + } + })(); + </script> +</head> +<body> + <div class="container"> + <script src="https://cdn.plot.ly/plotly-2.30.0.min.js"></script> + <style> + .plotly-center { + display: flex; + justify-content: center; + margin: 20px 0; + width: 100%; + } + .plotly-center .plotly-graph-div, .plotly-center .js-plotly-plot { + margin: 0 auto !important; + } + .js-plotly-plot, .plotly-graph-div { + margin-left: auto !important; + margin-right: auto !important; + } + #learning-curve-accuracy, #learning-curve-loss, #threshold-plot, + #confusion-matrix, #per-class-metrics, #roc-curve, #pr-curve { + min-height: 400px; + } + </style> + + <div class="tabs"> + <div class="tab active" onclick="showTab('summary')">Model Summary & Config</div> + <div class="tab" onclick="showTab('train')">Train/Validation Summary</div> + <div class="tab" onclick="showTab('test')">Test Summary</div> + <div class="tab" onclick="showTab('feature')">Feature Importance</div> + <button id="openMetricsHelp" class="help-btn">Help</button> + </div> + + <!-- Summary Tab --> + <div id="summary" class="tab-content active"> + <section class="section"> + <h2 class="section-title">Model Performance Summary</h2> + <div class="card"> + <table class="metric-table"> + <thead><tr><th>Metric</th><th>Train</th><th>Validation</th><th>Test</th></tr></thead> + <tbody> + <tr><td>accuracy</td><td>0.9234</td><td>0.8912</td><td>0.8856</td></tr> + <tr><td>f1</td><td>0.9201</td><td>0.8876</td><td>0.8823</td></tr> + <tr><td>precision</td><td>0.9156</td><td>0.8845</td><td>0.8798</td></tr> + <tr><td>recall</td><td>0.9245</td><td>0.8907</td><td>0.8849</td></tr> + <tr><td>roc_auc</td><td>0.9789</td><td>0.9543</td><td>0.9512</td></tr> + <tr><td>log_loss</td><td>0.2134</td><td>0.2876</td><td>0.3012</td></tr> + </tbody> + </table> + </div> + </section> + + <section class="section"> + <h2 class="section-title">Run Configuration</h2> + <div class="card"> + <table class="kv-table"> + <thead><tr><th>Key</th><th>Value</th></tr></thead> + <tbody> + <tr><td>Predictor type</td><td>MultiModalPredictor</td></tr> + <tr><td>Framework</td><td>AutoGluon Multimodal</td></tr> + <tr><td>Model architecture</td><td>timm_image=resnet50, hf_text=bert-base-uncased</td></tr> + <tr><td>Modalities & Inputs</td><td>Images + Tabular</td></tr> + <tr><td>Label column</td><td>target</td></tr> + <tr><td>Image columns</td><td>image_path</td></tr> + <tr><td>Tabular columns</td><td>15</td></tr> + <tr><td>Presets</td><td>medium_quality</td></tr> + <tr><td>Eval metric</td><td>accuracy</td></tr> + <tr><td>Decision threshold calibration</td><td>enabled</td></tr> + <tr><td>Decision threshold (Test only)</td><td>0.500</td></tr> + <tr><td>Seed</td><td>42</td></tr> + <tr><td>time limit(s)</td><td>3600</td></tr> + </tbody> + </table> + </div> + </section> + + <section class="section"> + <h2 class="section-title">Class Balance (Train Full)</h2> + <div class="card"> + <h3>Class Balance (Train Full)</h3> + <table class="table"> + <thead><tr><th>Class</th><th>Count</th><th>Percent</th></tr></thead> + <tbody> + <tr><td>0</td><td>1245</td><td>45.23%</td></tr> + <tr><td>1</td><td>1508</td><td>54.77%</td></tr> + </tbody> + </table> + </div> + </section> + </div> + + <!-- Train Tab --> + <div id="train" class="tab-content"> + <h2>Train/Validation Performance Summary</h2> + + <div class="card"> + <table class="metric-table"> + <thead><tr><th>Metric</th><th>Train</th><th>Validation</th></tr></thead> + <tbody> + <tr><td>accuracy</td><td>0.9234</td><td>0.8912</td></tr> + <tr><td>f1</td><td>0.9201</td><td>0.8876</td></tr> + <tr><td>precision</td><td>0.9156</td><td>0.8845</td></tr> + <tr><td>recall</td><td>0.9245</td><td>0.8907</td></tr> + <tr><td>roc_auc</td><td>0.9789</td><td>0.9543</td></tr> + <tr><td>log_loss</td><td>0.2134</td><td>0.2876</td></tr> + </tbody> + </table> + </div> + + <h2>Learning Curves — Label Accuracy</h2> + <div class="plotly-center"> + <div id="learning-curve-accuracy" style="width:900px;height:500px;"></div> + </div> + + <h2>Learning Curves — Label Loss</h2> + <div class="plotly-center"> + <div id="learning-curve-loss" style="width:900px;height:500px;"></div> + </div> + </div> + + <!-- Test Tab --> + <div id="test" class="tab-content"> + <h2>Test Performance Summary</h2> + <table class="performance-summary"> + <thead> + <tr> + <th class="sortable">Metric</th> + <th class="sortable">Test</th> + </tr> + </thead> + <tbody> + <tr><td>accuracy</td><td>0.8856</td></tr> + <tr><td>f1</td><td>0.8823</td></tr> + <tr><td>precision</td><td>0.8798</td></tr> + <tr><td>recall</td><td>0.8849</td></tr> + <tr><td>roc_auc</td><td>0.9512</td></tr> + <tr><td>log_loss</td><td>0.3012</td></tr> + <tr><td>specificity (TNR)</td><td>0.8765</td></tr> + <tr><td>sensitivity (Sensitivity/TPR)</td><td>0.8923</td></tr> + </tbody> + </table> + + <h2>Confusion Matrix</h2> + <div class="plotly-center"> + <div id="confusion-matrix" style="width:700px;height:600px;"></div> + </div> + + <h2>Per-Class Metrics</h2> + <div class="plotly-center"> + <div id="per-class-metrics" style="width:900px;height:500px;"></div> + </div> + + <h2>ROC Curve</h2> + <div class="plotly-center"> + <div id="roc-curve" style="width:800px;height:600px;"></div> + </div> + + <h2>Precision–Recall Curve</h2> + <div class="plotly-center"> + <div id="pr-curve" style="width:800px;height:600px;"></div> + </div> + + <h2>Threshold Plot</h2> + <div class="plotly-center"> + <div id="threshold-plot" style="width:900px;height:500px;"></div> + </div> + </div> + + <!-- Feature Importance Tab --> + <div id="feature" class="tab-content"> + <section class="section"> + <h2 class="section-title">Feature Importance</h2> + <div class="card"> + <p>Permutation importance is not supported for MultiModalPredictor in this tool. For tabular-only runs, this section shows permutation importance.</p> + </div> + </section> + + <section class="section"> + <h2 class="section-title">Modalities & Inputs</h2> + <div class="card"> + <h3>Modalities & Inputs</h3> + <p>This run used <b>MultiModalPredictor</b> (images + tabular).</p> + <p><b>Label column:</b> target</p> + <p><b>Image column:</b> image_path</p> + <div><b>Tabular columns:</b> 15 + <ul> + <li>feature_1</li> + <li>feature_2</li> + <li>feature_3</li> + <li>feature_4</li> + <li>feature_5</li> + <li>feature_6</li> + <li>feature_7</li> + <li>feature_8</li> + <li>feature_9</li> + <li>feature_10</li> + <li>feature_11</li> + <li>feature_12</li> + <li>feature_13</li> + <li>feature_14</li> + <li>feature_15</li> + </ul> + </div> + </div> + </section> + </div> + + <script> + function showTab(id) { + document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active')); + document.querySelectorAll('.tab').forEach(el => el.classList.remove('active')); + document.getElementById(id).classList.add('active'); + document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active'); + } + + // Fixed random data for reproducibility + const epochs = Array.from({length: 31}, (_, i) => i); + const accuracy_train = [0.552, 0.568, 0.581, 0.595, 0.612, 0.628, 0.645, 0.662, 0.678, 0.692, 0.708, 0.722, 0.735, 0.748, 0.761, 0.773, 0.784, 0.795, 0.805, 0.814, 0.823, 0.831, 0.839, 0.846, 0.853, 0.859, 0.865, 0.870, 0.875, 0.880, 0.884]; + const accuracy_val = [0.752, 0.768, 0.782, 0.795, 0.807, 0.818, 0.828, 0.837, 0.845, 0.852, 0.859, 0.865, 0.871, 0.876, 0.881, 0.885, 0.889, 0.893, 0.896, 0.899, 0.902, 0.905, 0.907, 0.909, 0.911, 0.913, 0.915, 0.917, 0.918, 0.920, 0.921]; + const loss_train = [1.352, 1.285, 1.221, 1.158, 1.098, 1.041, 0.987, 0.936, 0.888, 0.842, 0.799, 0.758, 0.720, 0.684, 0.650, 0.618, 0.588, 0.560, 0.534, 0.510, 0.487, 0.466, 0.447, 0.429, 0.412, 0.397, 0.383, 0.370, 0.358, 0.347, 0.337]; + const loss_val = [0.802, 0.765, 0.730, 0.697, 0.666, 0.637, 0.610, 0.585, 0.561, 0.539, 0.518, 0.499, 0.481, 0.464, 0.448, 0.433, 0.419, 0.406, 0.394, 0.383, 0.372, 0.362, 0.353, 0.345, 0.337, 0.330, 0.323, 0.317, 0.311, 0.306, 0.301]; + + // Wait for Plotly to load before creating plots + function createPlots() { + if (typeof Plotly === 'undefined') { + console.log('Waiting for Plotly...'); + setTimeout(createPlots, 100); + return; + } + + console.log('Plotly loaded, creating plots...'); + + // Learning Curves — Label Accuracy (with both training and validation) + try { + Plotly.newPlot('learning-curve-accuracy', [ + { + x: epochs, + y: accuracy_train, + type: 'scatter', + mode: 'lines+markers', + name: 'Training', + line: { width: 3, color: '#1f77b4', shape: 'spline' }, + marker: { size: 8, color: '#1f77b4' } + }, + { + x: epochs, + y: accuracy_val, + type: 'scatter', + mode: 'lines+markers', + name: 'Validation', + line: { width: 3, color: '#ff7f0e', shape: 'spline' }, + marker: { size: 8, color: '#ff7f0e' } + } + ], { + template: 'plotly_white', + xaxis: { + title: 'Epoch', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false + }, + yaxis: { + title: 'Accuracy', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false, + range: [0.5, 1.0] + }, + legend: { + orientation: 'h', + yanchor: 'bottom', + y: 1.02, + xanchor: 'right', + x: 1.0 + }, + margin: { l: 60, r: 20, t: 40, b: 50 }, + plot_bgcolor: 'white', + paper_bgcolor: 'white' + }); + console.log('Accuracy plot created'); + } catch(e) { + console.error('Error creating accuracy plot:', e); + } + + // Learning Curves — Label Loss (with both training and validation) + try { + Plotly.newPlot('learning-curve-loss', [ + { + x: epochs, + y: loss_train, + type: 'scatter', + mode: 'lines+markers', + name: 'Training Loss', + line: { width: 3, color: '#1f77b4', shape: 'spline' }, + marker: { size: 8, color: '#1f77b4' } + }, + { + x: epochs, + y: loss_val, + type: 'scatter', + mode: 'lines+markers', + name: 'Validation Loss', + line: { width: 3, color: '#ff7f0e', shape: 'spline' }, + marker: { size: 8, color: '#ff7f0e' } + } + ], { + template: 'plotly_white', + xaxis: { + title: 'Epoch', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false + }, + yaxis: { + title: 'Loss', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false, + range: [0.2, 1.4] + }, + legend: { + orientation: 'h', + yanchor: 'bottom', + y: 1.02, + xanchor: 'right', + x: 1.0 + }, + margin: { l: 60, r: 20, t: 40, b: 50 }, + plot_bgcolor: 'white', + paper_bgcolor: 'white' + }); + console.log('Loss plot created'); + } catch(e) { + console.error('Error creating loss plot:', e); + } + + // Threshold Plot (with fixed data matching the expected visualization) + // Precision: starts low (~0.42), increases to peak (~0.95) around threshold 0.5, then drops sharply to 0 + // Recall: starts at 1, stays at 1 until ~0.1, then decreases to 0 around 0.9 + // F1: starts around 0.6, peaks around 0.5 at ~0.95, then drops to 0 + const thresholds = Array.from({length: 101}, (_, i) => i / 100); + + // Generate precision values: low at start, peak around 0.5, drop to 0 + const precision_vals = thresholds.map(t => { + if (t <= 0.5) { + return 0.42 + (t / 0.5) * 0.53; // 0.42 to 0.95 + } else if (t <= 0.9) { + return 0.95 - ((t - 0.5) / 0.4) * 0.95; // 0.95 to 0 + } else { + return 0; + } + }); + + // Generate recall values: starts at 1, stays at 1 until ~0.1, then decreases to 0 around 0.9 + const recall_vals = thresholds.map(t => { + if (t <= 0.1) { + return 1.0; + } else if (t <= 0.9) { + return 1.0 - ((t - 0.1) / 0.8) * 1.0; // 1.0 to 0 + } else { + return 0; + } + }); + + // Calculate F1 from precision and recall + const f1_vals = thresholds.map((t, i) => { + const p = precision_vals[i]; + const r = recall_vals[i]; + if (p + r === 0) return 0; + return 2 * (p * r) / (p + r); + }); + + // Queue Rate: decreases linearly from 1 to 0 + const queue_rate = thresholds.map(t => 1 - t); + + try { + Plotly.newPlot('threshold-plot', [ + { + x: thresholds, + y: precision_vals, + type: 'scatter', + mode: 'lines', + name: 'Precision', + line: { width: 3, color: '#1f77b4' } + }, + { + x: thresholds, + y: recall_vals, + type: 'scatter', + mode: 'lines', + name: 'Recall', + line: { width: 3, color: '#ff7f0e' } + }, + { + x: thresholds, + y: f1_vals, + type: 'scatter', + mode: 'lines', + name: 'F1', + line: { width: 3, color: '#2ca02c' } + }, + { + x: thresholds, + y: queue_rate, + type: 'scatter', + mode: 'lines', + name: 'Queue Rate', + line: { width: 2, color: '#808080', dash: 'dash' } + } + ], { + template: 'plotly_white', + xaxis: { + title: 'Discrimination Threshold', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false + }, + yaxis: { + title: 'Score', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false, + range: [0, 1] + }, + legend: { + orientation: 'h', + yanchor: 'bottom', + y: 1.02, + xanchor: 'right', + x: 1.0 + }, + margin: { l: 60, r: 20, t: 40, b: 50 }, + plot_bgcolor: 'white', + paper_bgcolor: 'white', + shapes: [{ + type: 'line', + xref: 'x', + yref: 'y', + x0: 0.51, + y0: 0, + x1: 0.51, + y1: 1, + line: { + color: 'black', + width: 2, + dash: 'dash' + } + }], + annotations: [{ + x: 0.51, + y: 0.85, + text: 't* = 0.51', + showarrow: false, + font: { size: 12, color: 'black' } + }] + }); + console.log('Threshold plot created'); + } catch(e) { + console.error('Error creating threshold plot:', e); + } + + // Confusion Matrix (matching imagelearner style with Blues colorscale) + const cm_data = [[542, 78], [65, 515]]; + const total = cm_data.flat().reduce((a, b) => a + b, 0); + const labels = ['0', '1']; + + // Build annotations array with all white text + const annotations = []; + cm_data.forEach((row, i) => { + row.forEach((val, j) => { + const pct = ((val / total) * 100).toFixed(1); + // All text is white + const textColor = 'white'; + // Count annotation (bold, bottom) + annotations.push({ + x: labels[j], + y: labels[i], + text: '<b>' + val + '</b>', + showarrow: false, + font: { color: textColor, size: 14 }, + xanchor: 'center', + yanchor: 'bottom', + yshift: 2 + }); + // Percentage annotation (top) + annotations.push({ + x: labels[j], + y: labels[i], + text: pct + '%', + showarrow: false, + font: { color: textColor, size: 13 }, + xanchor: 'center', + yanchor: 'top', + yshift: -2 + }); + }); + }); + + try { + Plotly.newPlot('confusion-matrix', [{ + z: cm_data, + x: labels, + y: labels, + type: 'heatmap', + colorscale: 'Blues', + showscale: true, + colorbar: { title: 'Count' }, + xgap: 2, + ygap: 2, + hovertemplate: 'True=%{y}<br>Pred=%{x}<br>Count=%{z}<extra></extra>', + zmin: 0 + }], { + xaxis: { title: 'Predicted label', type: 'category' }, + yaxis: { title: 'True label', type: 'category', autorange: 'reversed' }, + margin: { l: 80, r: 20, t: 40, b: 80 }, + template: 'plotly_white', + plot_bgcolor: 'white', + paper_bgcolor: 'white', + annotations: annotations + }); + console.log('Confusion matrix created'); + } catch(e) { + console.error('Error creating confusion matrix:', e); + } + + // Per-Class Metrics + const classes = ['Class 0', 'Class 1']; + const precision_per_class = [0.8929, 0.8685]; + const recall_per_class = [0.8742, 0.8879]; + const f1_per_class = [0.8835, 0.8781]; + + try { + Plotly.newPlot('per-class-metrics', [ + { + x: classes, + y: precision_per_class, + type: 'bar', + name: 'Precision', + marker: { color: '#4CAF50' } + }, + { + x: classes, + y: recall_per_class, + type: 'bar', + name: 'Recall', + marker: { color: '#2196F3' } + }, + { + x: classes, + y: f1_per_class, + type: 'bar', + name: 'F1', + marker: { color: '#FF9800' } + } + ], { + template: 'plotly_white', + xaxis: { + title: 'Class', + gridcolor: '#e0e0e0', + showgrid: true + }, + yaxis: { + title: 'Score', + gridcolor: '#e0e0e0', + showgrid: true, + range: [0, 1] + }, + barmode: 'group', + legend: { + orientation: 'h', + yanchor: 'bottom', + y: 1.02, + xanchor: 'right', + x: 1.0 + }, + margin: { l: 60, r: 20, t: 40, b: 50 }, + plot_bgcolor: 'white', + paper_bgcolor: 'white' + }); + console.log('Per-class metrics created'); + } catch(e) { + console.error('Error creating per-class metrics:', e); + } + + // ROC Curve (with fixed data) + const fpr = [0,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.10,0.11,0.12,0.13,0.14,0.15,0.16,0.17,0.18,0.19,0.20,0.21,0.22,0.23,0.24,0.25,0.26,0.27,0.28,0.29,0.30,0.31,0.32,0.33,0.34,0.35,0.36,0.37,0.38,0.39,0.40,0.41,0.42,0.43,0.44,0.45,0.46,0.47,0.48,0.49,0.50,0.51,0.52,0.53,0.54,0.55,0.56,0.57,0.58,0.59,0.60,0.61,0.62,0.63,0.64,0.65,0.66,0.67,0.68,0.69,0.70,0.71,0.72,0.73,0.74,0.75,0.76,0.77,0.78,0.79,0.80,0.81,0.82,0.83,0.84,0.85,0.86,0.87,0.88,0.89,0.90,0.91,0.92,0.93,0.94,0.95,0.96,0.97,0.98,0.99,1.00]; + const tpr = [0,0.12,0.23,0.33,0.42,0.50,0.57,0.63,0.69,0.74,0.78,0.82,0.85,0.88,0.90,0.92,0.94,0.95,0.96,0.97,0.98,0.985,0.99,0.992,0.994,0.996,0.997,0.998,0.999,0.9995,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0]; + const auc_score = 0.935; + + try { + Plotly.newPlot('roc-curve', [ + { + x: [0, 1], + y: [0, 1], + type: 'scatter', + mode: 'lines', + name: 'Random (AUC = 0.50)', + line: { dash: 'dash', color: 'gray', width: 2 } + }, + { + x: fpr, + y: tpr, + type: 'scatter', + mode: 'lines', + name: `ROC Curve (AUC = ${auc_score.toFixed(3)})`, + line: { width: 3, color: '#1f77b4' } + } + ], { + template: 'plotly_white', + xaxis: { + title: 'False Positive Rate', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false, + range: [0, 1] + }, + yaxis: { + title: 'True Positive Rate', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false, + range: [0, 1] + }, + legend: { + orientation: 'h', + yanchor: 'bottom', + y: 1.02, + xanchor: 'right', + x: 1.0 + }, + margin: { l: 60, r: 20, t: 40, b: 50 }, + plot_bgcolor: 'white', + paper_bgcolor: 'white' + }); + console.log('ROC curve created'); + } catch(e) { + console.error('Error creating ROC curve:', e); + } + + // Precision-Recall Curve (with fixed data) + const recall_pr = [0,0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.10,0.11,0.12,0.13,0.14,0.15,0.16,0.17,0.18,0.19,0.20,0.21,0.22,0.23,0.24,0.25,0.26,0.27,0.28,0.29,0.30,0.31,0.32,0.33,0.34,0.35,0.36,0.37,0.38,0.39,0.40,0.41,0.42,0.43,0.44,0.45,0.46,0.47,0.48,0.49,0.50,0.51,0.52,0.53,0.54,0.55,0.56,0.57,0.58,0.59,0.60,0.61,0.62,0.63,0.64,0.65,0.66,0.67,0.68,0.69,0.70,0.71,0.72,0.73,0.74,0.75,0.76,0.77,0.78,0.79,0.80,0.81,0.82,0.83,0.84,0.85,0.86,0.87,0.88,0.89,0.90,0.91,0.92,0.93,0.94,0.95,0.96,0.97,0.98,0.99,1.00]; + const precision_pr = [0.95,0.949,0.948,0.947,0.946,0.945,0.944,0.943,0.942,0.941,0.940,0.939,0.938,0.937,0.936,0.935,0.934,0.933,0.932,0.931,0.930,0.929,0.928,0.927,0.926,0.925,0.924,0.923,0.922,0.921,0.920,0.919,0.918,0.917,0.916,0.915,0.914,0.913,0.912,0.911,0.910,0.909,0.908,0.907,0.906,0.905,0.904,0.903,0.902,0.901,0.900,0.899,0.898,0.897,0.896,0.895,0.894,0.893,0.892,0.891,0.890,0.889,0.888,0.887,0.886,0.885,0.884,0.883,0.882,0.881,0.880,0.879,0.878,0.877,0.876,0.875,0.874,0.873,0.872,0.871,0.870,0.869,0.868,0.867,0.866,0.865,0.864,0.863,0.862,0.861,0.860,0.859,0.858,0.857,0.856,0.855,0.854,0.853,0.852,0.851,0.850]; + const ap_score = 0.9234; + + try { + Plotly.newPlot('pr-curve', [ + { + x: recall_pr, + y: precision_pr, + type: 'scatter', + mode: 'lines', + name: `Precision-Recall Curve (AP = ${ap_score.toFixed(4)})`, + line: { width: 3, color: '#1f77b4' } + } + ], { + template: 'plotly_white', + xaxis: { + title: 'Recall', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false, + range: [0, 1] + }, + yaxis: { + title: 'Precision', + gridcolor: '#e0e0e0', + showgrid: true, + zeroline: false, + range: [0.6, 1] + }, + legend: { + orientation: 'h', + yanchor: 'bottom', + y: 1.02, + xanchor: 'right', + x: 1.0 + }, + margin: { l: 60, r: 20, t: 40, b: 50 }, + plot_bgcolor: 'white', + paper_bgcolor: 'white' + }); + console.log('PR curve created'); + } catch(e) { + console.error('Error creating PR curve:', e); + } + + console.log('All plots created successfully!'); + } + + // Start creating plots when DOM is ready + if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', createPlots); + } else { + createPlots(); + } + </script> + </div> +</body> +</html> +
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test-data/test.csv Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,51 @@ +feature_0,feature_1,feature_2,feature_3,feature_4,text,target,image_path +0.25739992534469336,-0.288440183,0.20367539201504153,2.504864685302447,2.0388993649376554,Hypopharynxkarzinom[C13.9 ] Larynxkarzinom[C32.9 ] Halslymphknotenmetastasen[C77.0 B],0,image_0.png +-0.908481433,-1.067713065,-1.8565207,-1.333461127,0.3848107911787176,Supraglottisches Karzinom[C32.1 L] Halslymphknotenmetastasen[C77.0 ],1,image_1.png +-0.378503106,-1.077600898,0.3781894561704642,-1.028272654,-0.700527555,Bösartige Neubildung der Gaumenbogentonsillen[C09.1 L],1,image_2.png +-0.534915599,-0.796773763,1.2361403151716697,1.3631696640866262,-1.497154707,"Karzinom des Sulcus glossopalatinus[C09.1 ] Neubildung unsicheren oder unbekannten Verhaltens: Lippe, Mundhöhle und Pharynx[D37.0 ]",1,image_3.png +0.8580733460716252,-1.486042582,-0.689971804,-1.142274593,-0.097944716,Karzinom Hypopharynx mehrere Teilbereiche überlappend[C13.8 R],0,image_4.png +-0.413009982,0.5141287656735564,-0.198354692,-1.372025465,-0.749649593,Bösartige Neubildung: Weicher Gaumen[C05.1 ] Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 L],1,image_5.png +0.49818858448728287,0.8517908644430702,-3.025739873,2.1153714355164954,-0.423670688,Glottiskarzinom[C32.0 R] Halslipom[D17.0 R],1,image_6.png +2.0101992475719856,0.9586734372734298,-1.245626223,-0.274274886,0.3456524588826128,Bösartige Neubildung: Glottis[C32.0 ],0,image_7.png +1.2628615445176543,-0.626484049,-2.084039472,-0.250145423,0.23501780256852894,Bösartige Neubildung Hypopharynx mehrere Teilbereiche überlappend[C13.8 ] Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 ],1,image_8.png +-0.439214857,0.3079310128755357,0.10412164629262519,-0.5088304,0.19390051775098527,Supraglottisches Karzinom[C32.1 ],0,image_9.png +-0.346437893,0.005205685,-0.150233494,0.5813407294084169,1.3922625492889247,,0,image_10.png +0.4553196595572467,0.6915319076425244,-0.105816503,-1.133100081,0.4079556930770453,Halslymphknotenmetastasen[C77.0 B],1,image_11.png +-1.668662707,0.4448621590773117,-0.436954139,1.7743892245873165,2.7134802556937045,Neubildung bösartig Gaumenbogen (vorderer) (hinterer)[C09.1 ],1,image_12.png +-0.862085495,0.090279532,1.019136316265232,0.7530463240026412,0.1613349250415128,"Bösartige Neubildung: Oropharynx, nicht näher bezeichnet[C10.9 ]",0,image_13.png +0.49291084812374153,-1.858342899,-1.232803226,-0.051800232,-0.40689619,Zungenrandkarzinom[C02.1 R] Halslymphknotenmetastasen[C77.0 B],1,image_14.png +-0.124313396,-0.166580044,-0.074124538,-0.551705599,-1.497093461,"Zungenrandkarzinom[C02.1 ] Neubildung unsicheren oder unbekannten Verhaltens: Lippe, Mundhöhle und Pharynx[D37.0 ]",0,image_15.png +1.9351362874003422,0.11087648013745463,0.5963891827515404,0.10825694005967285,-1.314371547,Bösartige Neubildung der Seitenwand des Oropharynx[C10.2 L],1,image_16.png +-0.618442655,-0.694772644,-0.179392562,1.0439696514243961,0.9068151425024749,Bösartige Neubildung: Glottis[C32.0 L],0,image_17.png +-1.046838986,-0.269175566,-0.090795068,2.6174426286362076,0.6489323286331116,Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 ],1,image_18.png +-0.889617589,-1.29922515,-0.825079629,1.508768374358993,-0.089698726,,0,image_19.png +0.014040537,-0.321105447,-0.334969461,1.012902487072667,-0.224378814,Bösartige Neubildung: Zungenrand[C02.1 ],1,image_20.png +-0.160829686,0.5058687421604646,0.7183508075095323,0.043045740547390625,-0.887604847,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ] Bösartige Neubildung: Hypopharynx, mehrere Teilbereiche überlappend[C13.8 ]",0,image_21.png +2.2303596488319384,2.0890595712039643,-2.019462185,1.2325046106386215,-0.16636758,Tonsillenkarzinom[C09.9 ] Halslymphknotenmetastasen[C77.0 ],1,image_22.png +-0.399115719,-1.012709255,-0.158029049,-0.346410841,0.1234889741068257,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ]",0,image_23.png +0.054444563,-0.02397407,-0.584215128,0.3318833834396524,-2.017685901,Bösartige Neubildung der Glottis[C32.0 R],1,image_24.png +0.8841818151951191,-0.96146905,-0.070306498,1.1706959283854066,0.7251931071367159,Karzinom Hypopharynx mehrere Teilbereiche überlappend[C13.8 R],1,image_25.png +-0.107980561,-0.092566189,1.7731812300260754,0.6067736917185195,-2.02724386,cT1b-Glottis-CA[C32.0 ],1,image_26.png +0.5556069842518202,-0.223732081,-0.446629293,-0.664692993,0.33697833144899764,Neubildung bösartig Hypopharynx sonstige[C13.8 L] Neubildung bösartig Hypopharynx sonstige[C13.8 L],0,image_27.png +0.3949066359501769,0.832892162,1.4893745786091637,-1.532773874,-0.044698324,Bösartige Neubildung: Vallecula epiglottica[C10.0 ],0,image_28.png +0.8372050226472976,0.9741195783904064,1.100308721372732,0.5014966096525523,0.29207407178086175,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ] Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 ]",1,image_29.png +-1.40787817,0.16281815769810043,0.029896317954142267,0.13463643962491514,-0.648498896,"Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 ] Bösartige Neubildung: Tonsille, mehrere Teilbereiche überlappend[C09.8 ]",1,image_30.png +0.8078494136349023,-0.11449202,0.087105811,0.9299897685604038,1.6934782478194783,,1,image_31.png +-0.138283645,1.1864684326166455,0.18123494219161268,0.057889462,0.7180924802384169,Bösartige Neubildung der Glottis[C32.0 L],1,image_32.png +0.18717858533399576,0.17979165086317853,-0.825171797,-0.379583634,0.6354161714637024,Halslymphknotenmetastasen[C77.0 L] Tonsillenkarzinom[C09.9 L],0,image_33.png +-0.386658136,1.5164416176334854,0.6268387669020417,1.0539531037975378,-0.178935605,Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 R],1,image_34.png +1.6590487297919927,-1.63403202,-0.612697847,0.9963133530846225,-0.075254355,Bösartige Neubildung der Seitenwand des Oropharynx[C10.2 L],1,image_35.png +-2.047069131,1.7819709009957654,0.035742336,-2.109193418,0.8856980875005901,,0,image_36.png +1.3993169863061743,-0.617727701,-2.195962611,-1.519115605,1.3022907171692861,,1,image_37.png +-0.679007121,-1.086399293,-0.863282689,0.12983462557747558,1.4001751255531105,Karzinom Tonsille mehrere Teilbereiche überlappend[C09.8 ],1,image_38.png +1.5289851345959593,-0.044753328,0.5350993630394176,-0.065237868,-0.046387298,"Bösartige Neubildung: Tonsille, mehrere Teilbereiche überlappend[C09.8 ]",1,image_39.png +1.2212159632534685,-1.16370242,-1.123463295,-0.18698505,0.056701348,Bösartige Neubildung: Zungenrand[C02.1 ],1,image_40.png +1.0149885167826984,0.5689394950042188,0.33396509272135233,0.8239324982417736,0.33108021768184237,Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 ],1,image_41.png +0.8281299823497013,-0.462643856,-0.507230455,2.0154984149609985,-0.261780894,Bösartige Neubildung Oropharynx Seitenwand[C10.2 ] Bösartige Neubildung Hypopharynx mehrere Teilbereiche überlappend[C13.8 ] Bösartige Neubildung Nasopharynx Seitenwand[C11.2 ],1,image_42.png +2.2662927082101914,0.28735510483252746,0.9901501711281284,-1.031448133,-0.70926492,"Bösartige Neubildung: Larynx, mehrere Teilbereiche überlappend[C32.8 ]",1,image_43.png +-0.59495567,-0.977338711,1.2579439688721659,-0.067517776,0.16645918553019404,Bösartige Neubildung Hypopharynx mehrere Teilbereiche überlappend[C13.8 ],0,image_44.png +-0.581269538,-0.222436258,0.17762491871456892,-0.025381558,0.47332265126602674,Karzinom Gaumen und Uvula[C05.8 ],1,image_45.png +-0.655894146,0.096931768,-0.014867416,0.2052931948319547,0.5237041770156639,Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 B] Halslymphknotenmetastasen[C77.0 B],0,image_46.png +0.9251488533568732,-0.188384349,0.040436114,-1.223639731,-0.075535658,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ]",1,image_47.png +-1.299161344,1.5916021583729445,-0.02144477,0.46742987833997945,-0.658039055,,0,image_48.png +1.0111668672107617,0.5858108272747818,1.1916565829964165,0.036701016,0.4992660132510926,Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 R] Dyspnoe onA[R06.0 ],0,image_49.png \ No newline at end of file
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test-data/train.csv Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,177 @@ +feature_0,feature_1,feature_2,feature_3,feature_4,text,target,image_path +0.4967141530112327,-1.415370742,0.3577873603482833,-0.828995011,-1.594427659,Glottiskarzinom[C32.0 L],1,image_0.png +-0.138264301,-0.420645323,0.5607845263682344,-0.56018104,-0.599375023,Bösartige Neubildung: Zungenrand[C02.1 ],1,image_1.png +0.6476885381006925,-0.342714517,1.083051243175277,0.7472936051232618,0.0052437,Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 ],1,image_2.png +1.5230298564080254,-0.802277269,1.053802052034903,0.6103702654334648,0.046980593764742055,Neubildung unsicheren oder unbekannten Verhaltens: Sonstige näher bezeichnete Lokalisationen[D48.7 ] Rachenmandelkarzinom[C11.1 R] Leukoplakie Stimmlippe[J38.3 L],0,image_3.png +-0.234153375,-0.161285712,-1.377669368,-0.020901594,-0.450065471,Bösartige Neubildung der Supraglottis[C32.1 R],1,image_4.png +-0.234136957,0.4040508568145384,-0.93782504,0.117327383,0.6228499323474987,Uvulakarzinom[C05.2 ],0,image_5.png +1.5792128155073915,1.8861859012105302,0.5150352672086598,1.277664895788425,-1.067620429,"Bösartige Neubildung: Zunge, mehrere Teilbereiche überlappend[C02.8 ] Bösartige Neubildung im Bereich der Zungengruben[C10.0 ]",0,image_6.png +0.7674347291529088,0.17457781283183896,0.5137859509122088,-0.591571389,-0.142379485,Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 ],1,image_7.png +-0.469474386,0.25755039072276437,0.5150476863060479,0.5470973811700379,0.12029563171189886,Neubildung bösartig Hypopharynx sonstige[C13.8 L],0,image_8.png +0.5425600435859647,-0.074445916,3.852731490654721,-0.202192652,0.514438834,Bösartige Neubildung: Glottis[C32.0 ],0,image_9.png +-0.463417693,-1.918771215,0.570890511,-0.217681203,0.7116148780888898,Unsichere Neubildung der Übergangsregion des Oropharynx[D37.0 ],0,image_10.png +-0.465729754,-0.026513875,1.135565640180599,1.098776852,-1.124642092,Karzinom Tonsille mehrere Teilbereiche überlappend[C09.8 R],0,image_11.png +0.24196227156603412,0.06023021,0.9540017634932023,0.8254163489880298,-1.534114171,Bösartige Neubildung: Glottis[C32.0 ],1,image_12.png +-1.913280245,2.463242112485286,0.651391251,0.8135096360006385,1.277676821898509,"Bösartige Neubildung: Hypopharynx, mehrere Teilbereiche überlappend[C13.8 ]",0,image_13.png +-1.724917833,-0.192360965,-0.315269245,1.305478807154329,0.33231401197959165,"Bösartige Neubildung: Hypopharynx, mehrere Teilbereiche überlappend[C13.8 ]",1,image_14.png +-0.562287529,0.30154734233361247,0.7589692204932674,0.021003842,-0.748486537,Halslymphknotenmetastasen[C77.0 B],1,image_15.png +-1.01283112,-0.03471177,-0.772825215,0.6819529712949639,1.551151975522523,Bösartige Neubildung: Glottis[C32.0 ] Bösartige Neubildung der Glottis[C32.0 B],0,image_16.png +0.3142473325952739,-1.168678038,-0.236818607,-0.310266757,0.11567463429285867,Karzinom des Zungenrandes und der Zungenspitze[C02.1 ],1,image_17.png +-0.908024076,1.1428228145150205,-0.485363548,0.3241663524884421,1.1792971840638264,CUP [Cancer of Unknown Primary][C80.0 ] Halsmetastasen[C77.0 ],1,image_18.png +-1.412303701,0.7519330326867741,0.081874139,-0.130143054,0.067518481,Zungenrandkarzinom[C02.1 ],1,image_19.png +1.465648768921554,0.7910319470430469,2.3146585666735087,0.096995965,2.060747924881987,Bösartige Neubildung: Zungenrand[C02.1 ] Neubildung bösartig Mundboden sonstige[C04.8 R] Halslymphknotenmetastasen[C77.0 B],0,image_20.png +-0.2257763,-0.909387455,-1.867265193,0.5951570254369136,1.7553408424432044,Tonsillenkarzinom[C09.9 ] Halslymphknotenmetastasen[C77.0 ],1,image_21.png +0.067528205,1.4027943109360992,0.6862601903745135,-0.818220683,-0.248964148,Bösartige Neubildung: Glottis[C32.0 ],1,image_22.png +-1.424748186,-1.401851063,-1.612715871,2.0923872756854602,0.9715709509543554,Glottiskarzinom[C32.0 L],1,image_23.png +-0.544382725,0.5868570938002703,-0.471931866,-1.006017381,0.6453759495851475,Karzinom Tonsille mehrere Teilbereiche überlappend[C09.8 L],0,image_24.png +0.11092258970986608,2.1904556258099785,1.088950596967366,-1.214188613,1.3686315575323487,Bösartige Neubildung: Glottis[C32.0 ] Neubildung bösartig Kehlkopf sonstige[C32.8 L],1,image_25.png +-1.150993577,-0.990536325,0.064280019,1.1581108735000678,-0.964923461,"Bösartige Neubildung: Oropharynx, nicht näher bezeichnet[C10.9 ]",0,image_26.png +0.37569801834567196,-0.56629773,-1.077744778,0.7916626939629359,0.6860514599984393,"Bösartige Neubildung: Hypopharynx, nicht näher bezeichnet[C13.9 ]",0,image_27.png +-0.60063869,0.099651365,-0.715303709,0.6241198170521551,1.0584244868495878,Bösartige Neubildung des inneren Larynx[C32.0 L],0,image_28.png +-0.29169375,-0.503475654,0.6795977489346758,0.6283455092642799,-1.758739486,Bösartige Neubildung des postkrikoidalen Hypopharynx[C13.0 ] Neubildung bösartig Hypopharynx sonstige[C13.8 ],0,image_29.png +-0.601706612,-1.550663431,-0.730366632,-0.012246773,-1.183258513,Karzinom bei unbekanntem Primärtumor (CUP)[C80.0 ],0,image_30.png +1.8522781845089378,0.068562975,0.21645858958197486,-0.897254371,-2.039232178,Bösartige Neubildung: Seitenwand des Oropharynx[C10.2 ],0,image_31.png +-0.013497225,-1.062303714,0.045571839903813784,0.075804558,-0.269406834,Nachblutung[T81.0 B],1,image_32.png +-1.057710929,0.4735924306351816,-0.651600348,-0.677161712,0.7175422557959623,Unsichere Neubildung des seitlichen Zungenrandes[D37.0 R],0,image_33.png +0.822544912,-0.919424234,2.1439440893253257,0.9751197334177512,1.502357052096028,Bösartige Neubildung des Mundrachenraums[C14.8 ],1,image_34.png +-1.22084365,1.5499344050175394,0.6339190223180112,-0.147057382,0.07409478,Karzinom Oropharynx Seitenwand[C10.2 R] Neubildung bösartig sekundär und onA Lymphknoten Kopf Gesicht Hals[C77.0 R],1,image_35.png +0.2088635950047554,-0.783253292,-2.025142587,-0.825497197,1.6286155455712918,Zungenrandkarzinom[C02.1 ],0,image_36.png +-1.959670124,-0.322061516,0.18645431476942764,-0.321385842,-1.380101458,Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 B],1,image_37.png +-1.328186049,0.8135172173696698,-0.661786465,0.41293145427562433,-1.703382439,Bösartige Neubildung des weichen Gaumens[C05.1 R],1,image_38.png +0.19686123586912352,-1.230864316,0.852433335,-0.563724553,-0.055547699,,1,image_39.png +0.7384665799954104,0.22745993460412942,-0.792520738,-0.822220396,0.3840654489393073,Stimmlippenkarzinom[C32.0 L] Struma nodosa euthyreot[E04.1 ],1,image_40.png +0.1713682811899705,1.307142754282428,-0.114736441,0.2436872114919123,-0.032694748,Glottiskarzinom[C32.0 L],0,image_41.png +-0.115648282,-1.607483235,0.5049872789804571,0.24496657110872277,-2.0674421,Bösartige Neubildung: Glottis[C32.0 ],1,image_42.png +-0.301103696,0.1846338585323042,0.8657551941701215,-0.506943175,-0.08912004,Neubildung unsicheren oder unbekannten Verhaltens: Larynx[D38.0 ] Neubildung bösartig Kehlkopf sonstige[C32.8 L],0,image_43.png +-1.47852199,0.25988279424842353,-1.200296407,-0.471038306,-1.304469501,"Bösartige Neubildung: Larynx, mehrere Teilbereiche überlappend[C32.8 ]",1,image_44.png +-0.719844208,0.7818228717773104,-0.334501236,0.2320499373576363,0.6696725488300385,Bösartige Neubildung der Glottis[C32.0 ],1,image_45.png +-0.460638771,-1.236950711,-0.474945311,-1.448084341,0.36659824609684827,Bösartige Neubildung des Zungenrandes[C02.1 ],0,image_46.png +1.0571222262189157,-1.320456613,-0.653329233,-1.407463774,-0.939879786,"Neubildung unsicheren oder unbekannten Verhaltens: Lippe, Mundhöhle und Pharynx[D37.0 ] Funktionsstörung des Tracheostomas[J95.0 ]",1,image_47.png +0.3436182895684614,0.5219415656168976,1.7654542402810969,-0.718444221,-0.513866917,Stimmlippenkarzinom[C32.0 R],1,image_48.png +-1.763040155,0.29698467323318606,0.40498171096095553,-0.213447152,-1.059213522,Karzinom des Mundrachenraums[C14.8 R],1,image_49.png +0.324083969,0.25049285034587654,-1.260883954,0.3109075655980046,-0.062679097,"Bösartige Neubildung: Larynx, mehrere Teilbereiche überlappend[C32.8 ]",0,image_50.png +-0.38508228,0.3464482094969757,0.9178619470547761,1.475356216949552,0.9551423205012383,,1,image_51.png +-0.676922,-0.680024722,2.1221561970126332,0.8576596232020194,-0.985726046,Karzinom Mundboden mehrere Teilbereiche überlappend[C04.8 L],0,image_52.png +0.6116762888408679,0.23225369716100355,1.0324652605511468,-0.15993853,0.5040465155178444,Bösartige Neubildung: Glottis[C32.0 ],1,image_53.png +1.030999522495951,0.29307247329868125,-1.519369966,-0.019016208,-0.530257618,Bösartige Neubildung der Supraglottis[C32.1 B],0,image_54.png +0.9312801191161986,-0.714351418,-0.484234073,-1.002529365,-0.792872832,Bösartige Neubildung Hypopharynx mehrere Teilbereiche überlappend[C13.8 L],0,image_55.png +-0.839217523,1.8657745111447566,1.2669111491866227,-0.018513136,-0.10703036,Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 ],1,image_56.png +-0.309212376,0.4738329209117875,-0.707669466,-0.288658639,-1.035242322,"Bösartige Neubildung der Gaumenbogentonsillen[C09.1 L] Lymphknotenvergrößerung, umschrieben[R59.0 ]",1,image_57.png +0.33126343140356396,-1.191303497,0.44381942814622843,0.3227185603380895,-0.553649305,Bösartige Neubildung des Zungenrandes und der Zungenspitze[C02.1 ],0,image_58.png +0.9755451271223592,0.6565536086338297,0.7746340534293368,-0.827230944,-1.197877893,Karzinom der Plica aryepiglottica[C13.1 L],0,image_59.png +-0.479174238,-0.97468167,-0.926930472,0.5193465142411723,1.9647251329163893,Bösartige Neubildung: Glottis[C32.0 ],1,image_60.png +-0.185658977,0.787084604,-0.059525356,1.5327389130025777,0.035263552,"Bösartige Neubildung: Tonsille, mehrere Teilbereiche überlappend[C09.8 ]",0,image_61.png +-1.106334974,1.158595579007404,-3.24126734,-0.108760148,-0.699725508,Karzinom Tonsille mehrere Teilbereiche überlappend[C09.8 R],1,image_62.png +-1.196206624,-0.820682318,-1.024387641,0.40171172209894146,0.213979911,Bösartige Neubildung: Glottis[C32.0 ] Glottiskarzinom[C32.0 R],1,image_63.png +0.812525822,0.9633761292443218,-0.252568151,0.6901439917111125,-0.11232805,Glottiskarzinom[C32.0 L] Bösartige Neubildung der Glottis[C32.0 L],1,image_64.png +1.356240028570823,0.4127809269364983,-1.247783182,-0.401220472,-0.2209696,,1,image_65.png +-0.072010122,0.82206016,1.6324113039316352,0.22409248181041677,0.6141667000434252,Bösartige Neubildung der Zungenunterfläche[C02.2 R],0,image_66.png +1.0035328978920242,1.8967929826539474,-1.430141378,0.012592401,0.7575077100473051,Bösartige Neubildung Hypopharynx mehrere Teilbereiche überlappend[C13.8 L] Neubildung bösartig sekundär und onA Lymphknoten Kopf Gesicht Hals[C77.0 B],1,image_67.png +0.36163602504763415,-0.245388116,-0.440044487,0.097676099,-0.530501148,Rachenmandelkarzinom[C11.1 ],0,image_68.png +-0.645119755,-0.753736164,0.13074057728609134,-0.773009784,-0.575818241,"Bösartige Neubildung: Hypopharynx, mehrere Teilbereiche überlappend[C13.8 ] Halslymphknotenmetastasen[C77.0 B]",1,image_69.png +0.36139560550841393,-0.88951443,1.4412732890661155,0.024510174258942714,-0.275051697,"Bösartige Neubildung: Oropharynx, nicht näher bezeichnet[C10.9 ]",0,image_70.png +1.5380365664659692,-0.815810285,-1.435862151,0.49799829124544975,-2.301921165,Bösartige Neubildung: Supraglottis[C32.1 ],1,image_71.png +-0.035826039,-0.077101709,1.1631637521549596,1.4511436077950417,-1.515191062,Karzinom Hypopharynx mehrere Teilbereiche überlappend[C13.8 L],0,image_72.png +1.5646436558140062,0.3411519748166439,0.010233061019587049,0.9592708260852069,1.3668742674445247,Bösartige Neubildung Oropharynx mehrere Teilbereiche überlappend[C10.8 ],0,image_73.png +-2.619745104,0.27669079933001905,-0.981508651,2.1531824575115563,1.6449677135012837,Carcinoma in situ der Epiglottis[D02.0 ],1,image_74.png +0.8219025043752238,0.8271832490360238,0.46210347426327075,-0.767347563,-0.24903604,Bösartige Neubildung der Gaumenbogentonsillen[C09.1 R],1,image_75.png +0.087047068,0.013001892,0.19905969557347003,0.8723206367206782,0.5765569630557664,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ]",1,image_76.png +-0.29900735,1.4535340771573169,-0.600216877,0.18334200573835174,0.3112501545435361,Uvulakarzinom[C05.2 ],1,image_77.png +0.091760777,-0.264656833,0.069802085,2.1898029332176723,3.0788808084552377,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ]",0,image_78.png +-1.987568915,2.720169166589619,-0.385313597,-0.808298285,1.1195749114345768,Bösartige Neubildung: Glottis[C32.0 ],1,image_79.png +-0.219671888,0.6256673477650062,0.11351734525124804,-0.839721842,-0.127917591,Bösartige Neubildung der Glottis[C32.0 ],0,image_80.png +0.3571125715117464,-0.857157556,0.6621306745210467,-0.599392645,-0.955540441,Stimmbandkarzinom[C32.0 L],0,image_81.png +1.477894044741516,-1.070892498,1.586016816145352,-2.123895724,-1.60644632,,1,image_82.png +-0.518270218,0.4824724152431853,-1.237815499,-0.525755022,0.2034636358672231,Zungenrandkarzinom[C02.1 ],0,image_83.png +-0.808493603,-0.223462785,2.1330333746562666,-0.759132662,-0.756350745,Bösartige Neubildung: Zungenrand[C02.1 ] Bösartige Neubildung: Zungenrand[C02.1 ],0,image_84.png +-0.501757044,0.714000494,-1.9520878,0.15039378647620763,-1.42225371,"Neubildung unsicheren oder unbekannten Verhaltens: Lippe, Mundhöhle und Pharynx[D37.0 ]",0,image_85.png +0.9154021177020741,0.47323762457354485,-0.151785095,0.34175597577715944,-0.646572884,Karzinom Tonsille mehrere Teilbereiche überlappend[C09.8 R],0,image_86.png +0.32875110965968446,-0.072828913,0.5883172064845765,1.8761708392158862,-1.081548004,Karzinom Tonsille mehrere Teilbereiche überlappend[C09.8 R] Halslymphknotenmetastasen[C77.0 B],1,image_87.png +-0.529760204,-0.846793718,0.28099186773503265,0.9504238381860503,1.6871416350725648,,1,image_88.png +0.5132674331133561,-1.514847225,-0.62269952,-0.576903656,0.8816397569494505,Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 ] Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 ],1,image_89.png +0.097077549,-0.446514952,-0.20812225,-0.898414671,-0.007972641,,1,image_90.png +0.9686449905328892,0.8563987943234723,-0.493000935,0.4919191715065057,1.4799441388900259,Bösartige Neubildung: Glottis[C32.0 ],0,image_91.png +-0.702053094,0.21409374413020396,-0.589364757,-1.320233207,0.077368308,,0,image_92.png +-0.327662147,-1.245738779,0.8496020970210246,1.8314587658543537,-0.861284201,Bösartige Neubildung des Zungenrandes[C02.1 R],1,image_93.png +-0.392108153,0.173180926,0.35701548596504734,1.179440120721287,1.5231240772696573,Supraglottisches Karzinom[C32.1 R],1,image_94.png +-1.463514948,0.3853173797288368,-0.692909595,-0.469175652,0.5389100436846587,Karzinom Hypopharynx mehrere Teilbereiche überlappend[C13.8 ] Karzinom Hypopharynx mehrere Teilbereiche überlappend[C13.8 ],1,image_95.png +0.29612027706457605,-0.883857436,0.8995998754332507,-1.713134529,-1.037246154,,0,image_96.png +0.26105527217988933,0.1537251059455279,0.30729952087660933,1.3538723741654128,-0.190338678,Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 B],1,image_97.png +0.005113457,0.058208718445999896,0.8128621188389601,-0.114539845,-0.875618253,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ]",1,image_98.png +-0.234587133,-1.142970298,0.6296288419236122,1.2378163119734618,-1.382799731,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ] Karzinom Tonsille mehrere Teilbereiche überlappend[C09.8 L]",1,image_99.png +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,,, +,,,,,Glottiskarzinom[C32.0 L],, +,,,,,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ]",, +,,,,,Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 L],, +,,,,,Larynxkarzinom[C32.9 ],, +,,,,,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ]",, +,,,,,Bösartige Neubildung: Zungenrand[C02.1 ] Zungenrandkarzinom[C02.1 ],, +,,,,,Bösartige Neubildung: Glottis[C32.0 ],, +,,,,,"Bösartige Neubildung: Tonsille, mehrere Teilbereiche überlappend[C09.8 ]",, +,,,,,,, +,,,,,Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 B],, +,,,,,Bösartige Neubildung der Supraglottis[C32.1 ],, +,,,,,Karzinom Tonsille mehrere Teilbereiche überlappend[C09.8 R],, +,,,,,Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 ],, +,,,,,Bösartige Neubildung der Glottis[C32.0 L],, +,,,,,"Bösartige Neubildung: Zunge, nicht näher bezeichnet[C02.9 ]",, +,,,,,"Bösartige Neubildung: Zunge, mehrere Teilbereiche überlappend[C02.8 ]",, +,,,,,"Bösartige Neubildung: Tonsille, mehrere Teilbereiche überlappend[C09.8 ] Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ]",, +,,,,,Bösartige Neubildung der Fossa tonsillaris[C09.0 R],, +,,,,,"Neubildung unsicheren oder unbekannten Verhaltens: Lippe, Mundhöhle und Pharynx[D37.0 ] Bösartige Neubildung der Zungenunterfläche[C02.2 ]",, +,,,,,Neubildung bösartig Hypopharynx sonstige[C13.8 ],, +,,,,,Karzinom der ventralen Zunge[C02.2 ],, +,,,,,Bösartige Neubildung des weichen Gaumens[C05.1 L],, +,,,,,Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 R],, +,,,,,Halsmetastase[C79.88 R],, +,,,,,Karzinom Oropharynx mehrere Teilbereiche überlappend[C10.8 ],, +,,,,,"Bösartige Neubildung: Hypopharynx, nicht näher bezeichnet[C13.9 ]",, +,,,,,Bösartige Neubildung: Zungenrand[C02.1 ],, +,,,,,Glottiskarzinom[C32.0 R],, +,,,,,Supraglottisches Karzinom[C32.1 B] Halslymphknotenmetastasen[C77.0 B],, +,,,,,CUP [Cancer of Unknown Primary][C80.0 L],, +,,,,,Bösartige Neubildung: Zungenrand[C02.1 ],, +,,,,,"Bösartige Neubildung: Hypopharynx, mehrere Teilbereiche überlappend[C13.8 ]",, +,,,,,Uvulakarzinom[C05.2 ],, +,,,,,Karzinom Uvula[C05.2 ],, +,,,,,Karzinom Hypopharynx mehrere Teilbereiche überlappend[C13.8 R],, +,,,,,"Bösartige Neubildung: Oropharynx, mehrere Teilbereiche überlappend[C10.8 ]",, +,,,,,cT3 OrpharynxCa links[H93.1 ] Bösartige Neubildung Oropharynx mehrere Teilbereiche überlappend[C10.8 L],, +,,,,,Bösartige Neubildung: Uvula[C05.2 ],, +,,,,,"Bösartige Neubildung: Tonsille, mehrere Teilbereiche überlappend[C09.8 ]",, +,,,,,"Bösartige Neubildung: Larynx, mehrere Teilbereiche überlappend[C32.8 ]",, +,,,,,Bösartige Neubildung: Vallecula epiglottica[C10.0 ],, +,,,,,Bösartige Neubildung: Zungenrand[C02.1 ] Bösartige Neubildung Zunge mehrere Teilbereiche überlappend[C02.8 L],, +,,,,,Bösartige Neubildung Larynx mehrere Teilbereiche überlappend[C32.8 ],, \ No newline at end of file
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/test_pipeline.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,74 @@ +from __future__ import annotations + +import logging +from typing import Dict, Optional + +import pandas as pd +from plot_logic import infer_problem_type +from training_pipeline import evaluate_predictor_all_splits, fit_summary_safely + +logger = logging.getLogger(__name__) + + +def run_autogluon_test_experiment( + predictor, + data_ctx: Dict[str, pd.DataFrame], + target_column: str, + eval_metric: Optional[str] = None, + ag_config: Optional[dict] = None, + problem_type: Optional[str] = None, +) -> Dict[str, object]: + """ + Evaluate a trained predictor on train/val/test splits using prepared data_ctx. + + data_ctx is typically the context returned by ``run_autogluon_experiment``: + { + "train": df_train, + "val": df_val, + "test_internal": df_test_internal, + "test_external": df_test_external, + "threshold": threshold, + } + """ + if predictor is None: + raise ValueError("predictor is required for evaluation.") + if data_ctx is None: + raise ValueError("data_ctx is required; usually from run_autogluon_experiment.") + + df_train = data_ctx.get("train") + df_val = data_ctx.get("val") + df_test_internal = data_ctx.get("test_internal") + df_test_external = data_ctx.get("test_external") + threshold = None + if ag_config is not None: + threshold = ag_config.get("threshold", threshold) + threshold = data_ctx.get("threshold", threshold) + + if problem_type is None: + # Prefer inferring from training data and predictor metadata + base_df = df_train if df_train is not None else df_test_external + problem_type = infer_problem_type(predictor, base_df, target_column) + + df_test_final = df_test_external if df_test_external is not None else df_test_internal + raw_metrics, ag_by_split = evaluate_predictor_all_splits( + predictor=predictor, + df_train=df_train, + df_val=df_val, + df_test=df_test_final, + label_col=target_column, + problem_type=problem_type, + eval_metric=eval_metric, + threshold_test=threshold, + df_test_external=df_test_external, + ) + + summary = fit_summary_safely(predictor) + + result = { + "problem_type": problem_type, + "raw_metrics": raw_metrics, + "ag_eval": ag_by_split, + "fit_summary": summary, + } + logger.info("Evaluation complete; splits: %s", list(raw_metrics.keys())) + return result
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/training_pipeline.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,551 @@ +from __future__ import annotations + +import contextlib +import importlib +import io +import json +import logging +import os +import tempfile +import uuid +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd +import torch +from autogluon.multimodal import MultiModalPredictor +from metrics_logic import compute_metrics_for_split, evaluate_all_transparency +from packaging.version import Version + +logger = logging.getLogger(__name__) + +# ---------------------- small utilities ---------------------- + + +def load_user_hparams(hp_arg: Optional[str]) -> dict: + """Parse --hyperparameters (inline JSON or path to .json).""" + if not hp_arg: + return {} + try: + s = hp_arg.strip() + if s.startswith("{"): + return json.loads(s) + with open(s, "r") as f: + return json.load(f) + except Exception as e: + logger.warning(f"Could not parse --hyperparameters: {e}. Ignoring.") + return {} + + +def deep_update(dst: dict, src: dict) -> dict: + """Recursive dict update (src overrides dst).""" + for k, v in (src or {}).items(): + if isinstance(v, dict) and isinstance(dst.get(k), dict): + deep_update(dst[k], v) + else: + dst[k] = v + return dst + + +@contextlib.contextmanager +def suppress_stdout_stderr(): + """Silence noisy prints from AG internals (fit_summary).""" + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): + yield + + +def ag_evaluate_safely(predictor, df: Optional[pd.DataFrame], metrics: Optional[List[str]] = None) -> Dict[str, float]: + """ + Call predictor.evaluate and normalize the output to a dict. + """ + if df is None or len(df) == 0: + return {} + try: + res = predictor.evaluate(df, metrics=metrics) + except TypeError: + if metrics and len(metrics) == 1: + res = predictor.evaluate(df, metrics[0]) + else: + res = predictor.evaluate(df) + if isinstance(res, (int, float, np.floating)): + name = (metrics[0] if metrics else "metric") + return {name: float(res)} + if isinstance(res, dict): + return {k: float(v) for k, v in res.items()} + return {"metric": float(res)} + + +# ---------------------- hparams & training ---------------------- +def build_mm_hparams(args, df_train: pd.DataFrame, image_columns: Optional[List[str]]) -> dict: + """ + Build hyperparameters for MultiModalPredictor. + Handles text checkpoints for torch<2.6 and merges user overrides. + """ + inferred_text_cols = [ + c for c in df_train.columns + if c != args.label_column + and str(df_train[c].dtype) == "object" + and df_train[c].notna().any() + ] + text_cols = inferred_text_cols + + ag_version = None + try: + ag_mod = importlib.import_module("autogluon") + ag_ver = getattr(ag_mod, "__version__", None) + if ag_ver: + ag_version = Version(str(ag_ver)) + except Exception: + ag_mod = None + + def _log_missing_support(key: str) -> None: + logger.info( + "AutoGluon version %s does not expose '%s'; skipping override.", + ag_version or "unknown", + key, + ) + + hp = {} + + # Setup environment + hp["env"] = { + "seed": int(args.random_seed) + } + + # Set eval metric through model config + model_block = hp.setdefault("model", {}) + if args.eval_metric: + model_block.setdefault("metric_learning", {})["metric"] = str(args.eval_metric) + + if text_cols and Version(torch.__version__) < Version("2.6"): + safe_ckpt = "distilbert-base-uncased" + logger.warning(f"Forcing HF text checkpoint with safetensors: {safe_ckpt}") + hp["model.hf_text.checkpoint_name"] = safe_ckpt + hp.setdefault( + "model.names", + ["hf_text", "timm_image", "numerical_mlp", "categorical_mlp", "fusion_mlp"], + ) + + def _is_valid_hp_dict(d) -> bool: + if not isinstance(d, dict): + logger.warning("User-supplied hyperparameters must be a dict; received %s", type(d).__name__) + return False + return True + + user_hp = args.hyperparameters if isinstance(args.hyperparameters, dict) else load_user_hparams(args.hyperparameters) + if user_hp and _is_valid_hp_dict(user_hp): + hp = deep_update(hp, user_hp) + + # Map CLI knobs into AutoMM optimization hyperparameters when provided. + # We set multiple common key names (nested dicts and dotted flat keys) to + # maximize compatibility across AutoMM/AutoGluon versions. + try: + if any(getattr(args, param, None) is not None for param in ["epochs", "learning_rate", "batch_size"]): + if getattr(args, "epochs", None) is not None: + hp["optim.max_epochs"] = int(args.epochs) + hp["optim.epochs"] = int(args.epochs) + if getattr(args, "learning_rate", None) is not None: + hp["optim.learning_rate"] = float(args.learning_rate) + hp["optim.lr"] = float(args.learning_rate) + if getattr(args, "batch_size", None) is not None: + hp["optim.batch_size"] = int(args.batch_size) + hp["optim.per_device_train_batch_size"] = int(args.batch_size) + + # Also set dotted flat keys for max compatibility (e.g., 'optimization.max_epochs') + if getattr(args, "epochs", None) is not None: + hp["optimization.max_epochs"] = int(args.epochs) + hp["optimization.epochs"] = int(args.epochs) + if getattr(args, "learning_rate", None) is not None: + hp["optimization.learning_rate"] = float(args.learning_rate) + hp["optimization.lr"] = float(args.learning_rate) + if getattr(args, "batch_size", None) is not None: + hp["optimization.batch_size"] = int(args.batch_size) + hp["optimization.per_device_train_batch_size"] = int(args.batch_size) + except Exception: + logger.warning("Failed to attach epochs/learning_rate/batch_size to mm_hparams; continuing without them.") + + # Map backbone selections into mm_hparams if provided + try: + has_text_cols = bool(text_cols) + has_image_cols = False + model_names_cache: Optional[List[str]] = None + model_names_modified = False + + def _dedupe_preserve(seq: List[str]) -> List[str]: + seen = set() + ordered = [] + for item in seq: + if item in seen: + continue + seen.add(item) + ordered.append(item) + return ordered + + def _get_model_names() -> List[str]: + nonlocal model_names_cache + if model_names_cache is not None: + return model_names_cache + names = model_block.get("names") + if isinstance(names, list): + model_names_cache = list(names) + else: + model_names_cache = [] + if has_text_cols: + model_names_cache.append("hf_text") + if has_image_cols: + model_names_cache.append("timm_image") + model_names_cache.extend(["numerical_mlp", "categorical_mlp"]) + model_names_cache.append("fusion_mlp") + return model_names_cache + + def _set_model_names(new_names: List[str]) -> None: + nonlocal model_names_cache, model_names_modified + model_names_cache = new_names + model_names_modified = True + + if has_text_cols and getattr(args, "backbone_text", None): + text_choice = str(args.backbone_text) + model_block.setdefault("hf_text", {})["checkpoint_name"] = text_choice + hp["model.hf_text.checkpoint_name"] = text_choice + if has_image_cols and getattr(args, "backbone_image", None): + image_choice = str(args.backbone_image) + model_block.setdefault("timm_image", {})["checkpoint_name"] = image_choice + hp["model.timm_image.checkpoint_name"] = image_choice + if model_names_modified and model_names_cache is not None: + model_block["names"] = model_names_cache + except Exception: + logger.warning("Failed to attach backbone selections to mm_hparams; continuing without them.") + + if ag_version: + logger.info(f"Detected AutoGluon version: {ag_version}; applied robust hyperparameter mappings.") + + return hp + + +def train_predictor( + args, + df_train: pd.DataFrame, + df_val: pd.DataFrame, + image_columns: Optional[List[str]], + mm_hparams: dict, +): + """ + Train a MultiModalPredictor, honoring common knobs (presets, eval_metric, etc.). + """ + logger.info("Starting AutoGluon MultiModal training...") + predictor = MultiModalPredictor(label=args.label_column, path=None) + column_types = {} + + mm_fit_kwargs = dict( + train_data=df_train, + time_limit=args.time_limit, + seed=int(args.random_seed), + hyperparameters=mm_hparams, + ) + if df_val is not None and not df_val.empty: + mm_fit_kwargs["tuning_data"] = df_val + if column_types: + mm_fit_kwargs["column_types"] = column_types + + preset_mm = getattr(args, "presets", None) + if preset_mm is None: + preset_mm = getattr(args, "preset", None) + if preset_mm is not None: + mm_fit_kwargs["presets"] = preset_mm + + predictor.fit(**mm_fit_kwargs) + return predictor + + +# ---------------------- evaluation ---------------------- +def evaluate_predictor_all_splits( + predictor, + df_train: Optional[pd.DataFrame], + df_val: Optional[pd.DataFrame], + df_test: Optional[pd.DataFrame], + label_col: str, + problem_type: str, + eval_metric: Optional[str], + threshold_test: Optional[float], + df_test_external: Optional[pd.DataFrame] = None, +) -> Tuple[Dict[str, Dict[str, float]], Dict[str, Dict[str, float]]]: + """ + Returns (raw_metrics, ag_scores_by_split) + - raw_metrics: our transparent suite (threshold applied to Test/External Test only inside metrics_logic) + - ag_scores_by_split: AutoGluon's evaluate() per split for the chosen eval_metric (or default) + """ + metrics_req = None if (eval_metric is None or str(eval_metric).lower() == "auto") else [eval_metric] + ag_by_split: Dict[str, Dict[str, float]] = {} + + if df_train is not None and len(df_train): + ag_by_split["Train"] = ag_evaluate_safely(predictor, df_train, metrics=metrics_req) + if df_val is not None and len(df_val): + ag_by_split["Validation"] = ag_evaluate_safely(predictor, df_val, metrics=metrics_req) + + df_test_effective = df_test_external if df_test_external is not None else df_test + if df_test_effective is not None and len(df_test_effective): + ag_by_split["Test"] = ag_evaluate_safely(predictor, df_test_effective, metrics=metrics_req) + + # Transparent suite (threshold on Test handled inside metrics_logic) + _, raw_metrics = evaluate_all_transparency( + predictor=predictor, + train_df=df_train, + val_df=df_val, + test_df=df_test_effective, + target_col=label_col, + problem_type=problem_type, + threshold=threshold_test, + ) + + if df_test_external is not None and df_test_external is not df_test and len(df_test_external): + raw_metrics["Test (external)"] = compute_metrics_for_split( + predictor, df_test_external, label_col, problem_type, threshold=threshold_test + ) + ag_by_split["Test (external)"] = ag_evaluate_safely(predictor, df_test_external, metrics=metrics_req) + + return raw_metrics, ag_by_split + + +def fit_summary_safely(predictor) -> Optional[dict]: + """Get fit summary without printing misleading one-liners.""" + with suppress_stdout_stderr(): + try: + return predictor.fit_summary() + except Exception: + return None + + +# ---------------------- image helpers ---------------------- +_PLACEHOLDER_PATH = None + + +def _create_placeholder() -> str: + global _PLACEHOLDER_PATH + if _PLACEHOLDER_PATH and os.path.exists(_PLACEHOLDER_PATH): + return _PLACEHOLDER_PATH + + dir_ = Path(tempfile.mkdtemp(prefix="ag_placeholder_")) + file_ = dir_ / f"placeholder_{uuid.uuid4().hex}.png" + + try: + from PIL import Image + Image.new("RGB", (64, 64), (180, 180, 180)).save(file_) + except Exception: + import matplotlib.pyplot as plt + import numpy as np + plt.imsave(file_, np.full((64, 64, 3), 180, dtype=np.uint8)) + plt.close("all") + + _PLACEHOLDER_PATH = str(file_) + logger.info(f"Placeholder image created: {file_}") + return _PLACEHOLDER_PATH + + +def _is_valid_path(val) -> bool: + if pd.isna(val): + return False + s = str(val).strip() + return s and os.path.isfile(s) + + +def handle_missing_images( + df: pd.DataFrame, + image_columns: List[str], + strategy: str = "false", +) -> pd.DataFrame: + if not image_columns or df.empty: + return df + + remove = str(strategy).lower() == "true" + masks = [~df[col].apply(_is_valid_path) for col in image_columns if col in df.columns] + if not masks: + return df + + any_missing = pd.concat(masks, axis=1).any(axis=1) + n_missing = int(any_missing.sum()) + + if n_missing == 0: + return df + + if remove: + result = df[~any_missing].reset_index(drop=True) + logger.info(f"Dropped {n_missing} rows with missing images → {len(result)} remain") + else: + placeholder = _create_placeholder() + result = df.copy() + for col in image_columns: + if col in result.columns: + result.loc[~result[col].apply(_is_valid_path), col] = placeholder + logger.info(f"Filled {n_missing} missing images with placeholder") + + return result + + +# ---------------------- AutoGluon config helpers ---------------------- +def autogluon_hyperparameters( + threshold, + time_limit, + random_seed, + epochs, + learning_rate, + batch_size, + backbone_image, + backbone_text, + preset, + eval_metric, + hyperparameters, +): + """ + Build a MultiModalPredictor configuration (fit kwargs + hyperparameters) from CLI inputs. + The returned dict separates what should be passed to predictor.fit (under ``fit``) + from the model/optimization configuration (under ``hyperparameters``). Threshold is + preserved for downstream evaluation but not passed into AutoGluon directly. + """ + + def _prune_empty(d: dict) -> dict: + cleaned = {} + for k, v in (d or {}).items(): + if isinstance(v, dict): + nested = _prune_empty(v) + if nested: + cleaned[k] = nested + elif v is not None: + cleaned[k] = v + return cleaned + + # Base hyperparameters following the structure described in the AutoGluon + # customization guide (env / optimization / model). + env_cfg = {} + if random_seed is not None: + env_cfg["seed"] = int(random_seed) + if batch_size is not None: + env_cfg["per_gpu_batch_size"] = int(batch_size) + + optim_cfg = {} + if epochs is not None: + optim_cfg["max_epochs"] = int(epochs) + if learning_rate is not None: + optim_cfg["learning_rate"] = float(learning_rate) + if batch_size is not None: + bs = int(batch_size) + optim_cfg["per_device_train_batch_size"] = bs + optim_cfg["train_batch_size"] = bs + + model_cfg = {} + if eval_metric: + model_cfg.setdefault("metric_learning", {})["metric"] = str(eval_metric) + if backbone_image: + model_cfg.setdefault("timm_image", {})["checkpoint_name"] = str(backbone_image) + if backbone_text: + model_cfg.setdefault("hf_text", {})["checkpoint_name"] = str(backbone_text) + + hp = { + "env": env_cfg, + "optimization": optim_cfg, + "model": model_cfg, + } + + # Also expose the most common dotted aliases for robustness across AG versions. + if epochs is not None: + hp["optimization.max_epochs"] = int(epochs) + hp["optim.max_epochs"] = int(epochs) + if learning_rate is not None: + lr_val = float(learning_rate) + hp["optimization.learning_rate"] = lr_val + hp["optimization.lr"] = lr_val + hp["optim.learning_rate"] = lr_val + hp["optim.lr"] = lr_val + if batch_size is not None: + bs_val = int(batch_size) + hp["optimization.per_device_train_batch_size"] = bs_val + hp["optimization.batch_size"] = bs_val + hp["optim.per_device_train_batch_size"] = bs_val + hp["optim.batch_size"] = bs_val + hp["env.per_gpu_batch_size"] = bs_val + if backbone_image: + hp["model.timm_image.checkpoint_name"] = str(backbone_image) + if backbone_text: + hp["model.hf_text.checkpoint_name"] = str(backbone_text) + + # Merge user-provided hyperparameters (inline JSON or path) last so they win. + if isinstance(hyperparameters, dict): + user_hp = hyperparameters + else: + user_hp = load_user_hparams(hyperparameters) + hp = deep_update(hp, user_hp) + hp = _prune_empty(hp) + + fit_cfg = {} + if time_limit is not None: + fit_cfg["time_limit"] = time_limit + if random_seed is not None: + fit_cfg["seed"] = int(random_seed) + if preset: + fit_cfg["presets"] = preset + + config = { + "fit": fit_cfg, + "hyperparameters": hp, + } + if threshold is not None: + config["threshold"] = float(threshold) + + return config + + +def run_autogluon_experiment( + train_dataset: pd.DataFrame, + test_dataset: Optional[pd.DataFrame], + target_column: str, + image_columns: Optional[List[str]], + ag_config: dict, +): + """ + Launch an AutoGluon MultiModal training run using the config from + autogluon_hyperparameters(). Returns (predictor, context dict) so callers + can evaluate downstream with the chosen threshold. + """ + if ag_config is None: + raise ValueError("ag_config is required to launch AutoGluon training.") + + hyperparameters = ag_config.get("hyperparameters") or {} + fit_cfg = dict(ag_config.get("fit") or {}) + threshold = ag_config.get("threshold") + + if "split" not in train_dataset.columns: + raise ValueError("train_dataset must contain a 'split' column. Did you call split_dataset?") + + df_train = train_dataset[train_dataset["split"] == "train"].copy() + df_val = train_dataset[train_dataset["split"].isin(["val", "validation"])].copy() + df_test_internal = train_dataset[train_dataset["split"] == "test"].copy() + + predictor = MultiModalPredictor(label=target_column, path=None) + column_types = {c: "image_path" for c in (image_columns or [])} + + fit_kwargs = { + "train_data": df_train, + "hyperparameters": hyperparameters, + } + fit_kwargs.update(fit_cfg) + if not df_val.empty: + fit_kwargs.setdefault("tuning_data", df_val) + if column_types: + fit_kwargs.setdefault("column_types", column_types) + + logger.info( + "Fitting AutoGluon with %d train / %d val rows (internal test rows: %d, external test provided: %s)", + len(df_train), + len(df_val), + len(df_test_internal), + (test_dataset is not None and not test_dataset.empty), + ) + predictor.fit(**fit_kwargs) + + return predictor, { + "train": df_train, + "val": df_val, + "test_internal": df_test_internal, + "test_external": test_dataset, + "threshold": threshold, + }
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/utils.py Tue Dec 09 23:49:47 2025 +0000 @@ -0,0 +1,168 @@ +import json +import logging +import os +import random +import sys +import tempfile +import zipfile +from pathlib import Path +from typing import List, Optional + +import numpy as np +import pandas as pd +import torch + +LOG = logging.getLogger(__name__) + + +def str2bool(val) -> bool: + """Parse common truthy strings to bool.""" + return str(val).strip().lower() in ("1", "true", "yes", "y") + + +def load_user_hparams(hp_arg: Optional[str]) -> dict: + """Parse --hyperparameters (inline JSON or path to .json).""" + if not hp_arg: + return {} + try: + s = hp_arg.strip() + if s.startswith("{"): + return json.loads(s) + with open(s, "r") as f: + return json.load(f) + except Exception as e: + LOG.warning(f"Could not parse --hyperparameters: {e}. Ignoring.") + return {} + + +def set_seeds(seed: int = 42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def ensure_local_tmp(): + os.makedirs("/tmp", exist_ok=True) + + +def enable_tensor_cores_if_available(): + if torch.cuda.is_available(): + torch.set_float32_matmul_precision("high") + + +def enable_deterministic_mode(seed: Optional[int] = None): + """ + Force deterministic algorithms where possible to reduce run-to-run variance. + """ + if seed is not None: + set_seeds(seed) + os.environ.setdefault("PYTHONHASHSEED", str(int(seed))) + # cuBLAS determinism + os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + try: + torch.use_deterministic_algorithms(True) + except Exception as e: + LOG.warning(f"Could not enable torch deterministic algorithms: {e}") + try: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + except Exception as e: + LOG.warning(f"Could not enforce deterministic cuDNN settings: {e}") + try: + torch.backends.cuda.matmul.allow_tf32 = False + except Exception: + pass + try: + torch.backends.cudnn.allow_tf32 = False + except Exception: + pass + + +def load_file(path: str) -> pd.DataFrame: + if not path: + return None + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Dataset not found: {path}") + return pd.read_csv(path, sep=None, engine="python") + + +def prepare_image_search_dirs(args) -> Optional[Path]: + if not args.images_zip: + return None + + root = Path(tempfile.mkdtemp(prefix="autogluon_images_")) + LOG.info(f"Extracting {len(args.images_zip)} image ZIP(s) to {root}") + + for zip_path in args.images_zip: + path = Path(zip_path) + if not path.exists(): + raise FileNotFoundError(f"Image ZIP not found: {zip_path}") + with zipfile.ZipFile(path, 'r') as z: + z.extractall(root) + LOG.info(f"Extracted {path.name}") + + return root + + +def absolute_path_expander(df: pd.DataFrame, extracted_root: Optional[Path], image_columns: Optional[List[str]]) -> List[str]: + """ + Resolve image paths to absolute paths. If no image_columns are provided, + infers candidate columns whose values resolve to existing files (checking + absolute paths first, then paths relative to the extracted_root). + """ + if df is None or df.empty: + return [] + + image_columns = [c for c in (image_columns or []) if c in df.columns] + + def resolve(p): + if pd.isna(p): + return None + orig = Path(str(p).strip()) + candidates = [] + if orig.is_absolute(): + candidates.append(orig) + if extracted_root is not None: + candidates.extend([extracted_root / orig, extracted_root / orig.name]) + for cand in candidates: + if cand.exists(): + return str(cand.resolve()) + return None + + # Infer image columns if none were provided + if not image_columns: + obj_cols = [c for c in df.columns if str(df[c].dtype) == "object"] + inferred = [] + for col in obj_cols: + sample = df[col].dropna().head(50) + if sample.empty: + continue + resolved_sample = sample.apply(resolve) + if resolved_sample.notna().any(): + inferred.append(col) + image_columns = inferred + if image_columns: + LOG.info(f"Inferred image columns: {image_columns}") + + for col in image_columns: + df[col] = df[col].apply(resolve) + + return image_columns + + +def verify_outputs(paths): + ok = True + for p, desc in paths: + if os.path.exists(p): + size = os.path.getsize(p) + LOG.info(f"✓ Output {desc}: {p} ({size:,} bytes)") + os.chmod(p, 0o644) + else: + LOG.error(f"✗ Output {desc} MISSING: {p}") + ok = False + if not ok: + LOG.error("Some outputs are missing!") + sys.exit(1)
