|
# |
|
# For licensing see accompanying LICENSE.md file. |
|
# Copyright (C) 2022 Apple Inc. All Rights Reserved. |
|
# |
|
|
|
import argparse |
|
import csv |
|
import random |
|
|
|
from diffusers.pipeline_utils import DiffusionPipeline |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
|
from diffusers.schedulers import ( |
|
DDIMScheduler, |
|
DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
LMSDiscreteScheduler, |
|
PNDMScheduler, |
|
) |
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin |
|
|
|
import gc |
|
import inspect |
|
|
|
import logging |
|
|
|
logging.basicConfig() |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
import numpy as np |
|
import os |
|
|
|
from python_coreml_stable_diffusion.coreml_model import ( |
|
CoreMLModel, |
|
_load_mlpackage, |
|
get_available_compute_units, |
|
) |
|
|
|
import time |
|
import torch # Only used for `torch.from_tensor` in `pipe.scheduler.step()` |
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer |
|
from typing import List, Optional, Union |
|
|
|
class GenerationModel: |
|
def __init__(self): |
|
self.prompt = 0 |
|
self.variations = 1 |
|
self.seed = 0 |
|
self.o = 0 |
|
self.compute_unit = 0 |
|
self.model_version = 0 |
|
self.scheduler = 0 |
|
self.num_inference_steps = 0 |
|
|
|
class CoreMLStableDiffusionPipeline(DiffusionPipeline): |
|
""" Core ML version of |
|
`diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
text_encoder: CoreMLModel, |
|
unet: CoreMLModel, |
|
vae_decoder: CoreMLModel, |
|
feature_extractor: CLIPFeatureExtractor, |
|
safety_checker: Optional[CoreMLModel], |
|
scheduler: Union[DDIMScheduler, |
|
DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
LMSDiscreteScheduler, |
|
PNDMScheduler], |
|
tokenizer: CLIPTokenizer, |
|
): |
|
super().__init__() |
|
|
|
# Register non-Core ML components of the pipeline similar to the original pipeline |
|
self.register_modules( |
|
tokenizer=tokenizer, |
|
scheduler=scheduler, |
|
feature_extractor=feature_extractor, |
|
) |
|
|
|
if safety_checker is None: |
|
# Reproduce original warning: |
|
# https://github.com/huggingface/diffusers/blob/v0.9.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L119 |
|
logger.warning( |
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" |
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" |
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face" |
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" |
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more" |
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." |
|
) |
|
|
|
# Register Core ML components of the pipeline |
|
# self.safety_checker = safety_checker |
|
self.text_encoder = text_encoder |
|
self.unet = unet |
|
self.unet.in_channels = self.unet.expected_inputs["sample"]["shape"][1] |
|
|
|
self.vae_decoder = vae_decoder |
|
|
|
VAE_DECODER_UPSAMPLE_FACTOR = 8 |
|
|
|
# In PyTorch, users can determine the tensor shapes dynamically by default |
|
# In CoreML, tensors have static shapes unless flexible shapes were used during export |
|
# See https://coremltools.readme.io/docs/flexible-inputs |
|
latent_h, latent_w = self.unet.expected_inputs["sample"]["shape"][2:] |
|
self.height = latent_h * VAE_DECODER_UPSAMPLE_FACTOR |
|
self.width = latent_w * VAE_DECODER_UPSAMPLE_FACTOR |
|
|
|
logger.info( |
|
f"Stable Diffusion configured to generate {self.height}x{self.width} images" |
|
) |
|
|
|
def _encode_prompt(self, prompt, num_images_per_prompt, |
|
do_classifier_free_guidance, negative_prompt): |
|
batch_size = len(prompt) if isinstance(prompt, list) else 1 |
|
|
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
return_tensors="np", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
|
|
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: |
|
removed_text = self.tokenizer.batch_decode( |
|
text_input_ids[:, self.tokenizer.model_max_length:]) |
|
logger.warning( |
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}") |
|
text_input_ids = text_input_ids[:, :self.tokenizer. |
|
model_max_length] |
|
|
|
text_embeddings = self.text_encoder( |
|
input_ids=text_input_ids.astype(np.float32))["last_hidden_state"] |
|
|
|
if do_classifier_free_guidance: |
|
uncond_tokens: List[str] |
|
if negative_prompt is None: |
|
uncond_tokens = [""] * batch_size |
|
elif type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
" {type(prompt)}.") |
|
elif isinstance(negative_prompt, str): |
|
uncond_tokens = [negative_prompt] * batch_size |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`.") |
|
else: |
|
uncond_tokens = negative_prompt |
|
|
|
max_length = text_input_ids.shape[-1] |
|
uncond_input = self.tokenizer( |
|
uncond_tokens, |
|
padding="max_length", |
|
max_length=max_length, |
|
truncation=True, |
|
return_tensors="np", |
|
) |
|
|
|
uncond_embeddings = self.text_encoder( |
|
input_ids=uncond_input.input_ids.astype( |
|
np.float32))["last_hidden_state"] |
|
|
|
# For classifier free guidance, we need to do two forward passes. |
|
# Here we concatenate the unconditional and text embeddings into a single batch |
|
# to avoid doing two forward passes |
|
text_embeddings = np.concatenate( |
|
[uncond_embeddings, text_embeddings]) |
|
|
|
text_embeddings = text_embeddings.transpose(0, 2, 1)[:, :, None, :] |
|
|
|
return text_embeddings |
|
|
|
def run_safety_checker(self, image): |
|
if self.safety_checker is not None: |
|
safety_checker_input = self.feature_extractor( |
|
self.numpy_to_pil(image), |
|
return_tensors="np", |
|
) |
|
|
|
safety_checker_outputs = self.safety_checker( |
|
clip_input=safety_checker_input.pixel_values.astype( |
|
np.float16), |
|
images=image.astype(np.float16), |
|
adjustment=np.array([0.]).astype( |
|
np.float16), # defaults to 0 in original pipeline |
|
) |
|
|
|
# Unpack dict |
|
has_nsfw_concept = safety_checker_outputs["has_nsfw_concepts"] |
|
image = safety_checker_outputs["filtered_images"] |
|
concept_scores = safety_checker_outputs["concept_scores"] |
|
|
|
logger.info( |
|
f"Generated image has nsfw concept={has_nsfw_concept.any()}") |
|
else: |
|
has_nsfw_concept = None |
|
|
|
return image, has_nsfw_concept |
|
|
|
def decode_latents(self, latents): |
|
latents = 1 / 0.18215 * latents |
|
image = self.vae_decoder(z=latents.astype(np.float16))["image"] |
|
image = np.clip(image / 2 + 0.5, 0, 1) |
|
image = image.transpose((0, 2, 3, 1)) |
|
|
|
return image |
|
|
|
def prepare_latents(self, |
|
batch_size, |
|
num_channels_latents, |
|
height, |
|
width, |
|
latents=None): |
|
latents_shape = (batch_size, num_channels_latents, self.height // 8, |
|
self.width // 8) |
|
if latents is None: |
|
latents = np.random.randn(*latents_shape).astype(np.float16) |
|
elif latents.shape != latents_shape: |
|
raise ValueError( |
|
f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}" |
|
) |
|
|
|
latents = latents * self.scheduler.init_noise_sigma |
|
|
|
return latents |
|
|
|
def check_inputs(self, prompt, height, width, callback_steps): |
|
if height != self.height or width != self.width: |
|
logger.warning( |
|
"`height` and `width` dimensions (of the output image tensor) are fixed when exporting the Core ML models " \ |
|
"unless flexible shapes are used during export (https://coremltools.readme.io/docs/flexible-inputs). " \ |
|
"This pipeline was provided with Core ML models that generate {self.height}x{self.width} images (user requested {height}x{width})" |
|
) |
|
|
|
if not isinstance(prompt, str) and not isinstance(prompt, list): |
|
raise ValueError( |
|
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" |
|
) |
|
|
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError( |
|
f"`height` and `width` have to be divisible by 8 but are {height} and {width}." |
|
) |
|
|
|
if (callback_steps is None) or (callback_steps is not None and |
|
(not isinstance(callback_steps, int) |
|
or callback_steps <= 0)): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}.") |
|
|
|
def prepare_extra_step_kwargs(self, eta): |
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 |
|
# and should be between [0, 1] |
|
|
|
accepts_eta = "eta" in set( |
|
inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
return extra_step_kwargs |
|
|
|
def __call__( |
|
self, |
|
prompt, |
|
height=512, |
|
width=512, |
|
num_inference_steps=50, |
|
guidance_scale=7.5, |
|
negative_prompt=None, |
|
num_images_per_prompt=1, |
|
eta=0.0, |
|
latents=None, |
|
output_type="pil", |
|
return_dict=True, |
|
callback=None, |
|
callback_steps=1, |
|
**kwargs, |
|
): |
|
# 1. Check inputs. Raise error if not correct |
|
self.check_inputs(prompt, height, width, callback_steps) |
|
|
|
# 2. Define call parameters |
|
batch_size = 1 if isinstance(prompt, str) else len(prompt) |
|
if batch_size > 1 or num_images_per_prompt > 1: |
|
raise NotImplementedError( |
|
"For batched generation of multiple images and/or multiple prompts, please refer to the Swift package." |
|
) |
|
|
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` |
|
# corresponds to doing no classifier free guidance. |
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
# 3. Encode input prompt |
|
text_embeddings = self._encode_prompt( |
|
prompt, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
) |
|
|
|
# 4. Prepare timesteps |
|
self.scheduler.set_timesteps(num_inference_steps) |
|
timesteps = self.scheduler.timesteps |
|
|
|
# 5. Prepare latent variables |
|
num_channels_latents = self.unet.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
latents, |
|
) |
|
|
|
# 6. Prepare extra step kwargs |
|
extra_step_kwargs = self.prepare_extra_step_kwargs(eta) |
|
|
|
# 7. Denoising loop |
|
for i, t in enumerate(self.progress_bar(timesteps)): |
|
# expand the latents if we are doing classifier free guidance |
|
latent_model_input = np.concatenate( |
|
[latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input( |
|
latent_model_input, t) |
|
|
|
# predict the noise residual |
|
noise_pred = self.unet( |
|
sample=latent_model_input.astype(np.float16), |
|
timestep=np.array([t, t], np.float16), |
|
encoder_hidden_states=text_embeddings.astype(np.float16), |
|
)["noise_pred"] |
|
|
|
# perform guidance |
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond) |
|
|
|
# compute the previous noisy sample x_t -> x_t-1 |
|
latents = self.scheduler.step(torch.from_numpy(noise_pred), |
|
t, |
|
torch.from_numpy(latents), |
|
**extra_step_kwargs, |
|
).prev_sample.numpy() |
|
|
|
# call the callback, if provided |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
# 8. Post-processing |
|
image = self.decode_latents(latents) |
|
|
|
# 9. Run safety checker |
|
has_nsfw_concept = False |
|
# image, has_nsfw_concept = self.run_safety_checker(image) |
|
|
|
# 10. Convert to PIL |
|
if output_type == "pil": |
|
image = self.numpy_to_pil(image) |
|
|
|
if not return_dict: |
|
return (image, has_nsfw_concept) |
|
|
|
return StableDiffusionPipelineOutput( |
|
images=image, nsfw_content_detected=has_nsfw_concept) |
|
|
|
|
|
def get_available_schedulers(): |
|
schedulers = {} |
|
for scheduler in [DDIMScheduler, |
|
DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
LMSDiscreteScheduler, |
|
PNDMScheduler]: |
|
schedulers[scheduler().__class__.__name__.replace("Scheduler", "")] = scheduler |
|
return schedulers |
|
|
|
SCHEDULER_MAP = get_available_schedulers() |
|
|
|
def get_coreml_pipe(pytorch_pipe, |
|
mlpackages_dir, |
|
model_version, |
|
compute_unit, |
|
delete_original_pipe=True, |
|
scheduler_override=None): |
|
""" Initializes and returns a `CoreMLStableDiffusionPipeline` from an original |
|
diffusers PyTorch pipeline |
|
""" |
|
# Ensure `scheduler_override` object is of correct type if specified |
|
if scheduler_override is not None: |
|
assert isinstance(scheduler_override, SchedulerMixin) |
|
logger.warning( |
|
"Overriding scheduler in pipeline: " |
|
f"Default={pytorch_pipe.scheduler}, Override={scheduler_override}") |
|
|
|
# Gather configured tokenizer and scheduler attributes from the original pipe |
|
coreml_pipe_kwargs = { |
|
"tokenizer": pytorch_pipe.tokenizer, |
|
"scheduler": pytorch_pipe.scheduler if scheduler_override is None else scheduler_override, |
|
"feature_extractor": pytorch_pipe.feature_extractor, |
|
} |
|
|
|
model_names_to_load = ["text_encoder", "unet", "vae_decoder"] |
|
if getattr(pytorch_pipe, "safety_checker", None) is not None: |
|
model_names_to_load.append("safety_checker") |
|
else: |
|
logger.warning( |
|
f"Original diffusers pipeline for {model_version} does not have a safety_checker, " |
|
"Core ML pipeline will mirror this behavior.") |
|
coreml_pipe_kwargs["safety_checker"] = None |
|
|
|
if delete_original_pipe: |
|
del pytorch_pipe |
|
gc.collect() |
|
logger.info("Removed PyTorch pipe to reduce peak memory consumption") |
|
|
|
# Load Core ML models |
|
logger.info(f"Loading Core ML models in memory from {mlpackages_dir}") |
|
coreml_pipe_kwargs.update({ |
|
model_name: _load_mlpackage( |
|
model_name, |
|
mlpackages_dir, |
|
model_version, |
|
compute_unit, |
|
) |
|
for model_name in model_names_to_load |
|
}) |
|
logger.info("Done.") |
|
|
|
logger.info("Initializing Core ML pipe for image generation") |
|
coreml_pipe = CoreMLStableDiffusionPipeline(**coreml_pipe_kwargs) |
|
logger.info("Done.") |
|
|
|
return coreml_pipe |
|
|
|
|
|
def get_image_path(args, **override_kwargs): |
|
""" mkdir output folder and encode metadata in the filename |
|
""" |
|
out_folder = os.path.join(args.o, "_".join(args.prompt.replace("/", "_").rsplit(" "))) |
|
os.makedirs(out_folder, exist_ok=True) |
|
|
|
out_fname = f"randomSeed_{override_kwargs.get('seed', None) or args.seed}" |
|
out_fname += f"_computeUnit_{override_kwargs.get('compute_unit', None) or args.compute_unit}" |
|
out_fname += f"_modelVersion_{override_kwargs.get('model_version', None) or args.model_version.replace('/', '_')}" |
|
|
|
if args.scheduler is not None: |
|
out_fname += f"_customScheduler_{override_kwargs.get('scheduler', None) or args.scheduler}" |
|
out_fname += f"_numInferenceSteps{override_kwargs.get('num_inference_steps', None) or args.num_inference_steps}" |
|
|
|
return os.path.join(out_folder, out_fname + ".png") |
|
|
|
# Function to open and read csv file to get prompts |
|
def getPromptGenerationModels(inputsFilename): |
|
genModels = [] |
|
with open(inputsFilename, 'r') as file: |
|
reader = csv.reader(file) |
|
for row in reader: |
|
genModel = GenerationModel() |
|
genModel.prompt = row[0] |
|
variations = int(row[1]) |
|
if variations > 0: |
|
genModel.variations = variations |
|
genModels.append(genModel) |
|
|
|
return genModels |
|
|
|
def main(args): |
|
logger.info("Initializing PyTorch pipe for reference configuration") |
|
from diffusers import StableDiffusionPipeline |
|
pytorch_pipe = StableDiffusionPipeline.from_pretrained(args.model_version, |
|
use_auth_token=True) |
|
|
|
user_specified_scheduler = None |
|
if args.scheduler is not None: |
|
user_specified_scheduler = SCHEDULER_MAP[ |
|
args.scheduler].from_config(pytorch_pipe.scheduler.config) |
|
|
|
# Loads the libraries and models into memory [Takes the most time so doing this only once for all images now] |
|
coreml_pipe = get_coreml_pipe(pytorch_pipe=pytorch_pipe, |
|
mlpackages_dir=args.i, |
|
model_version=args.model_version, |
|
compute_unit=args.compute_unit, |
|
scheduler_override=user_specified_scheduler) |
|
|
|
genModels = getPromptGenerationModels("inputs.csv") |
|
|
|
# BEGIN |
|
for genModel in genModels: |
|
logger.info(f"Beginning image generation for prompt: {genModel.prompt}") |
|
for variation in range(genModel.variations): |
|
logger.info(f"Variation #{variation}") |
|
|
|
# Set seed |
|
randomSeed = random.randint(0, 1000000000) |
|
np.random.seed(randomSeed) |
|
|
|
genModel.seed = randomSeed |
|
genModel.o = args.o |
|
genModel.compute_unit = args.compute_unit |
|
genModel.model_version = args.model_version |
|
genModel.scheduler = args.scheduler |
|
genModel.num_inference_steps = args.num_inference_steps |
|
|
|
image = coreml_pipe( |
|
prompt=genModel.prompt, |
|
height=coreml_pipe.height, |
|
width=coreml_pipe.width, |
|
num_inference_steps=args.num_inference_steps, |
|
guidance_scale=args.guidance_scale |
|
) |
|
|
|
out_path = get_image_path(genModel) |
|
logger.info(f"Saving generated image to {out_path}") |
|
image["images"][0].save(out_path) |
|
# END |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument( |
|
"--prompt", |
|
required=False, |
|
help="The text prompt to be used for text-to-image generation.") |
|
parser.add_argument( |
|
"-i", |
|
required=True, |
|
help=("Path to input directory with the .mlpackage files generated by " |
|
"python_coreml_stable_diffusion.torch2coreml")) |
|
parser.add_argument("-o", required=True) |
|
parser.add_argument("--seed", |
|
"-s", |
|
default=93, |
|
type=int, |
|
help="Random seed to be able to reproduce results") |
|
parser.add_argument( |
|
"--model-version", |
|
default="CompVis/stable-diffusion-v1-4", |
|
help= |
|
("The pre-trained model checkpoint and configuration to restore. " |
|
"For available versions: https://huggingface.co/models?search=stable-diffusion" |
|
)) |
|
parser.add_argument( |
|
"--compute-unit", |
|
choices=get_available_compute_units(), |
|
default="ALL", |
|
help=("The compute units to be used when executing Core ML models. " |
|
f"Options: {get_available_compute_units()}")) |
|
parser.add_argument( |
|
"--scheduler", |
|
choices=tuple(SCHEDULER_MAP.keys()), |
|
default=None, |
|
help=("The scheduler to use for running the reverse diffusion process. " |
|
"If not specified, the default scheduler from the diffusers pipeline is utilized")) |
|
parser.add_argument( |
|
"--num-inference-steps", |
|
default=50, |
|
type=int, |
|
help="The number of iterations the unet model will be executed throughout the reverse diffusion process") |
|
parser.add_argument( |
|
"--guidance-scale", |
|
default=7.5, |
|
type=float, |
|
help="Controls the influence of the text prompt on sampling process (0=random images)") |
|
|
|
args = parser.parse_args() |
|
main(args) |