view flux.py @ 0:0d0561746128 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 71b3dacee16dc999cb4fa113858d6ace1781c71c
author bgruening
date Mon, 14 Oct 2024 16:50:38 +0000
parents
children 7933bed1ffab
line wrap: on
line source

import sys

import torch
from diffusers import FluxPipeline

model = sys.argv[1]

prompt_type = sys.argv[2]
if prompt_type == "file":
    with open(sys.argv[3], "r") as f:
        prompt = f.read().strip()
elif prompt_type == "text":
    prompt = sys.argv[3]

if model not in ["black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"]:
    print("Invalid model!")
    sys.exit(1)


pipe = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.to(torch.float16)

image = pipe(
    prompt,
    num_inference_steps=4,
    generator=torch.Generator("cpu").manual_seed(42),
).images[0]

image.save("output.png")