Skip to content

Instantly share code, notes, and snippets.

@shirayu
Last active August 23, 2022 11:38
Show Gist options
  • Save shirayu/514db3c873d7713fc49d82d3b6c4e4d1 to your computer and use it in GitHub Desktop.
Save shirayu/514db3c873d7713fc49d82d3b6c4e4d1 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
from pathlib import Path
import torch
from diffusers import StableDiffusionPipeline
from torch.amp.autocast_mode import autocast
def get_max_index(p: Path):
mx = 0
for image in p.glob("*.png"):
try:
n: int = int(image.name.split(".")[0])
except ValueError:
continue
mx = max(mx, n)
return mx
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
# pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
# pipe = pipe.to(device)
print("Loading...")
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=True,
)
pipe.to("cuda")
with Path("/dev/stdin").open("r") as inf, autocast("cuda"):
print("Ready!")
path_out = Path("img")
path_out.mkdir(exist_ok=True, parents=True)
idx = get_max_index(path_out)
for line in inf:
prompt: str = line.strip()
if len(prompt) == 0:
continue
image = pipe(prompt)["sample"][0]
oname: Path = path_out.joinpath(f"{idx:06}.png")
idx += 1
image.save(str(oname))
print(f"Saved to {oname}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment