# HG changeset patch
# User bgruening
# Date 1730210016 0
# Node ID 7933bed1ffab15c41f47125e5b10918fab106f76
# Parent 0d0561746128f453e5e4bca8fd8df83125f3de51
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 0959e15618c76e05c78e5811218c9c4bc343db27
diff -r 0d0561746128 -r 7933bed1ffab flux.py
--- a/flux.py Mon Oct 14 16:50:38 2024 +0000
+++ b/flux.py Tue Oct 29 13:53:36 2024 +0000
@@ -1,9 +1,10 @@
+import os
import sys
import torch
from diffusers import FluxPipeline
-model = sys.argv[1]
+model_path = sys.argv[1]
prompt_type = sys.argv[2]
if prompt_type == "file":
@@ -12,12 +13,20 @@
elif prompt_type == "text":
prompt = sys.argv[3]
-if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
+if "dev" in model_path:
+ num_inference_steps = 20
+elif "schnell" in model_path:
+ num_inference_steps = 4
+else:
print("Invalid model!")
sys.exit(1)
+snapshots = []
+for d in os.listdir(os.path.join(model_path, "snapshots")):
+ snapshots.append(os.path.join(model_path, "snapshots", d))
+latest_snapshot_path = max(snapshots, key=os.path.getmtime)
-pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
+pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
@@ -25,7 +34,7 @@
image = pipe(
prompt,
- num_inference_steps=4,
+ num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(42),
).images[0]
diff -r 0d0561746128 -r 7933bed1ffab flux.xml
--- a/flux.xml Mon Oct 14 16:50:38 2024 +0000
+++ b/flux.xml Tue Oct 29 13:53:36 2024 +0000
@@ -2,7 +2,7 @@
text-to-image model
2024
- 0
+ 1
python
@@ -16,9 +16,8 @@
huggingface_hub