Skip to content

Instantly share code, notes, and snippets.

@artsparkAI
Created April 11, 2023 12:09
Show Gist options
  • Save artsparkAI/3d16e4ce379353a289864feeef4a3eb6 to your computer and use it in GitHub Desktop.
Save artsparkAI/3d16e4ce379353a289864feeef4a3eb6 to your computer and use it in GitHub Desktop.
From 8b4ef05aa89479616bbc28bb7e74ed9dc37a83c2 Mon Sep 17 00:00:00 2001
From: artspark <play@notreal.co>
Date: Tue, 11 Apr 2023 12:08:26 +0000
Subject: [PATCH] Backend changes
---
backend/ldm/generate.py | 15 +-
backend/ldm/invoke/generator/base.py | 3 +-
.../invoke/generator/diffusers_pipeline.py | 226 +++++++++++++++++-
backend/ldm/invoke/generator/img2img.py | 27 ++-
backend/ldm/invoke/generator/inpaint.py | 12 +
backend/ldm/invoke/generator/txt2img.py | 3 +-
backend/ldm/invoke/model_manager.py | 6 +-
backend/ldm/models/diffusion/ddpm.py | 2 +-
8 files changed, 268 insertions(+), 26 deletions(-)
diff --git a/backend/ldm/generate.py b/backend/ldm/generate.py
index a5649a92..d60046fd 100644
--- a/backend/ldm/generate.py
+++ b/backend/ldm/generate.py
@@ -341,6 +341,7 @@ class Generate:
infill_method = None,
force_outpaint: bool = False,
enable_image_debugging = False,
+ control = None,
**args,
): # eat up additional cruft
@@ -423,9 +424,9 @@ class Generate:
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
assert threshold >= 0.0, '--threshold must be >=0.0'
- assert (
- 0.0 < strength < 1.0
- ), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
+ #assert (
+ # 0.0 < strength < 1.0
+ #), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0'
assert (
0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]'
@@ -517,7 +518,10 @@ class Generate:
)
# TODO: Hacky selection of operation to perform. Needs to be refactored.
+ print("control", control)
generator = self.select_generator(init_image, mask_image, embiggen, hires_fix, force_outpaint)
+ #generator.model.control = control
+ #generator.model.control = 'canny'
generator.set_variation(
self.seed, variation_amount, with_variations
@@ -534,6 +538,7 @@ class Generate:
iterations=iterations,
seed=self.seed,
sampler=self.sampler,
+ control=control,
steps=steps,
cfg_scale=cfg_scale,
conditioning=(uc, c, extra_conditioning_info),
@@ -753,6 +758,7 @@ class Generate:
):
inpainting_model_in_use = self.sampler.uses_inpainting_model()
+
if hires_fix:
return self._make_txt2img2img()
@@ -817,6 +823,9 @@ class Generate:
def _make_base(self):
return self._load_generator('','Generator')
+ def _make_controlnet(self):
+ return self._load_generator('.controlnet','ControlNet')
+
def _make_txt2img(self):
return self._load_generator('.txt2img','Txt2Img')
diff --git a/backend/ldm/invoke/generator/base.py b/backend/ldm/invoke/generator/base.py
index 467cbe38..254e65bb 100644
--- a/backend/ldm/invoke/generator/base.py
+++ b/backend/ldm/invoke/generator/base.py
@@ -61,7 +61,7 @@ class Generator:
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
- safety_checker:dict=None, orig_prompt=None,
+ safety_checker:dict=None, orig_prompt=None, control=None,
free_gpu_mem: bool=False,
**kwargs):
scope = nullcontext
@@ -79,6 +79,7 @@ class Generator:
threshold = threshold,
perlin = perlin,
attention_maps_callback = attention_maps_callback,
+ control = control,
**kwargs
)
results = []
diff --git a/backend/ldm/invoke/generator/diffusers_pipeline.py b/backend/ldm/invoke/generator/diffusers_pipeline.py
index 69412057..82f87ed3 100644
--- a/backend/ldm/invoke/generator/diffusers_pipeline.py
+++ b/backend/ldm/invoke/generator/diffusers_pipeline.py
@@ -6,7 +6,7 @@ import secrets
import sys
import warnings
from dataclasses import dataclass, field
-from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any
+from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any, Dict
if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
@@ -14,11 +14,13 @@ else:
from typing import ParamSpec
import PIL.Image
+from PIL import Image
import einops
import torch
import torchvision.transforms as T
from diffusers.models import attention
from diffusers.utils.import_utils import is_xformers_available
+import numpy as np
from ...models.diffusion import cross_attention_control
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
@@ -28,10 +30,12 @@ from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToE
# this is to make prompt2prompt and (future) attention maps work
attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention
-from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.models import AutoencoderKL, UNet2DConditionModel, ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
+from diffusers import StableDiffusionControlNetPipeline, StableDiffusionInpaintPipeline
+from diffusers.utils import load_image
from .safety_checker import StableDiffusionSafetyChecker
#from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -44,6 +48,21 @@ from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager
+from web_pdb import set_trace
+
+
+import controlnet_hinter
+#controlnet_hinter.scribble = controlnet_hinter.fake_scribble
+CONTROLNETS = [
+ 'canny',
+ 'depth',
+ 'scribble',
+ 'hed',
+ 'mlsd',
+ 'normal',
+ 'openpose',
+ 'seg',
+];
@@ -233,14 +252,15 @@ class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
Output class for InvokeAI's Stable Diffusion pipeline.
+
Args:
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
after generation completes. Optional.
"""
attention_map_saver: Optional[AttentionMapSaver]
-
-class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
+class StableDiffusionGeneratorPipeline(StableDiffusionControlNetPipeline):
+#class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -280,13 +300,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
+ controlnet: ControlNetModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False,
precision: str = 'float32',
):
- super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
+ #print(vae)
+ super().__init__(vae, text_encoder, tokenizer, unet, controlnet, scheduler,
safety_checker, feature_extractor, requires_safety_checker)
self.register_modules(
@@ -297,6 +319,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
+ controlnet=controlnet,
+ #**controlnets,
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
use_full_precision = (precision == 'float32' or precision == 'autocast')
@@ -316,13 +340,90 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.enable_vae_slicing()
self.enable_attention_slicing()
+ #self.controlnet_cond = None
+ #self.control = None
+ #self.controller = None
+ #self.controlnets = self.controlnet_models_dict()
+ self.control_forward = None
+
+ def controlnet_model(self, controlnet_name: str) -> ControlNetModel:
+ #print(self.unet.device)
+ return ControlNetModel.from_pretrained(
+ #f'takuma104/control_sd15_{controlnet_name}',
+ f'/data/models/control_models/control_nedream_{controlnet_name}',
+ subfolder='controlnet',
+ torch_dtype=self.unet.dtype
+ )
+
+ def controlnet_models_dict(self):
+ return {f"controlnet_{cn}": self.controlnet_model(cn) for cn in CONTROLNETS}
+
+ def controlnet_models(self):
+ return [self.controlnet_model(cn) for cn in CONTROLNETS]
+
+
+ def control_to_model(self, control):
+ if 'scribble' in control:
+ key = 'scribble'
+ elif 'hough' in control:
+ key = 'mlsd'
+ elif 'segmentation' in control:
+ key = 'seg'
+ else:
+ key = control
+
+ return self.controlnet_model(key)
+ #return self.controlnets[key]
+
+
+ def controller(self, control_dict, image):
+
+
+ controls = [c for c, v in control_dict.items() if v]
+ w, h = trim_to_multiple_of(*image.size)
+ ctrl = {}
+ for control in controls:
+ hinter = getattr(controlnet_hinter, f'hint_{control}')
+ control_image = hinter(image)
+ controlnet_cond = self.preprocess(control_image, w, h)
+ control_model = self.control_to_model(control)
+ ctrl[control] = controlnet_cond, control_model
+
+
+ def controller_forward(latents, t, encoder_hidden_states):
+ down_res, mid_res = None, None
+ for control in controls:
+ controlnet_cond, control_model = ctrl[control]
+ model = control_model.to(self._execution_device)
+
+ dr, mr = model(
+ latents,
+ t,
+ encoder_hidden_states=encoder_hidden_states,
+ controlnet_cond=controlnet_cond.to(self._execution_device, dtype=self.controlnet.dtype),
+ return_dict=False,
+ )
+ down_res = dr if down_res is None else [sum(z) for z in zip(down_res, dr)]
+ mid_res = mr if mid_res is None else sum([mid_res, mr])
+
+ unet_inputs = {
+ 'down_block_additional_residuals': down_res,
+ 'mid_block_additional_residual': mid_res,
+ }
+
+ return unet_inputs
+
+ return controller_forward
+
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,
*,
noise: torch.Tensor,
+ noise_func=None,
callback: Callable[[PipelineIntermediateState], None]=None,
- run_id=None) -> InvokeAIStableDiffusionPipelineOutput:
+ control=None,
+ run_id=None, init_image=None) -> InvokeAIStableDiffusionPipelineOutput:
r"""
Function invoked when calling the pipeline for generation.
@@ -335,6 +436,25 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
:param callback:
:param run_id:
"""
+
+ if init_image and control:
+ # This is for the case where someone is in img2img mode but only wants the init image to be used
+ # as a control hint, not as a base image. In the UI this is represented as turning up "strength" to 1.0
+ self.control_forward = self.controller(control, init_image)
+
+ if isinstance(init_image, PIL.Image.Image):
+ init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
+
+ if init_image.dim() == 3:
+ init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
+
+ device = self.unet.device
+ latents_dtype = self.unet.dtype
+ latents = torch.zeros_like(self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype))
+ noise = noise_func(latents)
+
+
+
result_latents, result_attention_map_saver = self.latents_from_embeddings(
latents, num_inference_steps,
conditioning_data,
@@ -347,6 +467,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
+ self.control_forward = None
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
@@ -458,9 +579,29 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return step_output
+ def preprocess(self, image, width, height):
+ #height = height or self.unet.config.sample_size * self.vae_scale_factor
+ #width = width or self.unet.config.sample_size * self.vae_scale_factor
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ image = [np.array(i.convert('RGB').resize((width, height), resample=PIL.Image.Resampling.LANCZOS))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[:, :, :, ::-1] # RGB -> BGR
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image.copy()) # copy: ::-1 workaround
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
def _unet_forward(self, latents, t, text_embeddings):
latents = latents.to(self.unet.device, dtype=self.unet.dtype)
"""predict the noise residual"""
+
if is_inpainting_model(self.unet) and latents.size(1) == 4:
# Pad out normal non-inpainting inputs for an inpainting model.
# FIXME: There are too many layers of functions and we have too many different ways of
@@ -472,7 +613,35 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
).add_mask_channels(latents)
- return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample
+
+ if self.control_forward:
+ unet_inputs = self.control_forward(latents, t, text_embeddings)
+ else:
+ unet_inputs = {}
+ #if self.controlnet_cond is not None:
+ # #controlnet_cond = self.preprocess(self.control_image).to(device=self._execution_device, dtype=self.controlnet.dtype)
+ # controlnet_cond = self.controlnet_cond.to(self._execution_device, dtype=self.controlnet.dtype)
+
+ # down_res, mid_res = self.controller(
+ # latents,
+ # t,
+ # encoder_hidden_states=text_embeddings,
+ # controlnet_cond=controlnet_cond,
+ # return_dict=False,
+ # )
+
+ # unet_inputs = {
+ # 'down_block_additional_residuals': down_res,
+ # 'mid_block_additional_residual': mid_res,
+ # }
+ #else:
+ # unet_inputs = {}
+
+ return self.unet(
+ latents, t,
+ encoder_hidden_states=text_embeddings,
+ **unet_inputs,
+ ).sample
def img2img_from_embeddings(self,
init_image: Union[torch.FloatTensor, PIL.Image.Image],
@@ -481,34 +650,57 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data: ConditioningData,
*, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
- noise_func=None
+ noise_func=None,
+ control=None
) -> InvokeAIStableDiffusionPipelineOutput:
+
+
+ if control:
+ self.control_forward = self.controller(control, init_image)
+
+
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
if init_image.dim() == 3:
init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')
+
+
+
# 6. Prepare latent variables
device = self.unet.device
latents_dtype = self.unet.dtype
initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
noise = noise_func(initial_latents)
- return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
+ result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
conditioning_data,
strength,
noise, run_id, callback)
+ self.control_forward = None
+ return result
+
+ def get_timesteps(self, scheduler, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = scheduler.timesteps[t_start:]
+
+ return timesteps, num_inference_steps - t_start
+
def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps,
conditioning_data: ConditioningData,
strength,
noise: torch.Tensor, run_id=None, callback=None
) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
- img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
+ #img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
+ img2img_pipeline = StableDiffusionControlNetPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
+ timesteps, _ = self.get_timesteps(img2img_pipeline.scheduler, num_inference_steps, strength, device=device)
result_latents, result_attention_maps = self.latents_from_embeddings(
initial_latents, num_inference_steps, conditioning_data,
@@ -535,10 +727,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*, callback: Callable[[PipelineIntermediateState], None] = None,
run_id=None,
noise_func=None,
+ control=None,
+ pil_init_image=None
) -> InvokeAIStableDiffusionPipelineOutput:
device = self.unet.device
latents_dtype = self.unet.dtype
+ if control:
+ self.control_forward = self.controller(control, pil_init_image)
+
+
if isinstance(init_image, PIL.Image.Image):
init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))
@@ -547,9 +745,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if init_image.dim() == 3:
init_image = init_image.unsqueeze(0)
- img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
+ #img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
+ img2img_pipeline = StableDiffusionControlNetPipeline(**self.components)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
+ timesteps, _ = self.get_timesteps(img2img_pipeline.scheduler, num_inference_steps, strength, device=device)
assert img2img_pipeline.scheduler is self.scheduler
@@ -588,6 +787,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
with torch.inference_mode():
image = self.decode_latents(result_latents)
output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
+ self.control_forward = None
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def non_noised_latents_from_image(self, init_image, *, device, dtype):
diff --git a/backend/ldm/invoke/generator/img2img.py b/backend/ldm/invoke/generator/img2img.py
index fedf6d3a..ee6d1379 100644
--- a/backend/ldm/invoke/generator/img2img.py
+++ b/backend/ldm/invoke/generator/img2img.py
@@ -8,6 +8,7 @@ from diffusers import logging
from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
+from diffusers.utils import load_image
class Img2Img(Generator):
@@ -17,7 +18,7 @@ class Img2Img(Generator):
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0,
- attention_maps_callback=None,
+ attention_maps_callback=None, control=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
@@ -28,6 +29,7 @@ class Img2Img(Generator):
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
+ pipeline.control_forward = None
uc, c, extra_conditioning_info = conditioning
conditioning_data = (
@@ -42,11 +44,24 @@ class Img2Img(Generator):
# We're not at the moment because the pipeline automatically resizes init_image if
# necessary, which the x_T input might not match.
logging.set_verbosity_error() # quench safety check warnings
- pipeline_output = pipeline.img2img_from_embeddings(
- init_image, strength, steps, conditioning_data,
- noise_func=self.get_noise_like,
- callback=step_callback
- )
+ if strength == 1.0:
+ pipeline_output = pipeline.image_from_embeddings(
+ latents=torch.zeros_like(x_T, dtype=self.torch_dtype()),
+ noise=x_T,
+ noise_func=self.get_noise_like,
+ num_inference_steps=steps,
+ conditioning_data=conditioning_data,
+ callback=step_callback,
+ init_image=init_image,
+ control=control,
+ )
+ else:
+ pipeline_output = pipeline.img2img_from_embeddings(
+ init_image, strength, steps, conditioning_data,
+ noise_func=self.get_noise_like,
+ callback=step_callback,
+ control=control
+ )
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
attention_maps_callback(pipeline_output.attention_map_saver)
return pipeline.numpy_to_pil(pipeline_output.images)[0]
diff --git a/backend/ldm/invoke/generator/inpaint.py b/backend/ldm/invoke/generator/inpaint.py
index dd43d3cd..f0db3e26 100644
--- a/backend/ldm/invoke/generator/inpaint.py
+++ b/backend/ldm/invoke/generator/inpaint.py
@@ -11,6 +11,7 @@ import numpy as np
import torch
from PIL import Image, ImageFilter, ImageOps, ImageChops
+from diffusers import StableDiffusionInpaintPipeline
from ldm.invoke.generator.diffusers_pipeline import image_resized_to_grid_as_tensor, StableDiffusionGeneratorPipeline, \
ConditioningData
from ldm.invoke.generator.img2img import Img2Img
@@ -183,6 +184,7 @@ class Inpaint(Img2Img):
inpaint_width=None,
inpaint_height=None,
attention_maps_callback=None,
+ control=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and
@@ -244,6 +246,9 @@ class Inpaint(Img2Img):
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
+ pipeline.control_forward = None
+
+ #pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=pipeline.unet.dtype)
# todo: support cross-attention control
uc, c, _ = conditioning
@@ -252,14 +257,20 @@ class Inpaint(Img2Img):
def make_image(x_T):
+ #orig_unet = pipeline.unet
+ #pipeline.unet = pipe_inpaint.unet.to(pipeline.device)
+ #pipeline.unet.in_channels = 4
+
pipeline_output = pipeline.inpaint_from_embeddings(
init_image=init_image,
+ pil_init_image=self.pil_image,
mask=1 - mask, # expects white means "paint here."
strength=strength,
num_inference_steps=steps,
conditioning_data=conditioning_data,
noise_func=self.get_noise_like,
callback=step_callback,
+ control=control,
)
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
@@ -289,6 +300,7 @@ class Inpaint(Img2Img):
infill_method = infill_method,
**kwargs)
+ #pipeline.unet = orig_unet
return result
return make_image
diff --git a/backend/ldm/invoke/generator/txt2img.py b/backend/ldm/invoke/generator/txt2img.py
index 38b4415c..d3f4045f 100644
--- a/backend/ldm/invoke/generator/txt2img.py
+++ b/backend/ldm/invoke/generator/txt2img.py
@@ -16,7 +16,7 @@ class Txt2Img(Generator):
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
- attention_maps_callback=None,
+ attention_maps_callback=None, control=None,
**kwargs):
"""
Returns a function returning an image derived from the prompt and the initial image
@@ -28,6 +28,7 @@ class Txt2Img(Generator):
# noinspection PyTypeChecker
pipeline: StableDiffusionGeneratorPipeline = self.model
pipeline.scheduler = sampler
+ pipeline.control_forward = None
uc, c, extra_conditioning_info = conditioning
conditioning_data = (
diff --git a/backend/ldm/invoke/model_manager.py b/backend/ldm/invoke/model_manager.py
index 2f68df73..9d324e4a 100644
--- a/backend/ldm/invoke/model_manager.py
+++ b/backend/ldm/invoke/model_manager.py
@@ -422,7 +422,7 @@ class ModelManager(object):
return model, width, height, model_hash
- def _load_diffusers_model(self, mconfig):
+ def _load_diffusers_model(self, mconfig, control=True):
name_or_path = self.model_name_or_path(mconfig)
using_fp16 = self.precision == 'float16'
@@ -437,6 +437,10 @@ class ModelManager(object):
safety_checker=None,
local_files_only=not Globals.internet_available
)
+ #if control:
+ # controlnets = StableDiffusionGeneratorPipeline.controlnet_models()
+ # pipeline_args.update(controlnets=controlnets)
+
if 'vae' in mconfig and mconfig['vae'] is not None:
vae = self._load_vae(mconfig['vae'])
pipeline_args.update(vae=vae)
diff --git a/backend/ldm/models/diffusion/ddpm.py b/backend/ldm/models/diffusion/ddpm.py
index 7c7ba9f5..422f61a7 100644
--- a/backend/ldm/models/diffusion/ddpm.py
+++ b/backend/ldm/models/diffusion/ddpm.py
@@ -1013,7 +1013,7 @@ class LatentDiffusion(DDPM):
xc = x
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
- # import pudb; pudb.set_trace()
+ # import pudb; pudb.set_trace() #
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
--
2.39.1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment