Created
November 3, 2022 20:09
-
-
Save aredden/f5230a490d7ee7d4c5b00746d10dc4f1 to your computer and use it in GitHub Desktop.
cli tool for easily cropping faces to 512x512 for ML dataset generation (textual inversion / dreambooth)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from typing import List, Union | |
import cv2, os, pathlib | |
from PIL import Image | |
from loguru import logger | |
from more_itertools import flatten | |
from mediapipe.python.solutions.face_detection import FaceDetection | |
import numpy as np | |
from traceback import format_exception | |
def format_error(e: Exception, limit: int = 3): | |
"""Format exception as a string. | |
Args: | |
e (Exception): Raised exception object. | |
limit (int): limit of error call stack lines to be included in the formatted string. | |
""" | |
e_fmt = format_exception(type(e), e, tb=e.__traceback__, limit=limit) | |
return "\n".join(e_fmt) | |
## issues | |
# image too small | |
# solution: drop image | |
# no detections | |
# solution: drop image | |
# image too large | |
# solution: downscale, then process | |
# image has many faces | |
# solution: create different boxes for each face??? | |
def process_image_at_path(image_path): | |
if not os.path.exists(image_path): | |
logger.warning(f"Image at path: {image_path} does not exist.") | |
return [] | |
else: | |
try: | |
img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
if min(img.shape[:2]) < 512: | |
logger.warning(f"Image is too small to be processed ({img.shape}): {image_path}") | |
elif min(img.shape[:2]) > 1024: | |
logger.warning( | |
f"Image is very large, so we are shrinking it before processing it ({img.shape}): {image_path}" | |
) | |
while min(img.shape[:2]) > 768: | |
img = cv2.resize(img, (0,0), fx=0.7, fy=0.7, interpolation = cv2.INTER_LANCZOS4) | |
logger.info(f"Image shape is now {img.shape}") | |
assert min(img.shape[:2]) >= 512, f"Accidentally shrunk image too much ({img.shape})" | |
except Exception as e: | |
logger.error(f"Issue: {format_error(e)}") | |
return [] | |
else: | |
result = process_image(img, path=image_path) | |
if not result: | |
logger.warning(f"No faces found for image {image_path}") | |
return [] | |
if len(result) == 0: | |
logger.warning(f"No faces detected in the image: {image_path}") | |
return [] | |
elif len(result) > 1: | |
logger.warning(f"Multiple faces detected in the image: {image_path}") | |
return result | |
else: | |
logger.success(f"Successfully processed image at path: {image_path}") | |
return result | |
def process_image(img, path:str = None)->Union[None,List[np.ndarray]]: | |
h,w = img.shape[:2] | |
founds = [] | |
with FaceDetection(min_detection_confidence=0.1) as fd: | |
outputs = fd.process(img).detections | |
if not outputs: | |
logger.warning(f"No faces found for this image: {path}") | |
return [] | |
else: | |
for det in outputs: | |
bbox = det.location_data.relative_bounding_box | |
bbox.xmin *= w | |
bbox.ymin *= h | |
bbox.width *= w | |
bbox.height *= h | |
x_center = int(bbox.xmin + (bbox.width / 2)) | |
y_center = int(bbox.ymin + (bbox.height / 2)) | |
top_y = y_center - (512/2) | |
bottom_y = y_center + (512/2) | |
left_x = x_center - (512/2) | |
right_x = x_center + (512/2) | |
while top_y < 0: | |
top_y += 1 | |
bottom_y += 1 | |
while left_x < 0: | |
left_x += 1 | |
right_x += 1 | |
while bottom_y > h: | |
top_y -= 1 | |
bottom_y -= 1 | |
while right_x > w: | |
left_x -= 1 | |
right_x -= 1 | |
cropped = img[int(top_y):int(bottom_y), int(left_x):int(right_x), :] | |
founds.append(cropped) | |
return founds | |
def read_images_from_dir(dir, recursive=True): | |
dir = pathlib.Path(dir) if not isinstance(dir, pathlib.Path) else dir | |
for i,maybe_image in enumerate(dir.glob("**/*") if recursive else dir.glob("*")): | |
if not maybe_image.is_dir() and not maybe_image.is_symlink() and maybe_image.suffix in (".jpg", ".jpeg", ".png"): | |
yield str(maybe_image) | |
def process_dir(dir, recursive=True): | |
yield from flatten(map(process_image_at_path,read_images_from_dir(dir,recursive=recursive))) | |
argp = argparse.ArgumentParser(description="Face detection and cropping utility for ML dataset curation.") | |
argp.add_argument("-p","--path", help="image directory") | |
argp.add_argument("-r","--recursive", default=False, action="store_true", help="Whether to recurse through sub-folders.") | |
argp.add_argument("-o","--output", default="output", help="Path to output directory, is created if it does not exist.") | |
argp.add_argument("-os","--output_suffix", default=".jpg", help="Output file suffix, (.jpg, .png, .jpeg)") | |
args = argp.parse_args() | |
if not os.path.exists(args.path): | |
logger.error(f"Path: {args.path} does not exist!") | |
exit(1) | |
else: | |
logger.info(f"Path: {args.path} exists!") | |
if not os.path.exists(args.output): | |
logger.info(f"Output directory does not exist, making it now.") | |
os.makedirs(args.output) | |
elif not os.path.isdir(args.output): | |
logger.error(f"Output path exists but is not a directory! (Path: {args.output})") | |
exit(1) | |
else: | |
logger.info(f"Output path: {args.output} exists, will write images there.") | |
images = process_dir(args.path, recursive=args.recursive) | |
for i,img in enumerate(images): | |
if img is not None: | |
im = Image.fromarray(img) | |
im.save(pathlib.Path(args.output)/f'image-{str(i)}.{args.output_suffix}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment