Skip to content

Instantly share code, notes, and snippets.

@aimerib
Created September 16, 2024 16:00
Show Gist options
  • Save aimerib/7b3c338a5d9496bf636a2c7b951a9f86 to your computer and use it in GitHub Desktop.
Save aimerib/7b3c338a5d9496bf636a2c7b951a9f86 to your computer and use it in GitHub Desktop.
Update Oobabooga StableDiffusion extension to use SwarmUI instead
# Either change extensions/sd_api_pictures/script.py or copy the extensions/sd_api_pictures folder as second folder like sd_api_pictures_swarm and change extensions/sd_api_pictures_swarm/script.py
import base64
import io
import json
import re
from datetime import date
from pathlib import Path
from PIL import Image
import gradio as gr
import requests
import torch
import os
import hashlib
from modules import shared
from modules.models import reload_model, unload_model
from modules.ui import create_refresh_button
torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui
gen_params = {
"address": "http://127.0.0.1:7801",
"mode": 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on)
"manage_VRAM": False,
"save_img": False,
"prompt_prefix": "(Masterpiece:1.1), detailed, intricate, colorful",
"negative_prompt": "(worst quality, low quality:1.3)",
"width": 512,
"height": 512,
"denoising_strength": 0.61,
"seed": -1,
"sampler_name": "dpmpp_2m_sde_gpu",
"scheduler_name": "karras",
"aspect_ratio": "1:1",
"steps": 32,
"cfg_scale": 7,
"textgen_prefix": "Please provide a detailed and vivid description of [subject]",
"sd_checkpoint": " ",
"checkpoint_list": [" "],
"session_id": "",
}
def update_gen_params(new_params):
global gen_params
# Validate the input
if not isinstance(new_params, dict):
raise ValueError("Input must be a dictionary")
# Update only the keys that exist in gen_params
for key, value in new_params.items():
if key in gen_params:
# Type checking to ensure we're not changing the type of a parameter
if isinstance(gen_params[key], type(value)):
gen_params[key] = value
else:
print(
f"Warning: Type mismatch for key '{key}'. Expected {type(gen_params[key])}, got {type(value)}. Value not updated."
)
else:
print(f"Warning: '{key}' is not a valid parameter and will be ignored.")
# You might want to add specific validations for certain parameters
if "width" in new_params or "height" in new_params:
gen_params["aspect_ratio"] = f"{gen_params['width']}:{gen_params['height']}"
# Return the updated dictionary
return gen_params
# SwarmUI API requires a session_id for every request we send. Let's get one
def get_session_id():
global gen_params
print("Requesting new SwarmUI session token...")
msg = "✔️ SwarmUI API Session Token updated"
try:
response = requests.post(
url=f'{gen_params["address"]}/API/GetNewSession',
data=json.dumps({}),
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
r = response.json()
update_gen_params({"session_id": r["session_id"]})
except:
msg = "❌ No SD API endpoint on:"
return gr.Textbox.update(label=msg)
def make_api_request(endpoint, payload):
print(f"Making SwarmUI API request to endpoint: {endpoint}")
msg = f"✔️ SwarmUI API request to: {endpoint} - Success"
if gen_params["session_id"] == "":
get_session_id()
try:
payload["session_id"] = gen_params["session_id"]
response = requests.post(
url=f'{gen_params["address"]}/API/{endpoint}',
data=json.dumps(payload),
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
r = response.json()
print(msg)
return (r, response.ok)
except Exception as e:
msg = f"❌ Request to endpoint: {endpoint} returned - {e}"
print(msg)
def get_file_hash(file_path):
"""Calculate the MD5 hash of a file."""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def download_and_save_image(url, save_dir):
# Create a filename based on the URL
url = f"{gen_params['address']}/{url}"
filename = os.path.join(save_dir, hashlib.md5(url.encode()).hexdigest() + ".png")
output_file = Path(filename)
output_file.parent.mkdir(parents=True, exist_ok=True)
# Check if file already exists and hasn't changed
if os.path.exists(filename):
etag = requests.head(url).headers.get("ETag")
if etag and etag == get_file_hash(filename):
return filename
# Send a GET request to the URL
response = requests.get(url, stream=True)
response.raise_for_status()
# Open the image using PIL
image = Image.open(io.BytesIO(response.content))
# Save the image to disk
image.save(filename)
return filename
picture_response = (
False # specifies if the next model response should appear as a picture
)
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub("\*[^\*]*?(\*|$)", "", string)
def triggers_are_in(string):
string = remove_surrounded_chars(string)
# regex searches for send|main|message|me (at the end of the word) followed by
# a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s),
# (?aims) are regex parser flags
return bool(
re.search(
"(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b",
string,
)
)
def state_modifier(state):
if picture_response:
state["stream"] = False
return state
def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""
global gen_params
if not gen_params["mode"] == 1: # if not in immersive/interactive mode, do nothing
return string
if triggers_are_in(string): # if we're in it, check for trigger words
toggle_generation(True)
string = string.lower()
if "of" in string:
subject = string.split("of", 1)[
1
] # subdivide the string once by the first 'of' instance and get what's coming after it
string = gen_params["textgen_prefix"].replace("[subject]", subject)
else:
string = gen_params["textgen_prefix"].replace(
"[subject]",
"your appearance, your surroundings and what you are doing right now",
)
return string
# Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description):
global gen_params
description = re.sub("<audio.*?</audio>", " ", description)
description = f"({description}:1)"
payload = {
"prompt": gen_params["prompt_prefix"] + description,
"negativeprompt": gen_params["negative_prompt"],
"model": gen_params["sd_checkpoint"],
"images": "1",
"seed": gen_params["seed"],
"steps": gen_params["steps"],
"cfgscale": gen_params["cfg_scale"],
"aspectratio": gen_params["aspect_ratio"],
"width": gen_params["width"],
"height": gen_params["height"],
"sampler": gen_params["sampler_name"],
"scheduler": gen_params["scheduler_name"],
"initimagecreativity": gen_params["denoising_strength"],
"initimageresettonorm": "0",
"maskblur": "4",
"initimagerecompositemask": True,
"internalbackendtype": "Any",
"modelspecificenhancements": True,
"freeuapplyto": "Both",
"freeuversion": "2",
"freeublockone": "1.1",
"freeublocktwo": "1.2",
"freeuskipone": "0.9",
"freeuskiptwo": "0.2",
"automaticvae": True,
}
print(f'Prompting the image generator via the API on {gen_params["address"]}...')
(response, status) = make_api_request("GenerateText2Image", payload)
if status >= 400:
raise Exception(response)
image_path = download_and_save_image(
response["images"][0], "extensions/sd_api_pictures_swarmui/outputs/"
)
visible_result = ""
if gen_params["save_img"]:
visible_result = (
visible_result
+ f'<img src="/file/{image_path}" alt="{description}" style="max-width: unset; max-height: unset;">\n'
)
else:
image_buffer = requests.get(f"{gen_params['address']}/{response['images'][0]}")
image_buffer.raise_for_status() # Raise an exception for bad status codes
# Open the image using PIL
image = Image.open(io.BytesIO(image_buffer.content))
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
image.thumbnail((300, 300))
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
buffered.seek(0)
image_bytes = buffered.getvalue()
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
return visible_result
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging?
def output_modifier(string, state):
"""
This function is applied to the model outputs.
"""
global picture_response, gen_params
if not picture_response:
return string
if len(string) < 1 and shared.model == None:
return "Model not loaded - Load a model first"
string = remove_surrounded_chars(string)
string = string.replace('"', "")
string = string.replace("“", "")
string = string.replace("\n", " ")
string = string.strip()
if string == "":
string = "no viable description in reply, try regenerating"
return string
text = ""
if gen_params["mode"] < 2:
toggle_generation(False)
text = f"*Sends a picture which portrays: “{string}”*"
else:
text = string
string = get_SD_pictures(string) + "\n" + text
return string
def bot_prefix_modifier(string):
"""
This function is only applied in chat mode. It modifies
the prefix text for the Bot and can be used to bias its
behavior.
"""
return string
def toggle_generation(*args):
global picture_response, shared
if not args:
picture_response = not picture_response
else:
picture_response = args[0]
shared.processing_message = (
"*Is sending a picture...*" if picture_response else "*Is typing...*"
)
def filter_address(address):
address = address.strip()
address = re.sub("\/$", "", address) # remove trailing /s
if not address.startswith("http"):
address = "http://" + address
return address
def SD_api_address_update(address):
global gen_params
msg = "✔️ SD API is found on:"
address = filter_address(address)
gen_params.update({"address": address})
try:
(response, status) = make_api_request("ListModels", {"path": "", "depth": 2})
if status >= 400:
raise Exception(response)
update_gen_params(
{"checkpoint_list": [model["name"] for model in response["files"]]}
)
return gr.update(
choices=gen_params["checkpoint_list"],
value=gen_params["checkpoint_list"][0],
)
except Exception as e:
print(f"Error while connecting to SwarmUI - {e}")
msg = "❌ No SD API endpoint on:"
return gr.Textbox.update(label=msg)
def custom_css():
path_to_css = Path(__file__).parent.resolve() / "style.css"
return open(path_to_css, "r").read()
def load_checkpoint(checkpoint):
payload = {"model": checkpoint}
try:
make_api_request("SelectModel", payload)
except:
pass
def get_samplers():
try:
(response, status) = make_api_request("ListT2IParams", {})
if status >= 400:
raise Exception(response)
samplers = next(
(item["values"] for item in response["list"] if item["name"] == "Sampler"),
None,
)
if not samplers:
samplers = []
except Exception as e:
print(e)
samplers = []
return samplers
def ui():
global gen_params
# Gradio elements
with gr.Accordion("SwarmUI Parameters", open=True, elem_classes="SDAP"):
with gr.Row():
address = gr.Textbox(
placeholder=lambda: gen_params["address"],
value=lambda: gen_params["address"],
label="SwarmUI address",
)
modes_list = ["Manual", "Immersive/Interactive", "Picturebook/Adventure"]
mode = gr.Dropdown(
modes_list,
value=lambda: modes_list[gen_params["mode"]],
label="Mode of operation",
type="index",
)
with gr.Column(scale=1, min_width=300):
save_img = gr.Checkbox(
value=lambda: gen_params["save_img"],
label="Keep original images and use them in chat",
)
force_pic = gr.Button("Force the picture response")
suppr_pic = gr.Button("Suppress the picture response")
with gr.Row():
checkpoint = gr.Dropdown(
gen_params["checkpoint_list"],
value=lambda: gen_params["sd_checkpoint"],
label="Checkpoint",
type="value",
)
with gr.Accordion("Generation parameters", open=False):
prompt_prefix = gr.Textbox(
placeholder=lambda: gen_params["prompt_prefix"],
value=lambda: gen_params["prompt_prefix"],
label="Prompt Prefix (best used to describe the look of the character)",
)
textgen_prefix = gr.Textbox(
placeholder=lambda: gen_params["textgen_prefix"],
value=lambda: gen_params["textgen_prefix"],
label="textgen prefix (type [subject] where the subject should be placed)",
)
negative_prompt = gr.Textbox(
placeholder=lambda: gen_params["negative_prompt"],
value=lambda: gen_params["negative_prompt"],
label="Negative Prompt",
)
with gr.Row():
with gr.Column():
width = gr.Slider(
64, 2048, value=lambda: gen_params["width"], step=64, label="Width"
)
height = gr.Slider(
64, 2048, value=lambda: gen_params["height"], step=64, label="Height"
)
with gr.Column(variant="compact", elem_id="sampler_col"):
with gr.Row(elem_id="sampler_row"):
sampler_name = gr.Dropdown(
value=lambda: gen_params["sampler_name"],
allow_custom_value=True,
label="Sampling method",
elem_id="sampler_box",
)
create_refresh_button(
sampler_name,
lambda: None,
lambda: {"choices": get_samplers()},
"refresh-button",
)
steps = gr.Slider(
1,
150,
value=lambda: gen_params["steps"],
step=1,
label="Sampling steps",
elem_id="steps_box",
)
with gr.Row():
seed = gr.Number(
label="Seed", value=lambda: gen_params["seed"], elem_id="seed_box"
)
cfg_scale = gr.Number(
label="CFG Scale", value=lambda: gen_params["cfg_scale"], elem_id="cfg_box"
)
denoising_strength = gr.Slider(
0,
1,
value=lambda: gen_params["denoising_strength"],
step=0.01,
label="Denoising strength",
)
# Event functions to update the parameters in the backend
address.change(
lambda x: gen_params.update({"address": filter_address(x)}), address, None
)
mode.select(lambda x: gen_params.update({"mode": x}), mode, None)
mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None)
save_img.change(lambda x: gen_params.update({"save_img": x}), save_img, None)
address.submit(fn=SD_api_address_update, inputs=address, outputs=checkpoint)
prompt_prefix.change(
lambda x: gen_params.update({"prompt_prefix": x}), prompt_prefix, None
)
textgen_prefix.change(
lambda x: gen_params.update({"textgen_prefix": x}), textgen_prefix, None
)
negative_prompt.change(
lambda x: gen_params.update({"negative_prompt": x}), negative_prompt, None
)
width.change(lambda x: gen_params.update({"width": x}), width, None)
height.change(lambda x: gen_params.update({"height": x}), height, None)
denoising_strength.change(
lambda x: gen_params.update({"denoising_strength": x}), denoising_strength, None
)
checkpoint.change(
lambda x: gen_params.update({"sd_checkpoint": x}), checkpoint, None
)
checkpoint.change(load_checkpoint, checkpoint, None)
sampler_name.change(
lambda x: gen_params.update({"sampler_name": x}), sampler_name, None
)
steps.change(lambda x: gen_params.update({"steps": x}), steps, None)
seed.change(lambda x: gen_params.update({"seed": x}), seed, None)
cfg_scale.change(lambda x: gen_params.update({"cfg_scale": x}), cfg_scale, None)
force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None)
suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment