Created
September 16, 2024 16:00
-
-
Save aimerib/7b3c338a5d9496bf636a2c7b951a9f86 to your computer and use it in GitHub Desktop.
Update Oobabooga StableDiffusion extension to use SwarmUI instead
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Either change extensions/sd_api_pictures/script.py or copy the extensions/sd_api_pictures folder as second folder like sd_api_pictures_swarm and change extensions/sd_api_pictures_swarm/script.py | |
import base64 | |
import io | |
import json | |
import re | |
from datetime import date | |
from pathlib import Path | |
from PIL import Image | |
import gradio as gr | |
import requests | |
import torch | |
import os | |
import hashlib | |
from modules import shared | |
from modules.models import reload_model, unload_model | |
from modules.ui import create_refresh_button | |
torch._C._jit_set_profiling_mode(False) | |
# parameters which can be customized in settings.json of webui | |
gen_params = { | |
"address": "http://127.0.0.1:7801", | |
"mode": 0, # modes of operation: 0 (Manual only), 1 (Immersive/Interactive - looks for words to trigger), 2 (Picturebook Adventure - Always on) | |
"manage_VRAM": False, | |
"save_img": False, | |
"prompt_prefix": "(Masterpiece:1.1), detailed, intricate, colorful", | |
"negative_prompt": "(worst quality, low quality:1.3)", | |
"width": 512, | |
"height": 512, | |
"denoising_strength": 0.61, | |
"seed": -1, | |
"sampler_name": "dpmpp_2m_sde_gpu", | |
"scheduler_name": "karras", | |
"aspect_ratio": "1:1", | |
"steps": 32, | |
"cfg_scale": 7, | |
"textgen_prefix": "Please provide a detailed and vivid description of [subject]", | |
"sd_checkpoint": " ", | |
"checkpoint_list": [" "], | |
"session_id": "", | |
} | |
def update_gen_params(new_params): | |
global gen_params | |
# Validate the input | |
if not isinstance(new_params, dict): | |
raise ValueError("Input must be a dictionary") | |
# Update only the keys that exist in gen_params | |
for key, value in new_params.items(): | |
if key in gen_params: | |
# Type checking to ensure we're not changing the type of a parameter | |
if isinstance(gen_params[key], type(value)): | |
gen_params[key] = value | |
else: | |
print( | |
f"Warning: Type mismatch for key '{key}'. Expected {type(gen_params[key])}, got {type(value)}. Value not updated." | |
) | |
else: | |
print(f"Warning: '{key}' is not a valid parameter and will be ignored.") | |
# You might want to add specific validations for certain parameters | |
if "width" in new_params or "height" in new_params: | |
gen_params["aspect_ratio"] = f"{gen_params['width']}:{gen_params['height']}" | |
# Return the updated dictionary | |
return gen_params | |
# SwarmUI API requires a session_id for every request we send. Let's get one | |
def get_session_id(): | |
global gen_params | |
print("Requesting new SwarmUI session token...") | |
msg = "✔️ SwarmUI API Session Token updated" | |
try: | |
response = requests.post( | |
url=f'{gen_params["address"]}/API/GetNewSession', | |
data=json.dumps({}), | |
headers={"Content-Type": "application/json"}, | |
) | |
response.raise_for_status() | |
r = response.json() | |
update_gen_params({"session_id": r["session_id"]}) | |
except: | |
msg = "❌ No SD API endpoint on:" | |
return gr.Textbox.update(label=msg) | |
def make_api_request(endpoint, payload): | |
print(f"Making SwarmUI API request to endpoint: {endpoint}") | |
msg = f"✔️ SwarmUI API request to: {endpoint} - Success" | |
if gen_params["session_id"] == "": | |
get_session_id() | |
try: | |
payload["session_id"] = gen_params["session_id"] | |
response = requests.post( | |
url=f'{gen_params["address"]}/API/{endpoint}', | |
data=json.dumps(payload), | |
headers={"Content-Type": "application/json"}, | |
) | |
response.raise_for_status() | |
r = response.json() | |
print(msg) | |
return (r, response.ok) | |
except Exception as e: | |
msg = f"❌ Request to endpoint: {endpoint} returned - {e}" | |
print(msg) | |
def get_file_hash(file_path): | |
"""Calculate the MD5 hash of a file.""" | |
hash_md5 = hashlib.md5() | |
with open(file_path, "rb") as f: | |
for chunk in iter(lambda: f.read(4096), b""): | |
hash_md5.update(chunk) | |
return hash_md5.hexdigest() | |
def download_and_save_image(url, save_dir): | |
# Create a filename based on the URL | |
url = f"{gen_params['address']}/{url}" | |
filename = os.path.join(save_dir, hashlib.md5(url.encode()).hexdigest() + ".png") | |
output_file = Path(filename) | |
output_file.parent.mkdir(parents=True, exist_ok=True) | |
# Check if file already exists and hasn't changed | |
if os.path.exists(filename): | |
etag = requests.head(url).headers.get("ETag") | |
if etag and etag == get_file_hash(filename): | |
return filename | |
# Send a GET request to the URL | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
# Open the image using PIL | |
image = Image.open(io.BytesIO(response.content)) | |
# Save the image to disk | |
image.save(filename) | |
return filename | |
picture_response = ( | |
False # specifies if the next model response should appear as a picture | |
) | |
def remove_surrounded_chars(string): | |
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR | |
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' | |
return re.sub("\*[^\*]*?(\*|$)", "", string) | |
def triggers_are_in(string): | |
string = remove_surrounded_chars(string) | |
# regex searches for send|main|message|me (at the end of the word) followed by | |
# a whole word of image|pic|picture|photo|snap|snapshot|selfie|meme(s), | |
# (?aims) are regex parser flags | |
return bool( | |
re.search( | |
"(?aims)(send|mail|message|me)\\b.+?\\b(image|pic(ture)?|photo|snap(shot)?|selfie|meme)s?\\b", | |
string, | |
) | |
) | |
def state_modifier(state): | |
if picture_response: | |
state["stream"] = False | |
return state | |
def input_modifier(string): | |
""" | |
This function is applied to your text inputs before | |
they are fed into the model. | |
""" | |
global gen_params | |
if not gen_params["mode"] == 1: # if not in immersive/interactive mode, do nothing | |
return string | |
if triggers_are_in(string): # if we're in it, check for trigger words | |
toggle_generation(True) | |
string = string.lower() | |
if "of" in string: | |
subject = string.split("of", 1)[ | |
1 | |
] # subdivide the string once by the first 'of' instance and get what's coming after it | |
string = gen_params["textgen_prefix"].replace("[subject]", subject) | |
else: | |
string = gen_params["textgen_prefix"].replace( | |
"[subject]", | |
"your appearance, your surroundings and what you are doing right now", | |
) | |
return string | |
# Get and save the Stable Diffusion-generated picture | |
def get_SD_pictures(description): | |
global gen_params | |
description = re.sub("<audio.*?</audio>", " ", description) | |
description = f"({description}:1)" | |
payload = { | |
"prompt": gen_params["prompt_prefix"] + description, | |
"negativeprompt": gen_params["negative_prompt"], | |
"model": gen_params["sd_checkpoint"], | |
"images": "1", | |
"seed": gen_params["seed"], | |
"steps": gen_params["steps"], | |
"cfgscale": gen_params["cfg_scale"], | |
"aspectratio": gen_params["aspect_ratio"], | |
"width": gen_params["width"], | |
"height": gen_params["height"], | |
"sampler": gen_params["sampler_name"], | |
"scheduler": gen_params["scheduler_name"], | |
"initimagecreativity": gen_params["denoising_strength"], | |
"initimageresettonorm": "0", | |
"maskblur": "4", | |
"initimagerecompositemask": True, | |
"internalbackendtype": "Any", | |
"modelspecificenhancements": True, | |
"freeuapplyto": "Both", | |
"freeuversion": "2", | |
"freeublockone": "1.1", | |
"freeublocktwo": "1.2", | |
"freeuskipone": "0.9", | |
"freeuskiptwo": "0.2", | |
"automaticvae": True, | |
} | |
print(f'Prompting the image generator via the API on {gen_params["address"]}...') | |
(response, status) = make_api_request("GenerateText2Image", payload) | |
if status >= 400: | |
raise Exception(response) | |
image_path = download_and_save_image( | |
response["images"][0], "extensions/sd_api_pictures_swarmui/outputs/" | |
) | |
visible_result = "" | |
if gen_params["save_img"]: | |
visible_result = ( | |
visible_result | |
+ f'<img src="/file/{image_path}" alt="{description}" style="max-width: unset; max-height: unset;">\n' | |
) | |
else: | |
image_buffer = requests.get(f"{gen_params['address']}/{response['images'][0]}") | |
image_buffer.raise_for_status() # Raise an exception for bad status codes | |
# Open the image using PIL | |
image = Image.open(io.BytesIO(image_buffer.content)) | |
# lower the resolution of received images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history | |
image.thumbnail((300, 300)) | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
buffered.seek(0) | |
image_bytes = buffered.getvalue() | |
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode() | |
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n' | |
return visible_result | |
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) | |
# and replace it with 'text' for the purposes of logging? | |
def output_modifier(string, state): | |
""" | |
This function is applied to the model outputs. | |
""" | |
global picture_response, gen_params | |
if not picture_response: | |
return string | |
if len(string) < 1 and shared.model == None: | |
return "Model not loaded - Load a model first" | |
string = remove_surrounded_chars(string) | |
string = string.replace('"', "") | |
string = string.replace("“", "") | |
string = string.replace("\n", " ") | |
string = string.strip() | |
if string == "": | |
string = "no viable description in reply, try regenerating" | |
return string | |
text = "" | |
if gen_params["mode"] < 2: | |
toggle_generation(False) | |
text = f"*Sends a picture which portrays: “{string}”*" | |
else: | |
text = string | |
string = get_SD_pictures(string) + "\n" + text | |
return string | |
def bot_prefix_modifier(string): | |
""" | |
This function is only applied in chat mode. It modifies | |
the prefix text for the Bot and can be used to bias its | |
behavior. | |
""" | |
return string | |
def toggle_generation(*args): | |
global picture_response, shared | |
if not args: | |
picture_response = not picture_response | |
else: | |
picture_response = args[0] | |
shared.processing_message = ( | |
"*Is sending a picture...*" if picture_response else "*Is typing...*" | |
) | |
def filter_address(address): | |
address = address.strip() | |
address = re.sub("\/$", "", address) # remove trailing /s | |
if not address.startswith("http"): | |
address = "http://" + address | |
return address | |
def SD_api_address_update(address): | |
global gen_params | |
msg = "✔️ SD API is found on:" | |
address = filter_address(address) | |
gen_params.update({"address": address}) | |
try: | |
(response, status) = make_api_request("ListModels", {"path": "", "depth": 2}) | |
if status >= 400: | |
raise Exception(response) | |
update_gen_params( | |
{"checkpoint_list": [model["name"] for model in response["files"]]} | |
) | |
return gr.update( | |
choices=gen_params["checkpoint_list"], | |
value=gen_params["checkpoint_list"][0], | |
) | |
except Exception as e: | |
print(f"Error while connecting to SwarmUI - {e}") | |
msg = "❌ No SD API endpoint on:" | |
return gr.Textbox.update(label=msg) | |
def custom_css(): | |
path_to_css = Path(__file__).parent.resolve() / "style.css" | |
return open(path_to_css, "r").read() | |
def load_checkpoint(checkpoint): | |
payload = {"model": checkpoint} | |
try: | |
make_api_request("SelectModel", payload) | |
except: | |
pass | |
def get_samplers(): | |
try: | |
(response, status) = make_api_request("ListT2IParams", {}) | |
if status >= 400: | |
raise Exception(response) | |
samplers = next( | |
(item["values"] for item in response["list"] if item["name"] == "Sampler"), | |
None, | |
) | |
if not samplers: | |
samplers = [] | |
except Exception as e: | |
print(e) | |
samplers = [] | |
return samplers | |
def ui(): | |
global gen_params | |
# Gradio elements | |
with gr.Accordion("SwarmUI Parameters", open=True, elem_classes="SDAP"): | |
with gr.Row(): | |
address = gr.Textbox( | |
placeholder=lambda: gen_params["address"], | |
value=lambda: gen_params["address"], | |
label="SwarmUI address", | |
) | |
modes_list = ["Manual", "Immersive/Interactive", "Picturebook/Adventure"] | |
mode = gr.Dropdown( | |
modes_list, | |
value=lambda: modes_list[gen_params["mode"]], | |
label="Mode of operation", | |
type="index", | |
) | |
with gr.Column(scale=1, min_width=300): | |
save_img = gr.Checkbox( | |
value=lambda: gen_params["save_img"], | |
label="Keep original images and use them in chat", | |
) | |
force_pic = gr.Button("Force the picture response") | |
suppr_pic = gr.Button("Suppress the picture response") | |
with gr.Row(): | |
checkpoint = gr.Dropdown( | |
gen_params["checkpoint_list"], | |
value=lambda: gen_params["sd_checkpoint"], | |
label="Checkpoint", | |
type="value", | |
) | |
with gr.Accordion("Generation parameters", open=False): | |
prompt_prefix = gr.Textbox( | |
placeholder=lambda: gen_params["prompt_prefix"], | |
value=lambda: gen_params["prompt_prefix"], | |
label="Prompt Prefix (best used to describe the look of the character)", | |
) | |
textgen_prefix = gr.Textbox( | |
placeholder=lambda: gen_params["textgen_prefix"], | |
value=lambda: gen_params["textgen_prefix"], | |
label="textgen prefix (type [subject] where the subject should be placed)", | |
) | |
negative_prompt = gr.Textbox( | |
placeholder=lambda: gen_params["negative_prompt"], | |
value=lambda: gen_params["negative_prompt"], | |
label="Negative Prompt", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
width = gr.Slider( | |
64, 2048, value=lambda: gen_params["width"], step=64, label="Width" | |
) | |
height = gr.Slider( | |
64, 2048, value=lambda: gen_params["height"], step=64, label="Height" | |
) | |
with gr.Column(variant="compact", elem_id="sampler_col"): | |
with gr.Row(elem_id="sampler_row"): | |
sampler_name = gr.Dropdown( | |
value=lambda: gen_params["sampler_name"], | |
allow_custom_value=True, | |
label="Sampling method", | |
elem_id="sampler_box", | |
) | |
create_refresh_button( | |
sampler_name, | |
lambda: None, | |
lambda: {"choices": get_samplers()}, | |
"refresh-button", | |
) | |
steps = gr.Slider( | |
1, | |
150, | |
value=lambda: gen_params["steps"], | |
step=1, | |
label="Sampling steps", | |
elem_id="steps_box", | |
) | |
with gr.Row(): | |
seed = gr.Number( | |
label="Seed", value=lambda: gen_params["seed"], elem_id="seed_box" | |
) | |
cfg_scale = gr.Number( | |
label="CFG Scale", value=lambda: gen_params["cfg_scale"], elem_id="cfg_box" | |
) | |
denoising_strength = gr.Slider( | |
0, | |
1, | |
value=lambda: gen_params["denoising_strength"], | |
step=0.01, | |
label="Denoising strength", | |
) | |
# Event functions to update the parameters in the backend | |
address.change( | |
lambda x: gen_params.update({"address": filter_address(x)}), address, None | |
) | |
mode.select(lambda x: gen_params.update({"mode": x}), mode, None) | |
mode.select(lambda x: toggle_generation(x > 1), inputs=mode, outputs=None) | |
save_img.change(lambda x: gen_params.update({"save_img": x}), save_img, None) | |
address.submit(fn=SD_api_address_update, inputs=address, outputs=checkpoint) | |
prompt_prefix.change( | |
lambda x: gen_params.update({"prompt_prefix": x}), prompt_prefix, None | |
) | |
textgen_prefix.change( | |
lambda x: gen_params.update({"textgen_prefix": x}), textgen_prefix, None | |
) | |
negative_prompt.change( | |
lambda x: gen_params.update({"negative_prompt": x}), negative_prompt, None | |
) | |
width.change(lambda x: gen_params.update({"width": x}), width, None) | |
height.change(lambda x: gen_params.update({"height": x}), height, None) | |
denoising_strength.change( | |
lambda x: gen_params.update({"denoising_strength": x}), denoising_strength, None | |
) | |
checkpoint.change( | |
lambda x: gen_params.update({"sd_checkpoint": x}), checkpoint, None | |
) | |
checkpoint.change(load_checkpoint, checkpoint, None) | |
sampler_name.change( | |
lambda x: gen_params.update({"sampler_name": x}), sampler_name, None | |
) | |
steps.change(lambda x: gen_params.update({"steps": x}), steps, None) | |
seed.change(lambda x: gen_params.update({"seed": x}), seed, None) | |
cfg_scale.change(lambda x: gen_params.update({"cfg_scale": x}), cfg_scale, None) | |
force_pic.click(lambda x: toggle_generation(True), inputs=force_pic, outputs=None) | |
suppr_pic.click(lambda x: toggle_generation(False), inputs=suppr_pic, outputs=None) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment