view flux.py @ 3:21ee409e6cde draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit f8d8a9399068a8f11582d528e1eb54d36125fb76
author bgruening
date Fri, 22 Nov 2024 13:41:32 +0000
parents 7933bed1ffab
children
line wrap: on
line source

import os
import sys

import torch
from diffusers import FluxPipeline

model_path = sys.argv[1]

prompt_type = sys.argv[2]
if prompt_type == "file":
    with open(sys.argv[3], "r") as f:
        prompt = f.read().strip()
elif prompt_type == "text":
    prompt = sys.argv[3]

if "dev" in model_path:
    num_inference_steps = 20
elif "schnell" in model_path:
    num_inference_steps = 4
else:
    print("Invalid model!")
    sys.exit(1)

snapshots = []
for d in os.listdir(os.path.join(model_path, "snapshots")):
    snapshots.append(os.path.join(model_path, "snapshots", d))
latest_snapshot_path = max(snapshots, key=os.path.getmtime)

pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.to(torch.float16)

image = pipe(
    prompt,
    num_inference_steps=num_inference_steps,
    generator=torch.Generator("cpu").manual_seed(42),
).images[0]

image.save("output.png")