Created
October 23, 2023 15:42
-
-
Save andysingal/f9af42a7a4f943ccc8deb4a9235ae78d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
import logging | |
import os | |
import re | |
from os import PathLike | |
from pathlib import Path | |
from typing import Any, Dict, List, Union | |
import numpy as np | |
import torch | |
from controlnet_aux import LineartAnimeDetector | |
from controlnet_aux.processor import MODELS | |
from controlnet_aux.processor import Processor as ControlnetPreProcessor | |
from controlnet_aux.util import HWC3, ade_palette | |
from controlnet_aux.util import resize_image as aux_resize_image | |
from diffusers import (AutoencoderKL, ControlNetModel, DiffusionPipeline, | |
StableDiffusionControlNetImg2ImgPipeline, | |
StableDiffusionPipeline) | |
from PIL import Image | |
from tqdm.rich import tqdm | |
from transformers import (AutoImageProcessor, CLIPImageProcessor, | |
CLIPTextModel, CLIPTokenizer, | |
UperNetForSemanticSegmentation) | |
from animatediff import get_dir | |
from animatediff.dwpose import DWposeDetector | |
from animatediff.models.clip import CLIPSkipTextModel | |
from animatediff.models.unet import UNet3DConditionModel | |
from animatediff.pipelines import AnimationPipeline, load_text_embeddings | |
from animatediff.pipelines.pipeline_controlnet_img2img_reference import \ | |
StableDiffusionControlNetImg2ImgReferencePipeline | |
from animatediff.schedulers import get_scheduler | |
from animatediff.settings import InferenceConfig, ModelConfig | |
from animatediff.utils.convert_from_ckpt import convert_ldm_vae_checkpoint | |
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora | |
from animatediff.utils.model import (ensure_motion_modules, | |
get_checkpoint_weights) | |
from animatediff.utils.util import (get_resized_image, get_resized_image2, | |
get_resized_images, | |
get_tensor_interpolation_method, | |
prepare_dwpose, prepare_ip_adapter, | |
prepare_motion_module, save_frames, | |
save_imgs, save_video) | |
try: | |
import onnxruntime | |
onnxruntime_installed = True | |
except: | |
onnxruntime_installed = False | |
logger = logging.getLogger(__name__) | |
data_dir = get_dir("data") | |
default_base_path = data_dir.joinpath("models/huggingface/stable-diffusion-v1-5") | |
re_clean_prompt = re.compile(r"[^\w\-, ]") | |
controlnet_preprocessor = {} | |
def load_safetensors_lora(text_encoder, unet, lora_path, alpha=0.75, is_animatediff=True): | |
from safetensors.torch import load_file | |
from animatediff.utils.lora_diffusers import (LoRANetwork, | |
create_network_from_weights) | |
sd = load_file(lora_path) | |
print(f"create LoRA network") | |
lora_network: LoRANetwork = create_network_from_weights(text_encoder, unet, sd, multiplier=alpha, is_animatediff=is_animatediff) | |
print(f"load LoRA network weights") | |
lora_network.load_state_dict(sd, False) | |
lora_network.merge_to(alpha) | |
def load_tensors(path:Path,framework="pt",device="cpu"): | |
tensors = {} | |
if path.suffix == ".safetensors": | |
from safetensors import safe_open | |
with safe_open(path, framework=framework, device=device) as f: | |
for k in f.keys(): | |
tensors[k] = f.get_tensor(k) # loads the full tensor given a key | |
else: | |
from torch import load | |
tensors = load(path, device) | |
if "state_dict" in tensors: | |
tensors = tensors["state_dict"] | |
return tensors | |
def load_motion_lora(unet, lora_path:Path, alpha=1.0): | |
state_dict = load_tensors(lora_path) | |
# directly update weight in diffusers model | |
for key in state_dict: | |
# only process lora down key | |
if "up." in key: continue | |
up_key = key.replace(".down.", ".up.") | |
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") | |
model_key = model_key.replace("to_out.", "to_out.0.") | |
layer_infos = model_key.split(".")[:-1] | |
curr_layer = unet | |
try: | |
while len(layer_infos) > 0: | |
temp_name = layer_infos.pop(0) | |
curr_layer = curr_layer.__getattr__(temp_name) | |
except: | |
logger.info(f"{model_key} not found") | |
continue | |
weight_down = state_dict[key] | |
weight_up = state_dict[up_key] | |
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) | |
class SegPreProcessor: | |
def __init__(self): | |
self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") | |
self.processor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small") | |
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs): | |
input_array = np.array(input_image, dtype=np.uint8) | |
input_array = HWC3(input_array) | |
input_array = aux_resize_image(input_array, detect_resolution) | |
pixel_values = self.image_processor(input_array, return_tensors="pt").pixel_values | |
with torch.no_grad(): | |
outputs = self.processor(pixel_values.to(self.processor.device)) | |
outputs.loss = outputs.loss.to("cpu") if outputs.loss is not None else outputs.loss | |
outputs.logits = outputs.logits.to("cpu") if outputs.logits is not None else outputs.logits | |
outputs.hidden_states = outputs.hidden_states.to("cpu") if outputs.hidden_states is not None else outputs.hidden_states | |
outputs.attentions = outputs.attentions.to("cpu") if outputs.attentions is not None else outputs.attentions | |
seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[input_image.size[::-1]])[0] | |
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3 | |
for label, color in enumerate(ade_palette()): | |
color_seg[seg == label, :] = color | |
color_seg = color_seg.astype(np.uint8) | |
color_seg = aux_resize_image(color_seg, image_resolution) | |
color_seg = Image.fromarray(color_seg) | |
return color_seg | |
class NullPreProcessor: | |
def __call__(self, input_image, **kwargs): | |
return input_image | |
class BlurPreProcessor: | |
def __call__(self, input_image, sigma=5.0, **kwargs): | |
import cv2 | |
input_array = np.array(input_image, dtype=np.uint8) | |
input_array = HWC3(input_array) | |
dst = cv2.GaussianBlur(input_array, (0, 0), sigma) | |
return Image.fromarray(dst) | |
class TileResamplePreProcessor: | |
def resize(self, input_image, resolution): | |
import cv2 | |
H, W, C = input_image.shape | |
H = float(H) | |
W = float(W) | |
k = float(resolution) / min(H, W) | |
H *= k | |
W *= k | |
img = cv2.resize(input_image, (int(W), int(H)), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) | |
return img | |
def __call__(self, input_image, down_sampling_rate = 1.0, **kwargs): | |
input_array = np.array(input_image, dtype=np.uint8) | |
input_array = HWC3(input_array) | |
H, W, C = input_array.shape | |
target_res = min(H,W) / down_sampling_rate | |
dst = self.resize(input_array, target_res) | |
return Image.fromarray(dst) | |
def create_controlnet_model(type_str): | |
if type_str == "controlnet_tile": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11f1e_sd15_tile') | |
elif type_str == "controlnet_lineart_anime": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15s2_lineart_anime') | |
elif type_str == "controlnet_ip2p": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11e_sd15_ip2p') | |
elif type_str == "controlnet_openpose": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_openpose') | |
elif type_str == "controlnet_softedge": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_softedge') | |
elif type_str == "controlnet_shuffle": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11e_sd15_shuffle') | |
elif type_str == "controlnet_depth": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11f1p_sd15_depth') | |
elif type_str == "controlnet_canny": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_canny') | |
elif type_str == "controlnet_inpaint": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_inpaint') | |
elif type_str == "controlnet_lineart": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_lineart') | |
elif type_str == "controlnet_mlsd": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_mlsd') | |
elif type_str == "controlnet_normalbae": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_normalbae') | |
elif type_str == "controlnet_scribble": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_scribble') | |
elif type_str == "controlnet_seg": | |
return ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15_seg') | |
elif type_str == "qr_code_monster_v1": | |
return ControlNetModel.from_pretrained('monster-labs/control_v1p_sd15_qrcode_monster') | |
elif type_str == "qr_code_monster_v2": | |
return ControlNetModel.from_pretrained('monster-labs/control_v1p_sd15_qrcode_monster', subfolder='v2') | |
elif type_str =="controlnet_mediapipe_face": | |
return ControlNetModel.from_pretrained("CrucibleAI/ControlNetMediaPipeFace", subfolder="diffusion_sd15") | |
else: | |
raise ValueError(f"unknown controlnet type {type_str}") | |
default_preprocessor_table={ | |
"controlnet_lineart_anime":"lineart_anime", | |
"controlnet_openpose": "openpose_full" if onnxruntime_installed==False else "dwpose", | |
"controlnet_softedge":"softedge_hedsafe", | |
"controlnet_shuffle":"shuffle", | |
"controlnet_depth":"depth_midas", | |
"controlnet_canny":"canny", | |
"controlnet_lineart":"lineart_realistic", | |
"controlnet_mlsd":"mlsd", | |
"controlnet_normalbae":"normal_bae", | |
"controlnet_scribble":"scribble_pidsafe", | |
"controlnet_seg":"upernet_seg", | |
"controlnet_mediapipe_face":"mediapipe_face", | |
} | |
def create_preprocessor_from_name(pre_type): | |
if pre_type == "dwpose": | |
prepare_dwpose() | |
return DWposeDetector() | |
elif pre_type == "upernet_seg": | |
return SegPreProcessor() | |
elif pre_type == "blur": | |
return BlurPreProcessor() | |
elif pre_type == "tile_resample": | |
return TileResamplePreProcessor() | |
elif pre_type == "none": | |
return NullPreProcessor() | |
elif pre_type in MODELS: | |
return ControlnetPreProcessor(pre_type) | |
else: | |
raise ValueError(f"unknown controlnet preprocessor type {pre_type}") | |
def create_default_preprocessor(type_str): | |
if type_str in default_preprocessor_table: | |
pre_type = default_preprocessor_table[type_str] | |
else: | |
pre_type = "none" | |
return create_preprocessor_from_name(pre_type) | |
def get_preprocessor(type_str, device_str, preprocessor_map): | |
if type_str not in controlnet_preprocessor: | |
if preprocessor_map: | |
controlnet_preprocessor[type_str] = create_preprocessor_from_name(preprocessor_map["type"]) | |
if type_str not in controlnet_preprocessor: | |
controlnet_preprocessor[type_str] = create_default_preprocessor(type_str) | |
if hasattr(controlnet_preprocessor[type_str], "processor"): | |
if hasattr(controlnet_preprocessor[type_str].processor, "to"): | |
if device_str: | |
controlnet_preprocessor[type_str].processor.to(device_str) | |
elif hasattr(controlnet_preprocessor[type_str], "to"): | |
if device_str: | |
controlnet_preprocessor[type_str].to(device_str) | |
return controlnet_preprocessor[type_str] | |
def clear_controlnet_preprocessor(type_str = None): | |
global controlnet_preprocessor | |
if type_str == None: | |
for t in controlnet_preprocessor: | |
controlnet_preprocessor[t] = None | |
controlnet_preprocessor={} | |
torch.cuda.empty_cache() | |
else: | |
controlnet_preprocessor[type_str] = None | |
torch.cuda.empty_cache() | |
def get_preprocessed_img(type_str, img, use_preprocessor, device_str, preprocessor_map): | |
if use_preprocessor: | |
param = {} | |
if preprocessor_map: | |
param = preprocessor_map["param"] if "param" in preprocessor_map else {} | |
return get_preprocessor(type_str, device_str, preprocessor_map)(img, **param) | |
else: | |
return img | |
def create_pipeline( | |
base_model: Union[str, PathLike] = default_base_path, | |
model_config: ModelConfig = ..., | |
infer_config: InferenceConfig = ..., | |
use_xformers: bool = True, | |
) -> AnimationPipeline: | |
"""Create an AnimationPipeline from a pretrained model. | |
Uses the base_model argument to load or download the pretrained reference pipeline model.""" | |
# make sure motion_module is a Path and exists | |
logger.info("Checking motion module...") | |
motion_module = data_dir.joinpath(model_config.motion_module) | |
if not (motion_module.exists() and motion_module.is_file()): | |
prepare_motion_module() | |
if not (motion_module.exists() and motion_module.is_file()): | |
# check for safetensors version | |
motion_module = motion_module.with_suffix(".safetensors") | |
if not (motion_module.exists() and motion_module.is_file()): | |
# download from HuggingFace Hub if not found | |
ensure_motion_modules() | |
if not (motion_module.exists() and motion_module.is_file()): | |
# this should never happen, but just in case... | |
raise FileNotFoundError(f"Motion module {motion_module} does not exist or is not a file!") | |
logger.info("Loading tokenizer...") | |
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer") | |
logger.info("Loading text encoder...") | |
text_encoder: CLIPSkipTextModel = CLIPSkipTextModel.from_pretrained(base_model, subfolder="text_encoder") | |
logger.info("Loading VAE...") | |
vae: AutoencoderKL = AutoencoderKL.from_pretrained(base_model, subfolder="vae") | |
logger.info("Loading UNet...") | |
unet: UNet3DConditionModel = UNet3DConditionModel.from_pretrained_2d( | |
pretrained_model_path=base_model, | |
motion_module_path=motion_module, | |
subfolder="unet", | |
unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
) | |
feature_extractor = CLIPImageProcessor.from_pretrained(base_model, subfolder="feature_extractor") | |
# set up scheduler | |
sched_kwargs = infer_config.noise_scheduler_kwargs | |
scheduler = get_scheduler(model_config.scheduler, sched_kwargs) | |
logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') | |
# Load the checkpoint weights into the pipeline | |
if model_config.path is not None: | |
model_path = data_dir.joinpath(model_config.path) | |
logger.info(f"Loading weights from {model_path}") | |
if model_path.is_file(): | |
logger.debug("Loading from single checkpoint file") | |
unet_state_dict, tenc_state_dict, vae_state_dict = get_checkpoint_weights(model_path) | |
elif model_path.is_dir(): | |
logger.debug("Loading from Diffusers model directory") | |
temp_pipeline = StableDiffusionPipeline.from_pretrained(model_path) | |
unet_state_dict, tenc_state_dict, vae_state_dict = ( | |
temp_pipeline.unet.state_dict(), | |
temp_pipeline.text_encoder.state_dict(), | |
temp_pipeline.vae.state_dict(), | |
) | |
del temp_pipeline | |
else: | |
raise FileNotFoundError(f"model_path {model_path} is not a file or directory") | |
# Load into the unet, TE, and VAE | |
logger.info("Merging weights into UNet...") | |
_, unet_unex = unet.load_state_dict(unet_state_dict, strict=False) | |
if len(unet_unex) > 0: | |
raise ValueError(f"UNet has unexpected keys: {unet_unex}") | |
tenc_missing, _ = text_encoder.load_state_dict(tenc_state_dict, strict=False) | |
if len(tenc_missing) > 0: | |
raise ValueError(f"TextEncoder has missing keys: {tenc_missing}") | |
vae_missing, _ = vae.load_state_dict(vae_state_dict, strict=False) | |
if len(vae_missing) > 0: | |
raise ValueError(f"VAE has missing keys: {vae_missing}") | |
else: | |
logger.info("Using base model weights (no checkpoint/LoRA)") | |
if model_config.vae_path: | |
vae_path = data_dir.joinpath(model_config.vae_path) | |
logger.info(f"Loading vae from {vae_path}") | |
if vae_path.is_dir(): | |
vae = AutoencoderKL.from_pretrained(vae_path) | |
else: | |
tensors = load_tensors(vae_path) | |
tensors = convert_ldm_vae_checkpoint(tensors, vae.config) | |
vae.load_state_dict(tensors) | |
# enable xformers if available | |
if use_xformers: | |
logger.info("Enabling xformers memory-efficient attention") | |
unet.enable_xformers_memory_efficient_attention() | |
# lora | |
for l in model_config.lora_map: | |
lora_path = data_dir.joinpath(l) | |
if lora_path.is_file(): | |
logger.info(f"Loading lora {lora_path}") | |
logger.info(f"alpha = {model_config.lora_map[l]}") | |
load_safetensors_lora(text_encoder, unet, lora_path, alpha=model_config.lora_map[l]) | |
# motion lora | |
for l in model_config.motion_lora_map: | |
lora_path = data_dir.joinpath(l) | |
if lora_path.is_file(): | |
logger.info(f"Loading motion lora {lora_path}") | |
logger.info(f"alpha = {model_config.motion_lora_map[l]}") | |
load_motion_lora(unet, lora_path, alpha=model_config.motion_lora_map[l]) | |
logger.info("Creating AnimationPipeline...") | |
pipeline = AnimationPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
feature_extractor=feature_extractor, | |
controlnet_map=None, | |
) | |
# Load TI embeddings | |
load_text_embeddings(pipeline) | |
return pipeline | |
def load_controlnet_models(pipe: AnimationPipeline, model_config: ModelConfig = ...,): | |
# controlnet | |
controlnet_map={} | |
if model_config.controlnet_map: | |
c_image_dir = data_dir.joinpath( model_config.controlnet_map["input_image_dir"] ) | |
for c in model_config.controlnet_map: | |
item = model_config.controlnet_map[c] | |
if type(item) is dict: | |
if item["enable"] == True: | |
img_dir = c_image_dir.joinpath( c ) | |
cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False)) | |
if len(cond_imgs) > 0: | |
logger.info(f"loading {c=} model") | |
controlnet_map[c] = create_controlnet_model( c ) | |
if not controlnet_map: | |
controlnet_map = None | |
pipe.controlnet_map = controlnet_map | |
def unload_controlnet_models(pipe: AnimationPipeline): | |
from animatediff.utils.util import show_gpu | |
show_gpu("before uload controlnet") | |
pipe.controlnet_map = None | |
torch.cuda.empty_cache() | |
show_gpu("after unload controlnet") | |
def create_us_pipeline( | |
model_config: ModelConfig = ..., | |
infer_config: InferenceConfig = ..., | |
use_xformers: bool = True, | |
use_controlnet_ref: bool = False, | |
use_controlnet_tile: bool = False, | |
use_controlnet_line_anime: bool = False, | |
use_controlnet_ip2p: bool = False, | |
) -> DiffusionPipeline: | |
# set up scheduler | |
sched_kwargs = infer_config.noise_scheduler_kwargs | |
scheduler = get_scheduler(model_config.scheduler, sched_kwargs) | |
logger.info(f'Using scheduler "{model_config.scheduler}" ({scheduler.__class__.__name__})') | |
controlnet = [] | |
if use_controlnet_tile: | |
controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11f1e_sd15_tile') ) | |
if use_controlnet_line_anime: | |
controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11p_sd15s2_lineart_anime') ) | |
if use_controlnet_ip2p: | |
controlnet.append( ControlNetModel.from_pretrained('lllyasviel/control_v11e_sd15_ip2p') ) | |
if len(controlnet) == 1: | |
controlnet = controlnet[0] | |
elif len(controlnet) == 0: | |
controlnet = None | |
# Load the checkpoint weights into the pipeline | |
pipeline:DiffusionPipeline | |
if model_config.path is not None: | |
model_path = data_dir.joinpath(model_config.path) | |
logger.info(f"Loading weights from {model_path}") | |
if model_path.is_file(): | |
def is_empty_dir(path): | |
import os | |
return len(os.listdir(path)) == 0 | |
save_path = data_dir.joinpath("models/huggingface/" + model_path.stem + "_" + str(model_path.stat().st_size)) | |
save_path.mkdir(exist_ok=True) | |
if save_path.is_dir() and is_empty_dir(save_path): | |
# StableDiffusionControlNetImg2ImgPipeline.from_single_file does not exist in version 18.2 | |
logger.debug("Loading from single checkpoint file") | |
tmp_pipeline = StableDiffusionPipeline.from_single_file( | |
pretrained_model_link_or_path=str(model_path.absolute()) | |
) | |
tmp_pipeline.save_pretrained(save_path, safe_serialization=True) | |
del tmp_pipeline | |
if use_controlnet_ref: | |
pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained( | |
save_path, | |
controlnet=controlnet, | |
local_files_only=False, | |
load_safety_checker=False, | |
safety_checker=None, | |
) | |
else: | |
pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
save_path, | |
controlnet=controlnet, | |
local_files_only=False, | |
load_safety_checker=False, | |
safety_checker=None, | |
) | |
elif model_path.is_dir(): | |
logger.debug("Loading from Diffusers model directory") | |
if use_controlnet_ref: | |
pipeline = StableDiffusionControlNetImg2ImgReferencePipeline.from_pretrained( | |
model_path, | |
controlnet=controlnet, | |
local_files_only=True, | |
load_safety_checker=False, | |
safety_checker=None, | |
) | |
else: | |
pipeline = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( | |
model_path, | |
controlnet=controlnet, | |
local_files_only=True, | |
load_safety_checker=False, | |
safety_checker=None, | |
) | |
else: | |
raise FileNotFoundError(f"model_path {model_path} is not a file or directory") | |
else: | |
raise ValueError("model_config.path is invalid") | |
pipeline.scheduler = scheduler | |
# enable xformers if available | |
if use_xformers: | |
logger.info("Enabling xformers memory-efficient attention") | |
pipeline.enable_xformers_memory_efficient_attention() | |
# lora | |
for l in model_config.lora_map: | |
lora_path = data_dir.joinpath(l) | |
if lora_path.is_file(): | |
logger.info(f"Loading lora {lora_path}") | |
logger.info(f"alpha = {model_config.lora_map[l]}") | |
load_safetensors_lora(pipeline.text_encoder, pipeline.unet, lora_path, alpha=model_config.lora_map[l],is_animatediff=False) | |
# Load TI embeddings | |
load_text_embeddings(pipeline) | |
return pipeline | |
def seed_everything(seed): | |
import random | |
import numpy as np | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed % (2**32)) | |
random.seed(seed) | |
def controlnet_preprocess( | |
controlnet_map: Dict[str, Any] = None, | |
width: int = 512, | |
height: int = 512, | |
duration: int = 16, | |
out_dir: PathLike = ..., | |
device_str:str=None, | |
): | |
if not controlnet_map: | |
return None, None, None | |
out_dir = Path(out_dir) # ensure out_dir is a Path | |
# { 0 : { "type_str" : IMAGE, "type_str2" : IMAGE } } | |
controlnet_image_map={} | |
controlnet_type_map={} | |
c_image_dir = data_dir.joinpath( controlnet_map["input_image_dir"] ) | |
save_detectmap = controlnet_map["save_detectmap"] if "save_detectmap" in controlnet_map else True | |
preprocess_on_gpu = controlnet_map["preprocess_on_gpu"] if "preprocess_on_gpu" in controlnet_map else True | |
device_str = device_str if preprocess_on_gpu else None | |
for c in controlnet_map: | |
if c == "controlnet_ref": | |
continue | |
item = controlnet_map[c] | |
processed = False | |
if type(item) is dict: | |
if item["enable"] == True: | |
preprocessor_map = item["preprocessor"] if "preprocessor" in item else {} | |
img_dir = c_image_dir.joinpath( c ) | |
cond_imgs = sorted(glob.glob( os.path.join(img_dir, "[0-9]*.png"), recursive=False)) | |
if len(cond_imgs) > 0: | |
controlnet_type_map[c] = { | |
"controlnet_conditioning_scale" : item["controlnet_conditioning_scale"], | |
"control_guidance_start" : item["control_guidance_start"], | |
"control_guidance_end" : item["control_guidance_end"], | |
"control_scale_list" : item["control_scale_list"], | |
"guess_mode" : item["guess_mode"] if "guess_mode" in item else False, | |
} | |
use_preprocessor = item["use_preprocessor"] if "use_preprocessor" in item else True | |
for img_path in tqdm(cond_imgs, desc=f"Preprocessing images ({c})"): | |
frame_no = int(Path(img_path).stem) | |
if frame_no < duration: | |
if frame_no not in controlnet_image_map: | |
controlnet_image_map[frame_no] = {} | |
controlnet_image_map[frame_no][c] = get_preprocessed_img( c, get_resized_image2(img_path, 512) , use_preprocessor, device_str, preprocessor_map) | |
processed = True | |
if save_detectmap and processed: | |
det_dir = out_dir.joinpath(f"{0:02d}_detectmap/{c}") | |
det_dir.mkdir(parents=True, exist_ok=True) | |
for frame_no in tqdm(controlnet_image_map, desc=f"Saving Preprocessed images ({c})"): | |
save_path = det_dir.joinpath(f"{frame_no:08d}.png") | |
if c in controlnet_image_map[frame_no]: | |
controlnet_image_map[frame_no][c].save(save_path) | |
clear_controlnet_preprocessor(c) | |
clear_controlnet_preprocessor() | |
controlnet_ref_map = None | |
if "controlnet_ref" in controlnet_map: | |
r = controlnet_map["controlnet_ref"] | |
if r["enable"] == True: | |
org_name = data_dir.joinpath( r["ref_image"]).stem | |
# ref_image = get_resized_image( data_dir.joinpath( r["ref_image"] ) , width, height) | |
ref_image = get_resized_image2( data_dir.joinpath( r["ref_image"] ) , 512) | |
if ref_image is not None: | |
controlnet_ref_map = { | |
"ref_image" : ref_image, | |
"style_fidelity" : r["style_fidelity"], | |
"attention_auto_machine_weight" : r["attention_auto_machine_weight"], | |
"gn_auto_machine_weight" : r["gn_auto_machine_weight"], | |
"reference_attn" : r["reference_attn"], | |
"reference_adain" : r["reference_adain"], | |
"scale_pattern" : r["scale_pattern"] | |
} | |
if save_detectmap: | |
det_dir = out_dir.joinpath(f"{0:02d}_detectmap/controlnet_ref") | |
det_dir.mkdir(parents=True, exist_ok=True) | |
save_path = det_dir.joinpath(f"{org_name}.png") | |
ref_image.save(save_path) | |
return controlnet_image_map, controlnet_type_map, controlnet_ref_map | |
def ip_adapter_preprocess( | |
ip_adapter_config_map: Dict[str, Any] = None, | |
width: int = 512, | |
height: int = 512, | |
duration: int = 16, | |
out_dir: PathLike = ..., | |
): | |
ip_adapter_map={} | |
processed = False | |
if ip_adapter_config_map: | |
if ip_adapter_config_map["enable"] == True: | |
resized_to_square = ip_adapter_config_map["resized_to_square"] if "resized_to_square" in ip_adapter_config_map else False | |
image_dir = data_dir.joinpath( ip_adapter_config_map["input_image_dir"] ) | |
imgs = sorted(glob.glob( os.path.join(image_dir, "[0-9]*.png"), recursive=False)) | |
if len(imgs) > 0: | |
prepare_ip_adapter() | |
ip_adapter_map["scale"] = ip_adapter_config_map["scale"] | |
ip_adapter_map["is_plus"] = ip_adapter_config_map["is_plus"] | |
ip_adapter_map["is_plus_face"] = ip_adapter_config_map["is_plus_face"] if "is_plus_face" in ip_adapter_config_map else False | |
ip_adapter_map["is_light"] = ip_adapter_config_map["is_light"] if "is_light" in ip_adapter_config_map else False | |
ip_adapter_map["images"] = {} | |
for img_path in tqdm(imgs, desc=f"Preprocessing images (ip_adapter)"): | |
frame_no = int(Path(img_path).stem) | |
if frame_no < duration: | |
if resized_to_square: | |
ip_adapter_map["images"][frame_no] = get_resized_image(img_path, 256, 256) | |
else: | |
ip_adapter_map["images"][frame_no] = get_resized_image2(img_path, 256) | |
processed = True | |
if (ip_adapter_config_map["save_input_image"] == True) and processed: | |
det_dir = out_dir.joinpath(f"{0:02d}_ip_adapter/") | |
det_dir.mkdir(parents=True, exist_ok=True) | |
for frame_no in tqdm(ip_adapter_map["images"], desc=f"Saving Preprocessed images (ip_adapter)"): | |
save_path = det_dir.joinpath(f"{frame_no:08d}.png") | |
ip_adapter_map["images"][frame_no].save(save_path) | |
return ip_adapter_map if processed else None | |
def save_output( | |
pipeline_output, | |
frame_dir:str, | |
out_file:str, | |
output_map : Dict[str,Any] = {}, | |
no_frames : bool = False, | |
save_frames=save_frames, | |
save_video=None, | |
): | |
output_format = "gif" | |
output_fps = 8 | |
if output_map: | |
output_format = output_map["format"] if "format" in output_map else output_format | |
output_fps = output_map["fps"] if "fps" in output_map else output_fps | |
if output_format == "mp4": | |
output_format = "h264" | |
if output_format == "gif": | |
out_file = out_file.with_suffix(".gif") | |
if no_frames is not True: | |
if save_frames: | |
save_frames(pipeline_output,frame_dir) | |
# generate the output filename and save the video | |
if save_video: | |
save_video(pipeline_output, out_file, output_fps) | |
else: | |
pipeline_output[0].save( | |
fp=out_file, format="GIF", append_images=pipeline_output[1:], save_all=True, duration=(1 / output_fps * 1000), loop=0 | |
) | |
else: | |
if save_frames: | |
save_frames(pipeline_output,frame_dir) | |
from animatediff.rife.ffmpeg import (FfmpegEncoder, VideoCodec, | |
codec_extn) | |
out_file = out_file.with_suffix( f".{codec_extn(output_format)}" ) | |
logger.info("Creating ffmpeg encoder...") | |
encoder = FfmpegEncoder( | |
frames_dir=frame_dir, | |
out_file=out_file, | |
codec=output_format, | |
in_fps=output_fps, | |
out_fps=output_fps, | |
lossless=False, | |
param= output_map["encode_param"] if "encode_param" in output_map else {} | |
) | |
logger.info("Encoding interpolated frames with ffmpeg...") | |
result = encoder.encode() | |
logger.debug(f"ffmpeg result: {result}") | |
def run_inference( | |
pipeline: AnimationPipeline, | |
prompt: str = ..., | |
n_prompt: str = ..., | |
seed: int = -1, | |
steps: int = 25, | |
guidance_scale: float = 7.5, | |
width: int = 512, | |
height: int = 512, | |
duration: int = 16, | |
idx: int = 0, | |
out_dir: PathLike = ..., | |
context_frames: int = -1, | |
context_stride: int = 3, | |
context_overlap: int = 4, | |
context_schedule: str = "uniform", | |
clip_skip: int = 1, | |
prompt_map: Dict[int, str] = None, | |
controlnet_map: Dict[str, Any] = None, | |
controlnet_image_map: Dict[str,Any] = None, | |
controlnet_type_map: Dict[str,Any] = None, | |
controlnet_ref_map: Dict[str,Any] = None, | |
no_frames :bool = False, | |
ip_adapter_map: Dict[str,Any] = None, | |
output_map: Dict[str,Any] = None, | |
is_single_prompt_mode: bool = False, | |
): | |
out_dir = Path(out_dir) # ensure out_dir is a Path | |
seed_everything(seed) | |
pipeline_output = pipeline( | |
prompt=prompt, | |
negative_prompt=n_prompt, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
width=width, | |
height=height, | |
video_length=duration, | |
return_dict=False, | |
context_frames=context_frames, | |
context_stride=context_stride + 1, | |
context_overlap=context_overlap, | |
context_schedule=context_schedule, | |
clip_skip=clip_skip, | |
prompt_map=prompt_map, | |
controlnet_type_map=controlnet_type_map, | |
controlnet_image_map=controlnet_image_map, | |
controlnet_ref_map=controlnet_ref_map, | |
controlnet_max_samples_on_vram=controlnet_map["max_samples_on_vram"] if "max_samples_on_vram" in controlnet_map else 999, | |
controlnet_max_models_on_vram=controlnet_map["max_models_on_vram"] if "max_models_on_vram" in controlnet_map else 99, | |
controlnet_is_loop = controlnet_map["is_loop"] if "is_loop" in controlnet_map else True, | |
ip_adapter_map=ip_adapter_map, | |
interpolation_factor=1, | |
is_single_prompt_mode=is_single_prompt_mode, | |
) | |
logger.info("Generation complete, saving...") | |
# Trim and clean up the prompt for filename use | |
prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")] | |
prompt_str = "_".join((prompt_tags[:6]))[:50] | |
frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}") | |
out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}") | |
save_output( pipeline_output, frame_dir, out_file, output_map, no_frames, save_frames, save_video ) | |
logger.info(f"Saved sample to {out_file}") | |
return pipeline_output | |
def run_upscale( | |
org_imgs: List[str], | |
pipeline: DiffusionPipeline, | |
prompt_map: Dict[int, str] = None, | |
n_prompt: str = ..., | |
seed: int = -1, | |
steps: int = 25, | |
strength: float = 0.5, | |
guidance_scale: float = 7.5, | |
clip_skip: int = 1, | |
us_width: int = 512, | |
us_height: int = 512, | |
idx: int = 0, | |
out_dir: PathLike = ..., | |
upscale_config:Dict[str, Any]=None, | |
use_controlnet_ref: bool = False, | |
use_controlnet_tile: bool = False, | |
use_controlnet_line_anime: bool = False, | |
use_controlnet_ip2p: bool = False, | |
no_frames:bool = False, | |
output_map: Dict[str,Any] = None, | |
): | |
from animatediff.utils.lpw_stable_diffusion import lpw_encode_prompt | |
pipeline.set_progress_bar_config(disable=True) | |
images = get_resized_images(org_imgs, us_width, us_height) | |
steps = steps if "steps" not in upscale_config else upscale_config["steps"] | |
scheduler = scheduler if "scheduler" not in upscale_config else upscale_config["scheduler"] | |
guidance_scale = guidance_scale if "guidance_scale" not in upscale_config else upscale_config["guidance_scale"] | |
clip_skip = clip_skip if "clip_skip" not in upscale_config else upscale_config["clip_skip"] | |
strength = strength if "strength" not in upscale_config else upscale_config["strength"] | |
controlnet_conditioning_scale = [] | |
guess_mode = [] | |
control_guidance_start = [] | |
control_guidance_end = [] | |
# for controlnet tile | |
if use_controlnet_tile: | |
controlnet_conditioning_scale.append(upscale_config["controlnet_tile"]["controlnet_conditioning_scale"]) | |
guess_mode.append(upscale_config["controlnet_tile"]["guess_mode"]) | |
control_guidance_start.append(upscale_config["controlnet_tile"]["control_guidance_start"]) | |
control_guidance_end.append(upscale_config["controlnet_tile"]["control_guidance_end"]) | |
# for controlnet line_anime | |
if use_controlnet_line_anime: | |
controlnet_conditioning_scale.append(upscale_config["controlnet_line_anime"]["controlnet_conditioning_scale"]) | |
guess_mode.append(upscale_config["controlnet_line_anime"]["guess_mode"]) | |
control_guidance_start.append(upscale_config["controlnet_line_anime"]["control_guidance_start"]) | |
control_guidance_end.append(upscale_config["controlnet_line_anime"]["control_guidance_end"]) | |
# for controlnet ip2p | |
if use_controlnet_ip2p: | |
controlnet_conditioning_scale.append(upscale_config["controlnet_ip2p"]["controlnet_conditioning_scale"]) | |
guess_mode.append(upscale_config["controlnet_ip2p"]["guess_mode"]) | |
control_guidance_start.append(upscale_config["controlnet_ip2p"]["control_guidance_start"]) | |
control_guidance_end.append(upscale_config["controlnet_ip2p"]["control_guidance_end"]) | |
# for controlnet ref | |
ref_image = None | |
if use_controlnet_ref: | |
if not upscale_config["controlnet_ref"]["use_frame_as_ref_image"] and not upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]: | |
ref_image = get_resized_images([ data_dir.joinpath( upscale_config["controlnet_ref"]["ref_image"] ) ], us_width, us_height)[0] | |
generator = torch.manual_seed(seed) | |
seed_everything(seed) | |
prompt_embeds_map = {} | |
prompt_map = dict(sorted(prompt_map.items())) | |
negative = None | |
do_classifier_free_guidance=guidance_scale > 1.0 | |
prompt_list = [prompt_map[key_frame] for key_frame in prompt_map.keys()] | |
prompt_embeds,neg_embeds = lpw_encode_prompt( | |
pipe=pipeline, | |
prompt=prompt_list, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
negative_prompt=n_prompt, | |
) | |
if do_classifier_free_guidance: | |
negative = neg_embeds.chunk(neg_embeds.shape[0], 0) | |
positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0) | |
else: | |
negative = [None] | |
positive = prompt_embeds.chunk(prompt_embeds.shape[0], 0) | |
for i, key_frame in enumerate(prompt_map): | |
prompt_embeds_map[key_frame] = positive[i] | |
key_first =list(prompt_map.keys())[0] | |
key_last =list(prompt_map.keys())[-1] | |
def get_current_prompt_embeds( | |
center_frame: int = 0, | |
video_length : int = 0 | |
): | |
key_prev = key_last | |
key_next = key_first | |
for p in prompt_map.keys(): | |
if p > center_frame: | |
key_next = p | |
break | |
key_prev = p | |
dist_prev = center_frame - key_prev | |
if dist_prev < 0: | |
dist_prev += video_length | |
dist_next = key_next - center_frame | |
if dist_next < 0: | |
dist_next += video_length | |
if key_prev == key_next or dist_prev + dist_next == 0: | |
return prompt_embeds_map[key_prev] | |
rate = dist_prev / (dist_prev + dist_next) | |
return get_tensor_interpolation_method()(prompt_embeds_map[key_prev],prompt_embeds_map[key_next], rate) | |
line_anime_processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") | |
out_images=[] | |
logger.info(f"{use_controlnet_tile=}") | |
logger.info(f"{use_controlnet_line_anime=}") | |
logger.info(f"{use_controlnet_ip2p=}") | |
logger.info(f"{controlnet_conditioning_scale=}") | |
logger.info(f"{guess_mode=}") | |
logger.info(f"{control_guidance_start=}") | |
logger.info(f"{control_guidance_end=}") | |
for i, org_image in enumerate(tqdm(images, desc=f"Upscaling...")): | |
cur_positive = get_current_prompt_embeds(i, len(images)) | |
# logger.info(f"w {condition_image.size[0]}") | |
# logger.info(f"h {condition_image.size[1]}") | |
condition_image = [] | |
if use_controlnet_tile: | |
condition_image.append( org_image ) | |
if use_controlnet_line_anime: | |
condition_image.append( line_anime_processor(org_image) ) | |
if use_controlnet_ip2p: | |
condition_image.append( org_image ) | |
if not use_controlnet_ref: | |
out_image = pipeline( | |
prompt_embeds=cur_positive, | |
negative_prompt_embeds=negative[0], | |
image=org_image, | |
control_image=condition_image, | |
width=org_image.size[0], | |
height=org_image.size[1], | |
strength=strength, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0], | |
guess_mode= guess_mode[0], | |
control_guidance_start= control_guidance_start if len(control_guidance_start) > 1 else control_guidance_start[0], | |
control_guidance_end= control_guidance_end if len(control_guidance_end) > 1 else control_guidance_end[0], | |
).images[0] | |
else: | |
if upscale_config["controlnet_ref"]["use_1st_frame_as_ref_image"]: | |
if i == 0: | |
ref_image = org_image | |
elif upscale_config["controlnet_ref"]["use_frame_as_ref_image"]: | |
ref_image = org_image | |
out_image = pipeline( | |
prompt_embeds=cur_positive, | |
negative_prompt_embeds=negative[0], | |
image=org_image, | |
control_image=condition_image, | |
width=org_image.size[0], | |
height=org_image.size[1], | |
strength=strength, | |
num_inference_steps=steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
controlnet_conditioning_scale= controlnet_conditioning_scale if len(controlnet_conditioning_scale) > 1 else controlnet_conditioning_scale[0], | |
guess_mode= guess_mode[0], | |
# control_guidance_start= control_guidance_start, | |
# control_guidance_end= control_guidance_end, | |
### for controlnet ref | |
ref_image=ref_image, | |
attention_auto_machine_weight = upscale_config["controlnet_ref"]["attention_auto_machine_weight"], | |
gn_auto_machine_weight = upscale_config["controlnet_ref"]["gn_auto_machine_weight"], | |
style_fidelity = upscale_config["controlnet_ref"]["style_fidelity"], | |
reference_attn= upscale_config["controlnet_ref"]["reference_attn"], | |
reference_adain= upscale_config["controlnet_ref"]["reference_adain"], | |
).images[0] | |
out_images.append(out_image) | |
# Trim and clean up the prompt for filename use | |
prompt_tags = [re_clean_prompt.sub("", tag).strip().replace(" ", "-") for tag in prompt_map[list(prompt_map.keys())[0]].split(",")] | |
prompt_str = "_".join((prompt_tags[:6]))[:50] | |
# generate the output filename and save the video | |
out_file = out_dir.joinpath(f"{idx:02d}_{seed}_{prompt_str}") | |
frame_dir = out_dir.joinpath(f"{idx:02d}-{seed}-upscaled") | |
save_output( out_images, frame_dir, out_file, output_map, no_frames, save_imgs, None ) | |
logger.info(f"Saved sample to {out_file}") | |
return out_images |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment