Skip to content

Instantly share code, notes, and snippets.

@aredden
Created November 3, 2022 20:09
Show Gist options
  • Save aredden/f5230a490d7ee7d4c5b00746d10dc4f1 to your computer and use it in GitHub Desktop.
Save aredden/f5230a490d7ee7d4c5b00746d10dc4f1 to your computer and use it in GitHub Desktop.
cli tool for easily cropping faces to 512x512 for ML dataset generation (textual inversion / dreambooth)
import argparse
from typing import List, Union
import cv2, os, pathlib
from PIL import Image
from loguru import logger
from more_itertools import flatten
from mediapipe.python.solutions.face_detection import FaceDetection
import numpy as np
from traceback import format_exception
def format_error(e: Exception, limit: int = 3):
"""Format exception as a string.
Args:
e (Exception): Raised exception object.
limit (int): limit of error call stack lines to be included in the formatted string.
"""
e_fmt = format_exception(type(e), e, tb=e.__traceback__, limit=limit)
return "\n".join(e_fmt)
## issues
# image too small
# solution: drop image
# no detections
# solution: drop image
# image too large
# solution: downscale, then process
# image has many faces
# solution: create different boxes for each face???
def process_image_at_path(image_path):
if not os.path.exists(image_path):
logger.warning(f"Image at path: {image_path} does not exist.")
return []
else:
try:
img = cv2.imread(image_path, flags=cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if min(img.shape[:2]) < 512:
logger.warning(f"Image is too small to be processed ({img.shape}): {image_path}")
elif min(img.shape[:2]) > 1024:
logger.warning(
f"Image is very large, so we are shrinking it before processing it ({img.shape}): {image_path}"
)
while min(img.shape[:2]) > 768:
img = cv2.resize(img, (0,0), fx=0.7, fy=0.7, interpolation = cv2.INTER_LANCZOS4)
logger.info(f"Image shape is now {img.shape}")
assert min(img.shape[:2]) >= 512, f"Accidentally shrunk image too much ({img.shape})"
except Exception as e:
logger.error(f"Issue: {format_error(e)}")
return []
else:
result = process_image(img, path=image_path)
if not result:
logger.warning(f"No faces found for image {image_path}")
return []
if len(result) == 0:
logger.warning(f"No faces detected in the image: {image_path}")
return []
elif len(result) > 1:
logger.warning(f"Multiple faces detected in the image: {image_path}")
return result
else:
logger.success(f"Successfully processed image at path: {image_path}")
return result
def process_image(img, path:str = None)->Union[None,List[np.ndarray]]:
h,w = img.shape[:2]
founds = []
with FaceDetection(min_detection_confidence=0.1) as fd:
outputs = fd.process(img).detections
if not outputs:
logger.warning(f"No faces found for this image: {path}")
return []
else:
for det in outputs:
bbox = det.location_data.relative_bounding_box
bbox.xmin *= w
bbox.ymin *= h
bbox.width *= w
bbox.height *= h
x_center = int(bbox.xmin + (bbox.width / 2))
y_center = int(bbox.ymin + (bbox.height / 2))
top_y = y_center - (512/2)
bottom_y = y_center + (512/2)
left_x = x_center - (512/2)
right_x = x_center + (512/2)
while top_y < 0:
top_y += 1
bottom_y += 1
while left_x < 0:
left_x += 1
right_x += 1
while bottom_y > h:
top_y -= 1
bottom_y -= 1
while right_x > w:
left_x -= 1
right_x -= 1
cropped = img[int(top_y):int(bottom_y), int(left_x):int(right_x), :]
founds.append(cropped)
return founds
def read_images_from_dir(dir, recursive=True):
dir = pathlib.Path(dir) if not isinstance(dir, pathlib.Path) else dir
for i,maybe_image in enumerate(dir.glob("**/*") if recursive else dir.glob("*")):
if not maybe_image.is_dir() and not maybe_image.is_symlink() and maybe_image.suffix in (".jpg", ".jpeg", ".png"):
yield str(maybe_image)
def process_dir(dir, recursive=True):
yield from flatten(map(process_image_at_path,read_images_from_dir(dir,recursive=recursive)))
argp = argparse.ArgumentParser(description="Face detection and cropping utility for ML dataset curation.")
argp.add_argument("-p","--path", help="image directory")
argp.add_argument("-r","--recursive", default=False, action="store_true", help="Whether to recurse through sub-folders.")
argp.add_argument("-o","--output", default="output", help="Path to output directory, is created if it does not exist.")
argp.add_argument("-os","--output_suffix", default=".jpg", help="Output file suffix, (.jpg, .png, .jpeg)")
args = argp.parse_args()
if not os.path.exists(args.path):
logger.error(f"Path: {args.path} does not exist!")
exit(1)
else:
logger.info(f"Path: {args.path} exists!")
if not os.path.exists(args.output):
logger.info(f"Output directory does not exist, making it now.")
os.makedirs(args.output)
elif not os.path.isdir(args.output):
logger.error(f"Output path exists but is not a directory! (Path: {args.output})")
exit(1)
else:
logger.info(f"Output path: {args.output} exists, will write images there.")
images = process_dir(args.path, recursive=args.recursive)
for i,img in enumerate(images):
if img is not None:
im = Image.fromarray(img)
im.save(pathlib.Path(args.output)/f'image-{str(i)}.{args.output_suffix}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment