from clip_text_custom_embedder_xl import text_embeddings_sdxl
from diffusers import StableDiffusionXLPipeline
import torch
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to('cuda')
prompt = "((masterpiece, best quality)), white background, close-up, 1girl, litte smile"
neg_prompt = "(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), partial face, partial head, cropped head"
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = text_embeddings_sdxl(pipe, prompt=prompt, neg_prompt=neg_prompt, clip_skip=2)
image = pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
generator=torch.manual_seed(seed)
).images[0]
Created
August 28, 2024 04:58
-
-
Save missionfloyd/72758aec9d714d59f2fddb4785db24ea to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import torch | |
from typing import Optional | |
from transformers import CLIPTokenizer | |
from diffusers import StableDiffusionXLPipeline | |
from diffusers.loaders import TextualInversionLoaderMixin | |
# Adapted from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion_xl.py | |
re_attention = re.compile(r""" | |
\\\(| | |
\\\)| | |
\\\[| | |
\\]| | |
\\\\| | |
\\| | |
\(| | |
\[| | |
:\s*([+-]?[.\d]+)\s*\)| | |
\)| | |
]| | |
[^\\()\[\]:]+| | |
: | |
""", re.X) | |
re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) | |
def parse_prompt_attention(text): | |
""" | |
Parses a string with attention tokens and returns a list of pairs: text and its associated weight. | |
Accepted tokens are: | |
(abc) - increases attention to abc by a multiplier of 1.1 | |
(abc:3.12) - increases attention to abc by a multiplier of 3.12 | |
[abc] - decreases attention to abc by a multiplier of 1.1 | |
\\( - literal character '(' | |
\\[ - literal character '[' | |
\\) - literal character ')' | |
\\] - literal character ']' | |
\\ - literal character '\' | |
anything else - just text | |
>>> parse_prompt_attention('normal text') | |
[['normal text', 1.0]] | |
>>> parse_prompt_attention('an (important) word') | |
[['an ', 1.0], ['important', 1.1], [' word', 1.0]] | |
>>> parse_prompt_attention('(unbalanced') | |
[['unbalanced', 1.1]] | |
>>> parse_prompt_attention('\\(literal\\]') | |
[['(literal]', 1.0]] | |
>>> parse_prompt_attention('(unnecessary)(parens)') | |
[['unnecessaryparens', 1.1]] | |
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') | |
[['a ', 1.0], | |
['house', 1.5730000000000004], | |
[' ', 1.1], | |
['on', 1.0], | |
[' a ', 1.1], | |
['hill', 0.55], | |
[', sun, ', 1.1], | |
['sky', 1.4641000000000006], | |
['.', 1.1]] | |
""" | |
res = [] | |
round_brackets = [] | |
square_brackets = [] | |
round_bracket_multiplier = 1.1 | |
square_bracket_multiplier = 1 / 1.1 | |
def multiply_range(start_position, multiplier): | |
for p in range(start_position, len(res)): | |
res[p][1] *= multiplier | |
for m in re_attention.finditer(text): | |
text = m.group(0) | |
weight = m.group(1) | |
if text.startswith("\\"): | |
res.append([text[1:], 1.0]) | |
elif text == "(": | |
round_brackets.append(len(res)) | |
elif text == "[": | |
square_brackets.append(len(res)) | |
elif weight is not None and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), float(weight)) | |
elif text == ")" and len(round_brackets) > 0: | |
multiply_range(round_brackets.pop(), round_bracket_multiplier) | |
elif text == "]" and len(square_brackets) > 0: | |
multiply_range(square_brackets.pop(), square_bracket_multiplier) | |
else: | |
parts = re.split(re_break, text) | |
for i, part in enumerate(parts): | |
if i > 0: | |
res.append(["BREAK", -1]) | |
res.append([part, 1.0]) | |
for pos in round_brackets: | |
multiply_range(pos, round_bracket_multiplier) | |
for pos in square_brackets: | |
multiply_range(pos, square_bracket_multiplier) | |
if len(res) == 0: | |
res = [["", 1.0]] | |
# merge runs of identical weights | |
i = 0 | |
while i + 1 < len(res): | |
if res[i][1] == res[i + 1][1]: | |
res[i][0] += res[i + 1][0] | |
res.pop(i + 1) | |
else: | |
i += 1 | |
return res | |
def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str): | |
""" | |
Get prompt token ids and weights, this function works for both prompt and negative prompt | |
Args: | |
pipe (CLIPTokenizer) | |
A CLIPTokenizer | |
prompt (str) | |
A prompt string with weights | |
Returns: | |
text_tokens (list) | |
A list contains token ids | |
text_weight (list) | |
A list contains the correspondent weight of token ids | |
Example: | |
import torch | |
from transformers import CLIPTokenizer | |
clip_tokenizer = CLIPTokenizer.from_pretrained( | |
"stablediffusionapi/deliberate-v2" | |
, subfolder = "tokenizer" | |
, dtype = torch.float16 | |
) | |
token_id_list, token_weight_list = get_prompts_tokens_with_weights( | |
clip_tokenizer = clip_tokenizer | |
,prompt = "a (red:1.5) cat"*70 | |
) | |
""" | |
texts_and_weights = parse_prompt_attention(prompt) | |
text_tokens, text_weights = [], [] | |
for word, weight in texts_and_weights: | |
# tokenize and discard the starting and the ending token | |
token = clip_tokenizer(word, truncation=False).input_ids[1:-1] # so that tokenize whatever length prompt | |
# the returned token is a 1d list: [320, 1125, 539, 320] | |
# merge the new tokens to the all tokens holder: text_tokens | |
text_tokens = [*text_tokens, *token] | |
# each token chunk will come with one weight, like ['red cat', 2.0] | |
# need to expand weight for each token. | |
chunk_weights = [weight] * len(token) | |
# append the weight back to the weight holder: text_weights | |
text_weights = [*text_weights, *chunk_weights] | |
return text_tokens, text_weights | |
def group_tokens_and_weights(token_ids: list, weights: list, pad_last_block=False): | |
""" | |
Produce tokens and weights in groups and pad the missing tokens | |
Args: | |
token_ids (list) | |
The token ids from tokenizer | |
weights (list) | |
The weights list from function get_prompts_tokens_with_weights | |
pad_last_block (bool) | |
Control if fill the last token list to 75 tokens with eos | |
Returns: | |
new_token_ids (2d list) | |
new_weights (2d list) | |
Example: | |
token_groups,weight_groups = group_tokens_and_weights( | |
token_ids = token_id_list | |
, weights = token_weight_list | |
) | |
""" | |
bos, eos = 49406, 49407 | |
# this will be a 2d list | |
new_token_ids = [] | |
new_weights = [] | |
while len(token_ids) >= 75: | |
# get the first 75 tokens | |
head_75_tokens = [token_ids.pop(0) for _ in range(75)] | |
head_75_weights = [weights.pop(0) for _ in range(75)] | |
# extract token ids and weights | |
temp_77_token_ids = [bos] + head_75_tokens + [eos] | |
temp_77_weights = [1.0] + head_75_weights + [1.0] | |
# add 77 token and weights chunk to the holder list | |
new_token_ids.append(temp_77_token_ids) | |
new_weights.append(temp_77_weights) | |
# padding the left | |
if len(token_ids) > 0: | |
padding_len = 75 - len(token_ids) if pad_last_block else 0 | |
temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos] | |
new_token_ids.append(temp_77_token_ids) | |
temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0] | |
new_weights.append(temp_77_weights) | |
return new_token_ids, new_weights | |
def text_embeddings_sdxl( | |
pipe: StableDiffusionXLPipeline, | |
prompt: str = "", | |
prompt_2: str = None, | |
neg_prompt: str = "", | |
neg_prompt_2: str = None, | |
num_images_per_prompt: int = 1, | |
device: Optional[torch.device] = None, | |
clip_skip: Optional[int] = None, | |
): | |
""" | |
This function can process long prompt with weights, no length limitation | |
for Stable Diffusion XL | |
Args: | |
pipe (StableDiffusionPipeline) | |
prompt (str) | |
prompt_2 (str) | |
neg_prompt (str) | |
neg_prompt_2 (str) | |
num_images_per_prompt (int) | |
device (torch.device) | |
clip_skip (int) | |
Returns: | |
prompt_embeds (torch.Tensor) | |
neg_prompt_embeds (torch.Tensor) | |
pooled_prompt_embeds (torch.Tensor) | |
negative_pooled_prompt_embeds (torch.Tensor) | |
""" | |
device = device or pipe._execution_device | |
if prompt_2: | |
prompt = f"{prompt} {prompt_2}" | |
if neg_prompt_2: | |
neg_prompt = f"{neg_prompt} {neg_prompt_2}" | |
prompt_t1 = prompt_t2 = prompt | |
neg_prompt_t1 = neg_prompt_t2 = neg_prompt | |
if isinstance(pipe, TextualInversionLoaderMixin): | |
prompt_t1 = pipe.maybe_convert_prompt(prompt_t1, pipe.tokenizer) | |
neg_prompt_t1 = pipe.maybe_convert_prompt(neg_prompt_t1, pipe.tokenizer) | |
prompt_t2 = pipe.maybe_convert_prompt(prompt_t2, pipe.tokenizer_2) | |
neg_prompt_t2 = pipe.maybe_convert_prompt(neg_prompt_t2, pipe.tokenizer_2) | |
eos = pipe.tokenizer.eos_token_id | |
# tokenizer 1 | |
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, prompt_t1) | |
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(pipe.tokenizer, neg_prompt_t1) | |
# tokenizer 2 | |
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, prompt_t2) | |
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(pipe.tokenizer_2, neg_prompt_t2) | |
# padding the shorter one for prompt set 1 | |
prompt_token_len = len(prompt_tokens) | |
neg_prompt_token_len = len(neg_prompt_tokens) | |
if prompt_token_len > neg_prompt_token_len: | |
# padding the neg_prompt with eos token | |
neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len) | |
neg_prompt_weights = neg_prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len) | |
else: | |
# padding the prompt | |
prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - neg_prompt_token_len) | |
prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - neg_prompt_token_len) | |
# padding the shorter one for token set 2 | |
prompt_token_len_2 = len(prompt_tokens_2) | |
neg_prompt_token_len_2 = len(neg_prompt_tokens_2) | |
if prompt_token_len_2 > neg_prompt_token_len_2: | |
# padding the neg_prompt with eos token | |
neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2) | |
neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2) | |
else: | |
# padding the prompt | |
prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2) | |
prompt_weights_2 = prompt_weights + [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2) | |
embeds = [] | |
neg_embeds = [] | |
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(prompt_tokens.copy(), prompt_weights.copy()) | |
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights( | |
neg_prompt_tokens.copy(), neg_prompt_weights.copy() | |
) | |
prompt_token_groups_2, _ = group_tokens_and_weights( | |
prompt_tokens_2.copy(), prompt_weights_2.copy() | |
) | |
neg_prompt_token_groups_2, _ = group_tokens_and_weights( | |
neg_prompt_tokens_2.copy(), neg_prompt_weights_2.copy() | |
) | |
# get prompt embeddings one by one is not working. | |
for i in range(len(prompt_token_groups)): | |
# get positive prompt embeddings with weights | |
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device) | |
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device) | |
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device) | |
# use first text encoder | |
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True) | |
# use second text encoder | |
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True) | |
pooled_prompt_embeds = prompt_embeds_2[0] | |
if clip_skip is None: | |
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2] | |
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2] | |
else: | |
# "2" because SDXL always indexes from the penultimate layer. | |
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-(clip_skip + 2)] | |
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-(clip_skip + 2)] | |
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states] | |
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0) | |
for j in range(len(weight_tensor)): | |
if weight_tensor[j] != 1.0: | |
token_embedding[j] = ( | |
token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j] | |
) | |
token_embedding = token_embedding.unsqueeze(0) | |
embeds.append(token_embedding) | |
# get negative prompt embeddings with weights | |
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device) | |
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device) | |
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device) | |
# use first text encoder | |
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True) | |
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2] | |
# use second text encoder | |
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True) | |
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2] | |
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0] | |
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states] | |
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0) | |
for z in range(len(neg_weight_tensor)): | |
if neg_weight_tensor[z] != 1.0: | |
neg_token_embedding[z] = ( | |
neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z] | |
) | |
neg_token_embedding = neg_token_embedding.unsqueeze(0) | |
neg_embeds.append(neg_token_embedding) | |
prompt_embeds = torch.cat(embeds, dim=1) | |
negative_prompt_embeds = torch.cat(neg_embeds, dim=1) | |
bs_embed, seq_len, _ = prompt_embeds.shape | |
# duplicate text embeddings for each generation per prompt, using mps friendly method | |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
seq_len = negative_prompt_embeds.shape[1] | |
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) | |
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view( | |
bs_embed * num_images_per_prompt, -1 | |
) | |
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1).view( | |
bs_embed * num_images_per_prompt, -1 | |
) | |
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment