Skip to content

Instantly share code, notes, and snippets.

@Christopher-Hayes
Last active September 8, 2022 00:32
Show Gist options
  • Save Christopher-Hayes/151d0180fabcc8e1eb02957088869bc0 to your computer and use it in GitHub Desktop.
Save Christopher-Hayes/151d0180fabcc8e1eb02957088869bc0 to your computer and use it in GitHub Desktop.
Output Grid Code Block for Stable Diffusion Colab
#@title 🌌 Run to start dreaming.{ vertical-output: true, display-mode: "form" }
import IPython
import base64
from io import BytesIO
all_images = []
# Clear sample output
!rm /content/stable-diffusion/outputs/txt2img-samples/samples/*
!rm /content/stable-diffusion/outputs/txt2img-samples/grid*
# Python really needs objects like JS
class Option:
def __init__(self, seed):
self.value = seed
def __str__(self):
return self.value
# Clear code output
clear_output()
### HTML ELEMENTS
# Create an HTML grid to put images in
def createGrid():
display('create grid..')
display(IPython.display.Javascript("""
var seeds = []
const outputFooter = document.querySelector("#output-footer")
const grid = document.createElement('div')
grid.classList.add('stream')
grid.id = 'image-grid'
grid.style = `
display: grid;
grid-template-columns: repeat(6, 1fr);
grid-gap: 10px;
padding: 3em 2em 5em 2em;
`
outputFooter.appendChild(grid)
const showSeeds = document.createElement('div')
showSeeds.style = `
border: 2px solid grey;
background: #333;
margin: 0 0 2em 0;
padding: 2em 1em;
cursor: pointer;
width: 100%;
`
showSeeds.onmouseover = showSeeds.onfocus = () => {{
showSeeds.style.background = '#111'
}}
showSeeds.onmouseout = showSeeds.onblur = () => {{
showSeeds.style.background = '#333'
}}
showSeeds.innerText = 'Click to copy seeds: ' + seeds.join(' ')
showSeeds.onclick = () => {{
navigator.clipboard.writeText(seeds.join(' ')).then(() => {{
console.log('seeds copied to clipboard!')
}}, () => {{
console.log('failed to copy to clipboard')
}})
}}
outputFooter.appendChild(showSeeds)
// Listen for the "updateSeeds" event
window.addEventListener('updateSeeds', function (e) {{
seeds = e.detail
showSeeds.innerText = 'Click to copy seeds: ' + seeds.join(' ')
}})
"""))
# Add image to grid
def addToGrid(img_base64, seedStr):
display(IPython.display.Javascript("""
const grid = document.querySelector('#image-grid')
const container = document.createElement('div')
container.style = `
display: flex;
flex-direction: column;
`
const img = document.createElement('img')
img.src = `{img_base64}`
img.style = `
width: 100%;
height: 100%;
object-fit: cover;
`
const button = document.createElement('button')
let active = false
button.style = `
transition: all 80ms ease-out;
border: 4px solid grey;
border-radius: 4px;
filter: drop-shadow(0px 0px 0px black);
padding: 0;
cursor: pointer;
`
button.onmouseover = button.onfocus = () => {{
button.style.transform = active ? 'scale(1)' : 'scale(1.2)';
button.style.filter = 'drop-shadow(0px 2px 12px black)';
button.style.zIndex = '10';
}}
button.onmouseout = button.onblur = () => {{
button.style.transform = active ? 'scale(0.8)' : 'scale(1)';
button.style.filter = 'drop-shadow(0px 0px 0px black)';
button.style.zIndex = 'unset';
}}
button.onclick = () => {{
active = !active
if (active) {{
seeds.push('{seed}')
button.style.opacity = '0.5'
button.style.border = '4px solid #4747f5';
button.style.transform = 'scale(0.8)';
}} else {{
seeds.splice(seeds.indexOf('{seed}'), 1)
button.style.border = '4px solid grey';
button.style.opacity = '1'
button.style.transform = 'scale(1)';
}}
// Create new custom event to share the "seeds" array object
var event = new CustomEvent('updateSeeds', {{detail: seeds}})
window.dispatchEvent(event)
}}
button.appendChild(img)
container.appendChild(button)
grid.appendChild(container)
""".format(img_base64=img_base64, seed=seedStr)))
# Create image grid
createGrid()
### Prepare config
# Get iterable list from widget seedsopt string
widgetDict = get_widget_extractor(widget_opt)
# Remove hanging empty value if it ends with a comma
seeds = [s for s in widgetDict['seeds'].value.strip().replace(' ', ' ').split(' ') if s]
iterCount = widgetDict['n_iter'].value
widgetDict['n_iter'] = Option(1)
display(IPython.display.Javascript('if (typeof seeds === "undefined") { var seeds = [] } else { seeds = [] };'))
# Define run
def doRun(seedStr):
# Run inference
run(widgetDict)
# Update image grid
buffered = BytesIO()
all_images.pop().save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
img_base64 = (bytes("data:image/jpeg;base64,", encoding='utf-8') + img_str).decode("utf-8")
addToGrid(img_base64, seedStr)
# Runs
for baseSeed in seeds:
for iter in range(iterCount):
# Seed
seed = int(baseSeed) + iter
seedStr = str(seed)
# Add seed to dict manually
widgetDict['seed'] = Option(seed)
print('Seed:', seedStr)
doRun(seedStr)
# Get Seeds button at bottom
"""display(IPython.display.Javascript('''
console.log('create get seeds button')
const button = document.createElement('button')
button.innerText = 'Get seeds'
button.onclick = () => {{
console.log('seeds:', seeds)
const streams = document.querySelectorAll(".stream")
streams[streams.length - 1].appendChild(document.createTextNode(`Seeds: ${seeds.join(' ')}`))
}}
const streams = document.querySelectorAll(".stream")
streams[streams.length - 1].appendChild(button)
'''))
for img in all_images:
print('show image in grid..')
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
img_base64 = (bytes("data:image/jpeg;base64,", encoding='utf-8') + img_str).decode("utf-8")
addToGrid(img_base64)
"""
print('Batch complete!')
# Clear VRAM
torch.cuda.empty_cache()
#@title Main one-time Setup Code (replaces txt2img.py from SD repo) { display-mode: "form" }
# Special installs
!pip install diffusers
# Slightly modified version of: https://github.com/CompVis/stable-diffusion/blob/main/scripts/txt2img.py
import argparse, os, sys, glob
import torch
import numpy as np
import datetime
from omegaconf import OmegaConf
from PIL import Image
from PIL.PngImagePlugin import PngInfo
#from tqdm.auto import tqdm, trange # NOTE: updated for notebook
from tqdm import tqdm, trange # NOTE: updated for notebook
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
import rich
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from scripts.txt2img import chunk, load_model_from_config
from IPython.display import clear_output
# Code to turn kwargs into Jupyter widgets
import ipywidgets as widgets
from collections import OrderedDict
def load_model(opt):
"""Seperates the loading of the model from the inference"""
if opt.laion400m:
print("Falling back to LAION 400M model...")
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
if torch.cuda.is_available():
device = torch.device("cuda")
else:
print("Warning - running in CPU mode!")
device = torch.device("cpu")
model = model.to(device)
return model
all_images = []
def run_inference(opt, model):
"""Seperates the loading of the model from the inference
Additionally, slightly modified to display generated images inline
"""
seed_everything(opt.seed)
if opt.plms:
sampler = PLMSSampler(model)
else:
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
#print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
# * Variables for saved image filenames
# date + time
datetimeStr = datetime.datetime.now().isoformat()
# Filename-safe prompt string
slugPrompt = "".join(c if c.isalnum() else "_" for c in opt.prompt)
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
start_code = None
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in range(opt.n_iter): # trange(opt.n_iter, desc="Sampling"):
for prompts in data: # tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
imgPath = f'{slugPrompt[:150]}_{opt.seed}_{datetimeStr}_{base_count:05}.png'
#Image.fromarray(x_sample.astype(np.uint8)).save(
# os.path.join(sample_path, imgPath))
img = Image.fromarray(x_sample.astype(np.uint8)) # Image.open(imgPath)
all_images.append(img)
metadata = PngInfo()
metadata.add_text("artist", 'Chris Hayes')
metadata.add_text("copyright", 'Public Domain')
metadata.add_text("software", "Stable Diffusion 1.4")
metadata.add_text("title", opt.prompt)
config = f"prompt: {opt.prompt}, seed: {seed}, steps: {opt.ddim_steps}, CGS: {opt.scale}"
metadata.add_text("config", config)
img.save(os.path.join(sample_path, imgPath), pnginfo=metadata)
base_count += 1
if not opt.skip_grid:
all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
#Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{slugPrompt[:150]}_{opt.seed}_{datetimeStr}_{grid_count:04}.png'))
grid_count += 1
# display
#if opt.display_inline:
#clear_output()
#display(Image.fromarray(grid.astype(np.uint8)))
toc = time.time()
#print(f"Your samples have been saved to: \n{outpath} \n"
# f" \nEnjoy.")
def run(opt):
"""If the model parameters changed, reload the model, otherwise, just do inference"""
#print(f"Creating image ({opt.H},{opt.W}) from prompt:\n\"{opt.prompt}\"\n")
# FIXME global hack
global last_config
global last_ckpt
global model
if (opt.config != last_config) or (opt.ckpt != last_ckpt):
model = load_model(opt)
# FIXME global hack
last_config = opt.config
last_ckpt = opt.ckpt
run_inference(opt, model)
# FIXME global hack
last_config = ""
last_ckpt = ""
####### Widget GUI code #######
def get_widget_extractor(widget_dict):
# allows accessing after setting, this is to reduce the diff against the argparse code
class WidgetDict(OrderedDict):
def __getattr__(self,val):
return self[val].value
return WidgetDict(widget_dict)
# Allows long widget descriptions
style = {'description_width': 'initial'}
# Force widget width to max
layout = widgets.Layout(width='100%')
# args from argparse converted to widgets:
# https://github.com/CompVis/stable-diffusion/blob/main/scripts/txt2img.py#L48-L177
widget_opt = OrderedDict()
widget_opt['outdir'] = widgets.Text(
layout=layout, style=style,
description='dir to write results to',
value="outputs/txt2img-samples",
disabled=False
)
widget_opt['skip_grid'] = widgets.Checkbox(
layout=layout, style=style,
value=False,
description='do not save a grid, only individual samples. Helpful when evaluating lots of samples',
indent=False,
disabled=False
)
widget_opt['skip_save'] = widgets.Checkbox(
layout=layout, style=style,
value=False,
description='do not save individual samples. For speed measurements.',
indent=False,
disabled=False
)
widget_opt['plms'] = widgets.Checkbox(
layout=layout, style=style,
value=False,
description='use plms sampling (not checked = ddim)',
indent=False,
disabled=False
)
widget_opt['laion400m'] = widgets.Checkbox(
layout=layout, style=style,
value=False,
description='uses the LAION400M model',
indent=False,
disabled=False
)
widget_opt['fixed_code'] = widgets.Checkbox(
layout=layout, style=style,
value=False,
description='if enabled, uses the same starting code across samples',
indent=False,
disabled=False
)
widget_opt['ddim_eta'] = widgets.FloatText(
layout=layout, style=style,
description='ddim eta (eta=0.0 corresponds to deterministic sampling',
value=0.0,
disabled=False
)
widget_opt['C'] = widgets.IntText(
layout=layout, style=style,
description='latent channels',
value=4,
disabled=False
)
widget_opt['f'] = widgets.IntText(
layout=layout, style=style,
description='downsampling factor',
value=8,
disabled=False
)
widget_opt['n_rows'] = widgets.IntText(
layout=layout, style=style,
description='rows in the grid (default: n_samples)',
value=0,
disabled=False
)
widget_opt['scale'] = widgets.FloatText(
layout=layout, style=style,
description='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))',
value=7.5,
disabled=False
)
widget_opt['from_file'] = widgets.Text(
layout=layout, style=style,
description='if specified, load prompts from this file',
value=None,
disabled=False
)
widget_opt['config'] = widgets.Text(
layout=layout, style=style,
description='path to config which constructs model',
value="configs/stable-diffusion/v1-inference.yaml",
disabled=False
)
widget_opt['ckpt'] = widgets.Text(
layout=layout, style=style,
description='path to checkpoint of model',
value="models/ldm/stable-diffusion-v1/model.ckpt",
disabled=False
)
widget_opt['precision'] = widgets.Combobox(
layout=layout, style=style,
description='evaluate at this precision',
value="autocast",
options=["full", "autocast"],
disabled=False
)
# Extra option for the notebook
widget_opt['display_inline'] = widgets.Checkbox(
layout=layout, style=style,
value=True,
description='display output images inline (in addition to saving them)',
indent=False,
disabled=False
)
# Common
widget_opt['n_iter'] = widgets.IntText(
layout=layout, style=style,
description='sample this often',
value=1,
disabled=False
)
widget_opt['n_samples'] = widgets.IntText(
layout=layout, style=style,
description='how many samples to produce for each given prompt. A.k.a. batch size',
value=1,
disabled=False
)
widget_opt['H'] = widgets.IntText(
layout=layout, style=style,
description='image height, in pixel space',
value=512,
disabled=False
)
widget_opt['W'] = widgets.IntText(
layout=layout, style=style,
description='image width, in pixel space',
value=512,
disabled=False
)
widget_opt['prompt'] = widgets.Text(
layout=layout, style=style,
description='the prompt to render',
#value="a painting of a virus monster playing guitar", # script default
value="a photograph of an astronaut riding a horse", # README default
disabled=False
)
widget_opt['ddim_steps'] = widgets.IntText(
layout=layout, style=style,
description='number of ddim sampling steps',
value=50,
disabled=False
)
widget_opt['seeds'] = widgets.Text(
layout=layout, style=style,
description='multiple seeds for batch runs (separate with a space)',
value='42',
disabled=False
)
# Button that runs the
# Alternatively, you can just run the following in a new cell:
# run(get_widget_extractor(widget_opt))
run_button = widgets.Button(
description='CLICK TO DREAM',
disabled=False,
button_style='', # 'success', 'info', 'warning', 'danger' or ''
tooltip='Click to run (settings will update automatically)',
icon='check'
)
run_button_out = widgets.Output()
# this doesn't get used
def on_run_button_click(b):
with run_button_out:
widgetDict = get_widget_extractor(widget_opt)
for seed in widgetDict['seeds'].split(','):
#clear_output()
widgetDict['seed'] = seed
run(widgetDict)
run_button.on_click(on_run_button_click)
# Package into box and render
#primary_options = ['prompt', 'outdir'] # options to put up top
#secondary_options = [k for k in widget_opt.keys() if k not in primary_options] # rest, ordered by insertion
load_options = ['config', 'ckpt']
inference_options = [k for k in widget_opt.keys() if k not in load_options] # rest, ordered by insertion
assert all([k in inference_options + load_options for k in widget_opt.keys()]) # make sure we didn't miss any options
# Package into box for rendering
gui = widgets.VBox(
[widget_opt[k] for k in inference_options] + [widget_opt[k] for k in load_options] # + [run_button, run_button_out]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment