changeset 1:7933bed1ffab draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 0959e15618c76e05c78e5811218c9c4bc343db27
author bgruening
date Tue, 29 Oct 2024 13:53:36 +0000
parents 0d0561746128
children 3e22bda128be
files flux.py flux.xml
diffstat 2 files changed, 15 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/flux.py	Mon Oct 14 16:50:38 2024 +0000
+++ b/flux.py	Tue Oct 29 13:53:36 2024 +0000
@@ -1,9 +1,10 @@
+import os
 import sys
 
 import torch
 from diffusers import FluxPipeline
 
-model = sys.argv[1]
+model_path = sys.argv[1]
 
 prompt_type = sys.argv[2]
 if prompt_type == "file":
@@ -12,12 +13,20 @@
 elif prompt_type == "text":
     prompt = sys.argv[3]
 
-if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
+if "dev" in model_path:
+    num_inference_steps = 20
+elif "schnell" in model_path:
+    num_inference_steps = 4
+else:
     print("Invalid model!")
     sys.exit(1)
 
+snapshots = []
+for d in os.listdir(os.path.join(model_path, "snapshots")):
+    snapshots.append(os.path.join(model_path, "snapshots", d))
+latest_snapshot_path = max(snapshots, key=os.path.getmtime)
 
-pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
+pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16)
 pipe.enable_sequential_cpu_offload()
 pipe.vae.enable_slicing()
 pipe.vae.enable_tiling()
@@ -25,7 +34,7 @@
 
 image = pipe(
     prompt,
-    num_inference_steps=4,
+    num_inference_steps=num_inference_steps,
     generator=torch.Generator("cpu").manual_seed(42),
 ).images[0]
 
--- a/flux.xml	Mon Oct 14 16:50:38 2024 +0000
+++ b/flux.xml	Tue Oct 29 13:53:36 2024 +0000
@@ -2,7 +2,7 @@
     <description>text-to-image model</description>
     <macros>
         <token name="@TOOL_VERSION@">2024</token>
-        <token name="@VERSION_SUFFIX@">0</token>
+        <token name="@VERSION_SUFFIX@">1</token>
     </macros>
     <requirements>
         <requirement type="package" version="3.12">python</requirement>
@@ -16,9 +16,8 @@
         <requirement type="package" version="0.24.6">huggingface_hub</requirement>
     </requirements>
     <command detect_errors="exit_code"><![CDATA[
-export HF_HOME='$flux_models.fields.path' &&
 python '$__tool_directory__/flux.py'
-'$flux_models'
+'$flux_models.fields.path'
 '$input_type_selector'
 '$prompt'
     ]]></command>