Mercurial > repos > bgruening > black_forest_labs_flux
view flux.py @ 2:3e22bda128be draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/flux commit 61ef69066b5fb4ff882b0de4e14f54b97b384f2a
author | bgruening |
---|---|
date | Wed, 30 Oct 2024 13:27:35 +0000 |
parents | 7933bed1ffab |
children |
line wrap: on
line source
import os import sys import torch from diffusers import FluxPipeline model_path = sys.argv[1] prompt_type = sys.argv[2] if prompt_type == "file": with open(sys.argv[3], "r") as f: prompt = f.read().strip() elif prompt_type == "text": prompt = sys.argv[3] if "dev" in model_path: num_inference_steps = 20 elif "schnell" in model_path: num_inference_steps = 4 else: print("Invalid model!") sys.exit(1) snapshots = [] for d in os.listdir(os.path.join(model_path, "snapshots")): snapshots.append(os.path.join(model_path, "snapshots", d)) latest_snapshot_path = max(snapshots, key=os.path.getmtime) pipe = FluxPipeline.from_pretrained(latest_snapshot_path, torch_dtype=torch.bfloat16) pipe.enable_sequential_cpu_offload() pipe.vae.enable_slicing() pipe.vae.enable_tiling() pipe.to(torch.float16) image = pipe( prompt, num_inference_steps=num_inference_steps, generator=torch.Generator("cpu").manual_seed(42), ).images[0] image.save("output.png")