from clip_text_custom_embedder import text_embeddings
from diffusers import StableDiffusionPipeline
import torch
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda')
prompt = "((masterpiece, best quality)), white background, close-up, 1girl, litte smile"
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), partial face, "
"partial head, cropped head")
cond, uncond = text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=2)
images = pipe(prompt_embeds=cond,
negative_prompt_embeds=uncond,
generator=torch.manual_seed(seed)).images[0]
-
-
Save takuma104/43552b8ec70b63323c57dc9c6fcb9b90 to your computer and use it in GitHub Desktop.
import torch | |
import math | |
import re | |
# copied and customized from automatic1111 sd_hijack.py & prompt_parser.py | |
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/sd_hijack.py#L113 | |
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/ec1924ee5789b72c31c65932b549c59ccae0cdd6/modules/prompt_parser.py#L259 | |
re_attention = re.compile(r""" | |
\\\(| | |
\\\{| | |
\\\)| | |
\\\}| | |
\\\[| | |
\\]| | |
\\\\| | |
\\| | |
\(| | |
\{| | |
\[| | |
:([+-]?[.\d]+)\)| | |
\)| | |
\}| | |
]| | |
[^\\()\\{}\[\]:]+| | |
: | |
""", re.X) | |
def parse_prompt_attention(text): | |
""" | |
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. | |
Accepted tokens are: | |
(abc) - increases attention to abc by a multiplier of 1.1 | |
(abc:3.12) - increases attention to abc by a multiplier of 3.12 | |
[abc] - decreases attention to abc by a multiplier of 1.1 | |
\( - literal character '(' | |
\[ - literal character '[' | |
\) - literal character ')' | |
\] - literal character ']' | |
\\ - literal character '\' | |
anything else - just text | |
>>> parse_prompt_attention('normal text') | |
[['normal text', 1.0]] | |
>>> parse_prompt_attention('an (important) word') | |
[['an ', 1.0], ['important', 1.1], [' word', 1.0]] | |
>>> parse_prompt_attention('(unbalanced') | |
[['unbalanced', 1.1]] | |
>>> parse_prompt_attention('\(literal\]') | |
[['(literal]', 1.0]] | |
>>> parse_prompt_attention('(unnecessary)(parens)') | |
[['unnecessaryparens', 1.1]] | |
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') | |
[['a ', 1.0], | |
['house', 1.5730000000000004], | |
[' ', 1.1], | |
['on', 1.0], | |
[' a ', 1.1], | |
['hill', 0.55], | |
[', sun, ', 1.1], | |
['sky', 1.4641000000000006], | |
['.', 1.1]] | |
""" | |
res = [] | |
round_brackets = [] | |
square_brackets = [] | |
round_bracket_multiplier = 1.1 | |
square_bracket_multiplier = 1 / 1.1 | |
def multiply_range(start_position, multiplier): | |
for p in range(start_position, len(res)): | |
res[p][1] *= multiplier | |
for m in re_attention.finditer(text): | |
text = m.group(0) | |
weight = m.group(1) | |
if text.startswith('\\'): | |
res.append([text[1:], 1.0]) | |
elif text == '(' or text == '{': | |
round_brackets.append(len(res)) | |
elif text == '[': | |
square_brackets.append(len(res)) | |
elif weight is not None and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), float(weight)) | |
elif (text == ')' or text == '}') and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), round_bracket_multiplier) | |
elif text == ']' and len(square_brackets) > 0: | |
multiply_range(square_brackets.pop(), square_bracket_multiplier) | |
else: | |
res.append([text, 1.0]) | |
for pos in round_brackets: | |
multiply_range(pos, round_bracket_multiplier) | |
for pos in square_brackets: | |
multiply_range(pos, square_bracket_multiplier) | |
if len(res) == 0: | |
res = [["", 1.0]] | |
# merge runs of identical weights | |
i = 0 | |
while i + 1 < len(res): | |
if res[i][1] == res[i + 1][1]: | |
res[i][0] += res[i + 1][0] | |
res.pop(i + 1) | |
else: | |
i += 1 | |
return res | |
class CLIPTextCustomEmbedder(object): | |
def __init__(self, tokenizer, text_encoder, device, | |
clip_stop_at_last_layers=1): | |
self.tokenizer = tokenizer | |
self.text_encoder = text_encoder | |
self.token_mults = {} | |
self.device = device | |
self.clip_stop_at_last_layers = clip_stop_at_last_layers | |
def tokenize_line(self, line): | |
def get_target_prompt_token_count(token_count): | |
return math.ceil(max(token_count, 1) / 75) * 75 | |
id_end = self.tokenizer.eos_token_id | |
parsed = parse_prompt_attention(line) | |
tokenized = self.tokenizer( | |
[text for text, _ in parsed], truncation=False, | |
add_special_tokens=False)["input_ids"] | |
fixes = [] | |
remade_tokens = [] | |
multipliers = [] | |
for tokens, (text, weight) in zip(tokenized, parsed): | |
i = 0 | |
while i < len(tokens): | |
token = tokens[i] | |
remade_tokens.append(token) | |
multipliers.append(weight) | |
i += 1 | |
token_count = len(remade_tokens) | |
prompt_target_length = get_target_prompt_token_count(token_count) | |
tokens_to_add = prompt_target_length - len(remade_tokens) | |
remade_tokens = remade_tokens + [id_end] * tokens_to_add | |
multipliers = multipliers + [1.0] * tokens_to_add | |
return remade_tokens, fixes, multipliers, token_count | |
def process_text(self, texts): | |
if isinstance(texts, str): | |
texts = [texts] | |
remade_batch_tokens = [] | |
cache = {} | |
batch_multipliers = [] | |
for line in texts: | |
if line in cache: | |
remade_tokens, fixes, multipliers = cache[line] | |
else: | |
remade_tokens, fixes, multipliers, _ = self.tokenize_line(line) | |
cache[line] = (remade_tokens, fixes, multipliers) | |
remade_batch_tokens.append(remade_tokens) | |
batch_multipliers.append(multipliers) | |
return batch_multipliers, remade_batch_tokens | |
def __call__(self, text): | |
batch_multipliers, remade_batch_tokens = self.process_text(text) | |
z = None | |
i = 0 | |
while max(map(len, remade_batch_tokens)) != 0: | |
rem_tokens = [x[75:] for x in remade_batch_tokens] | |
rem_multipliers = [x[75:] for x in batch_multipliers] | |
tokens = [] | |
multipliers = [] | |
for j in range(len(remade_batch_tokens)): | |
if len(remade_batch_tokens[j]) > 0: | |
tokens.append(remade_batch_tokens[j][:75]) | |
multipliers.append(batch_multipliers[j][:75]) | |
else: | |
tokens.append([self.tokenizer.eos_token_id] * 75) | |
multipliers.append([1.0] * 75) | |
z1 = self.process_tokens(tokens, multipliers) | |
z = z1 if z is None else torch.cat((z, z1), axis=-2) | |
remade_batch_tokens = rem_tokens | |
batch_multipliers = rem_multipliers | |
i += 1 | |
return z | |
def process_tokens(self, remade_batch_tokens, batch_multipliers): | |
remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] + | |
[self.tokenizer.eos_token_id] for x in remade_batch_tokens] | |
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] | |
tokens = torch.asarray(remade_batch_tokens).to(self.device) | |
# print(tokens.shape) | |
# print(tokens) | |
outputs = self.text_encoder( | |
input_ids=tokens, output_hidden_states=True) | |
if self.clip_stop_at_last_layers > 1: | |
z = self.text_encoder.text_model.final_layer_norm( | |
outputs.hidden_states[-self.clip_stop_at_last_layers]) | |
else: | |
z = outputs.last_hidden_state | |
# restoring original mean is likely not correct, but it seems to work well | |
# to prevent artifacts that happen otherwise | |
batch_multipliers_of_same_length = [ | |
x + [1.0] * (75 - len(x)) for x in batch_multipliers] | |
batch_multipliers = torch.asarray( | |
batch_multipliers_of_same_length).to(self.device) | |
# print(batch_multipliers.shape) | |
# print(batch_multipliers) | |
original_mean = z.mean() | |
z *= batch_multipliers.reshape(batch_multipliers.shape + | |
(1,)).expand(z.shape) | |
new_mean = z.mean() | |
z *= original_mean / new_mean | |
return z | |
def get_text_tokens(self, text): | |
batch_multipliers, remade_batch_tokens = self.process_text(text) | |
return [[self.tokenizer.bos_token_id] + remade_batch_tokens[0]], \ | |
[[1.0] + batch_multipliers[0]] | |
def text_embeddings_equal_len(text_embedder, prompt, negative_prompt): | |
cond_embeddings = text_embedder(prompt) | |
uncond_embeddings = text_embedder(negative_prompt) | |
cond_len = cond_embeddings.shape[1] | |
uncond_len = uncond_embeddings.shape[1] | |
if cond_len == uncond_len: | |
return cond_embeddings, uncond_embeddings | |
else: | |
if cond_len > uncond_len: | |
n = (cond_len - uncond_len) // 77 | |
return cond_embeddings, torch.cat([uncond_embeddings] + [text_embedder("")]*n, dim=1) | |
else: | |
n = (uncond_len - cond_len) // 77 | |
return torch.cat([cond_embeddings] + [text_embedder("")]*n, dim=1), uncond_embeddings | |
def text_embeddings(pipe, prompt, negative_prompt, clip_stop_at_last_layers=1): | |
text_embedder = CLIPTextCustomEmbedder(tokenizer=pipe.tokenizer, | |
text_encoder=pipe.text_encoder, | |
device=pipe.text_encoder.device, | |
clip_stop_at_last_layers=clip_stop_at_last_layers) | |
cond_embeddings, uncond_embeddings = text_embeddings_equal_len(text_embedder, prompt, negative_prompt) | |
return cond_embeddings, uncond_embeddings | |
@alexblattner Sorry for the late reply. The A1111-compatible syntax like <lora:xxx:1>
cannot be used in this parser. Just like the A1111 implementation, you would need to preprocess the prompt and separate only the <lora:xxx:1>
part for use. You can implement it this way.
def process_prompt(prompt):
lora = re.compile(r'<lora:([^:]+):([\d\.\-]+)>')
lora_matches = lora.findall(prompt)
lora_matches = [(name, float(weight)) for name, weight in lora_matches]
filterd_prompt = lora.sub('', prompt)
return filterd_prompt, lora_matches
As for applying LoRA to Diffusers, it currently does not support the LoRA format commonly used in A1111 (kohya-ss/sd-scripts trained weight or its derivatives trained weight). The closest support is in huggingface/diffusers#3294. I think it should work with this too.
Personally, I think #3294 might be problematic because it modifies the original weights of the model, such as Unet, and it may be difficult to revert them after applying the changes. I've been experimenting with using hooks to freely swap multiple LoRAs. If you're interested, feel free to use.
https://gist.github.com/takuma104/e38d683d72b1e448b8d9b3835f7cfa44
@takuma104 thanks for the answer.
How come diffusers doesn't support Loras when it's such a common format? Also, wasn't the point of Loras that they modify the model for better results?
Also, your code seems very useful to me. I am trying to apply and remove Loras during the denoising process in order to have Loras applied to some parts of the image and not others. Would you mind giving an example of applying and removing Loras a bunch? Is it possible to use your code from inside the denoising loop?
Thanks for everything btw
It seems that the development of the train script (Kohya's implementation) has progressed faster, and the LoRA files created using it have rapidly become widespread in CivitAI, making it difficult for Diffusers' LoRA implementation to catch up. The script I wrote uses hooks, so it's not likely to be merged into Diffusers. I'm considering creating a non-hook version as well.
Since most of the LoRA files in CivitAI are applied to both TextEncoder and Unet, it might be quite complicated, and you might need to redo the TextEncoder processing in the denoising loop. If you want to keep it simple, you could create a code that only expects changes to the Unet, as at least those should be effective. For example, it should be possible to change the LoRA state around here (before the Unet inference) in the StableDiffusionPipeline
. You should be able to use apply_lora
/remove_lora
without any restrictions, but since it's not very fast, I recommend dynamically adjusting the alpha
values instead.
In terms of code, you can execute the following before the denoising loop:
lora1 = self.apply_lora('lora1.safetensors', 0.0)
lora2 = self.apply_lora('lora2.safetensors', 0.0)
And within the denoising loop, dynamically perform operations like lora1.alpha = 1.0
.
@takuma104 thank you very much for your answer. I didn't know that the Lora implementations were that different between diffusers and A111.
I love the fact that diffusers is so organized and doesn't require me to install a shitty webui so I have stuck to it. For my project I do need to use the latest tech though. Hopefully diffusers catches up soon. I am also creating a pipeline that should make latent couple (two shots) usable with multicontrolnet. I made latent couples work, I am applying controlnet and dynamic Loras on top of it now.
Does the code you gave:
lora1 = self.apply_lora('lora1.safetensors', 0.0)
lora2 = self.apply_lora('lora2.safetensors', 0.0)
Work with your code? If yes, then you saved me a fuckton of time and thanks a lot for it. If not, I'll try to do something about it assuming I manage to make my pipeline work with multicontrolnet.
@alexblattner Indeed, the code for Diffusers is well-organized. It should work this code in the StableDiffusionControlNetPipeline
as well. Since LoRAs are only applied to pipe.unet
and pipe.text_encoder
. I hope your latent couple project goes smoothly!
thank you @takuma104 !
I fixed an issue with torch type:
import math
import safetensors
import torch
from diffusers import DiffusionPipeline
# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17
class LoRAModule(torch.nn.Module):
def __init__(
self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
)
self.lora_up = torch.nn.Conv2d(
self.lora_dim, out_dim, (1, 1), (1, 1), bias=False
)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if alpha is None or alpha == 0:
self.alpha = self.lora_dim
else:
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant.
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
def forward(self, x):
scale = self.alpha / self.lora_dim
down= self.lora_down(x)
up= self.lora_up(down)
return self.multiplier * scale * up
class LoRAModuleContainer(torch.nn.Module):
def __init__(self, hooks, state_dict, multiplier):
super().__init__()
self.multiplier = multiplier
# Create LoRAModule from state_dict information
for key, value in state_dict.items():
if "lora_down" in key:
lora_name = key.split(".")[0]
lora_dim = value.size()[0]
lora_name_alpha = key.split(".")[0] + '.alpha'
alpha = None
if lora_name_alpha in state_dict:
alpha = state_dict[lora_name_alpha].item()
hook = hooks[lora_name]
lora_module = LoRAModule(
hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier
)
self.register_module(lora_name, lora_module)
# Load whole LoRA weights
self.load_state_dict(state_dict)
# Register LoRAModule to LoRAHook
for name, module in self.named_modules():
if module.__class__.__name__ == "LoRAModule":
hook = hooks[name]
hook.append_lora(module)
@property
def alpha(self):
return self.multiplier
@alpha.setter
def alpha(self, multiplier):
self.multiplier = multiplier
for name, module in self.named_modules():
if module.__class__.__name__ == "LoRAModule":
module.multiplier = multiplier
def remove_from_hooks(self, hooks):
for name, module in self.named_modules():
if module.__class__.__name__ == "LoRAModule":
hook = hooks[name]
hook.remove_lora(module)
del module
class LoRAHook(torch.nn.Module):
"""
replaces forward method of the original Linear,
instead of replacing the original Linear module.
"""
def __init__(self):
super().__init__()
self.lora_modules = []
def install(self, orig_module):
assert not hasattr(self, "orig_module")
self.orig_module = orig_module
self.orig_forward = self.orig_module.forward
self.orig_module.forward = self.forward
def uninstall(self):
assert hasattr(self, "orig_module")
self.orig_module.forward = self.orig_forward
del self.orig_forward
del self.orig_module
def append_lora(self, lora_module):
self.lora_modules.append(lora_module)
def remove_lora(self, lora_module):
self.lora_modules.remove(lora_module)
def forward(self, x):
if len(self.lora_modules) == 0:
return self.orig_forward(x)
lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0)
return self.orig_forward(x) + lora
class LoRAHookInjector(object):
def __init__(self):
super().__init__()
self.hooks = {}
self.device = None
self.dtype = None
def _get_target_modules(self, root_module, prefix, target_replace_modules):
target_modules = []
for name, module in root_module.named_modules():
if (
module.__class__.__name__ in target_replace_modules
and not "transformer_blocks" in name
): # to adapt latest diffusers:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
target_modules.append((lora_name, child_module))
return target_modules
def install_hooks(self, pipe):
"""Install LoRAHook to the pipe."""
assert len(self.hooks) == 0
text_encoder_targets = self._get_target_modules(
pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]
)
unet_targets = self._get_target_modules(
pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]
)
for name, target_module in text_encoder_targets + unet_targets:
hook = LoRAHook()
hook.install(target_module)
self.hooks[name] = hook
self.device = pipe.device
self.dtype = pipe.unet.dtype
def uninstall_hooks(self):
"""Uninstall LoRAHook from the pipe."""
for k, v in self.hooks.items():
v.uninstall()
self.hooks = {}
def apply_lora(self, filename, alpha=1.0,dtype=torch.float32):
"""Load LoRA weights and apply LoRA to the pipe."""
assert len(self.hooks) != 0
self.dtype = dtype
state_dict = safetensors.torch.load_file(filename)
container = LoRAModuleContainer(self.hooks, state_dict, alpha)
container.to(self.device, self.dtype)
return container
def remove_lora(self, container):
"""Remove the individual LoRA from the pipe."""
container.remove_from_hooks(self.hooks)
def install_lora_hook(pipe: DiffusionPipeline):
"""Install LoRAHook to the pipe."""
assert not hasattr(pipe, "lora_injector")
assert not hasattr(pipe, "apply_lora")
assert not hasattr(pipe, "remove_lora")
injector = LoRAHookInjector()
injector.install_hooks(pipe)
pipe.lora_injector = injector
pipe.apply_lora = injector.apply_lora
pipe.remove_lora = injector.remove_lora
def uninstall_lora_hook(pipe: DiffusionPipeline):
"""Uninstall LoRAHook from the pipe."""
pipe.lora_injector.uninstall_hooks()
del pipe.lora_injector
del pipe.apply_lora
del pipe.remove_lora
@takuma104 can you update clip_text_custom_embedder for SDXL?
@takuma104 can you update clip_text_custom_embedder for SDXL please?
@adhikjoshi @zoezhu Here's an SDXL version.
https://gist.github.com/missionfloyd/72758aec9d714d59f2fddb4785db24ea
would this also work with loras? I am on a quest to make loras work with paint with words and controlnet for more accurate generations. If I can manage to have loras in the embedding I could modify this for loras
I know it doesn't make any sense considering that loras aren't TIs, but I am just looking for any advice at all. I would appreciate your input considering you know more than 99% of people in this topic