Skip to content

Instantly share code, notes, and snippets.

@thesephist
Created August 2, 2024 09:57
Show Gist options
  • Save thesephist/40094c8eb188d18461b36d8a64d8c730 to your computer and use it in GitHub Desktop.
Save thesephist/40094c8eb188d18461b36d8a64d8c730 to your computer and use it in GitHub Desktop.
Flux 1 Schnell
import torch
from diffusers import FluxPipeline, AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
ckpt_id = "black-forest-labs/FLUX.1-dev"
revision = "refs/pr/1"
text_encoder = CLIPTextModel.from_pretrained(
ckpt_id, revision=revision, subfolder="text_encoder", torch_dtype=torch.bfloat16
).to("cuda")
text_encoder_2 = T5EncoderModel.from_pretrained(
ckpt_id,
subfolder="text_encoder_2",
revision=revision,
torch_dtype=torch.bfloat16,
).to("cuda")
tokenizer = CLIPTokenizer.from_pretrained(
ckpt_id,
subfolder="tokenizer",
revision=revision,
)
tokenizer_2 = T5TokenizerFast.from_pretrained(
ckpt_id,
subfolder="tokenizer_2",
revision=revision,
)
vae = AutoencoderKL.from_pretrained(
ckpt_id,
subfolder="vae",
revision=revision,
torch_dtype=torch.bfloat16,
).to("cuda")
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
vae=vae,
revision="refs/pr/1",
torch_dtype=torch.bfloat16,
device_map="balanced",
)
while True:
prompt = input("Enter prompt: ")
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
prompt=prompt,
prompt_2=None,
max_sequence_length=256,
)
height, width = 1080, 1080
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=4,
guidance_scale=0.0,
height=height,
width=width,
output_type="latent",
).images
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
with torch.no_grad():
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="pil")
image[0].save("image.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment