Mercurial > repos > bgruening > black_forest_labs_flux
changeset 1:7933bed1ffab draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 0959e15618c76e05c78e5811218c9c4bc343db27
author | bgruening |
---|---|
date | Tue, 29 Oct 2024 13:53:36 +0000 |
parents | 0d0561746128 |
children | 3e22bda128be |
files | flux.py flux.xml |
diffstat | 2 files changed, 15 insertions(+), 7 deletions(-) [+] |
line wrap: on
line diff
--- a/flux.py Mon Oct 14 16:50:38 2024 +0000 +++ b/flux.py Tue Oct 29 13:53:36 2024 +0000 @@ -1,9 +1,10 @@ +import os import sys import torch from diffusers import FluxPipeline -model = sys.argv[1] +model_path = sys.argv[1] prompt_type = sys.argv[2] if prompt_type == "file": @@ -12,12 +13,20 @@ elif prompt_type == "text": prompt = sys.argv[3] -if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]: +if "dev" in model_path: + num_inference_steps = 20 +elif "schnell" in model_path: + num_inference_steps = 4 +else: print("Invalid model!") sys.exit(1) +snapshots = [] +for d in os.listdir(os.path.join(model_path, "snapshots")): + snapshots.append(os.path.join(model_path, "snapshots", d)) +latest_snapshot_path = max(snapshots, key=os.path.getmtime) -pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16) +pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16) pipe.enable_sequential_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling() @@ -25,7 +34,7 @@ image = pipe( prompt, - num_inference_steps=4, + num_inference_steps=num_inference_steps, generator=torch.Generator("cpu").manual_seed(42), ).images[0]
--- a/flux.xml Mon Oct 14 16:50:38 2024 +0000 +++ b/flux.xml Tue Oct 29 13:53:36 2024 +0000 @@ -2,7 +2,7 @@ <description>text-to-image model</description> <macros> <token name="@TOOL_VERSION@">2024</token> - <token name="@VERSION_SUFFIX@">0</token> + <token name="@VERSION_SUFFIX@">1</token> </macros> <requirements> <requirement type="package" version="3.12">python</requirement> @@ -16,9 +16,8 @@ <requirement type="package" version="0.24.6">huggingface_hub</requirement> </requirements> <command detect_errors="exit_code"><![CDATA[ -export HF_HOME='$flux_models.fields.path' && python '$__tool_directory__/flux.py' -'$flux_models' +'$flux_models.fields.path' '$input_type_selector' '$prompt' ]]></command>