Last active
November 8, 2023 15:41
-
-
Save amenabe22/118777b6354c4f980d5e7eab6834beb0 to your computer and use it in GitHub Desktop.
Code to analyze similarity in input images
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import numpy as np | |
from PIL import Image | |
from skimage import io | |
from brisque import BRISQUE | |
from skimage.transform import resize | |
from skimage.metrics import structural_similarity as ssim | |
from deepface_helper import * | |
# def load_images(image_folder, target_size=(256, 256)): | |
def crop_single_image(input, output, filename): | |
# face detection and alignment | |
face_objs = DeepFace.extract_faces( | |
img_path=input, detector_backend=backends[1]) | |
face_index = main_face_index(face_objs) | |
output_path = os.path.join(output, filename) | |
bounding_box = face_objs[face_index]['facial_area'] | |
crop_image(input, output_path, bounding_box, 0.7) | |
def crop_out_images(image_folder, output): | |
# Load and resize images from a folder | |
image_files = os.listdir(image_folder) | |
[crop_single_image(os.path.join(image_folder, image_file), output, image_file) | |
for image_file in image_files] | |
return output | |
def load_images(image_folder, target_size=(256, 256)): | |
# Load and resize images from a folder | |
image_files = os.listdir(image_folder) | |
images = [io.imread(os.path.join(image_folder, image_file)) | |
for image_file in image_files] | |
images = [resize(image, target_size) for image in images] | |
return images, image_files | |
def calculate_similarity(images, win_size=7, channel_axis=2): | |
# Calculate the similarity matrix using SSIM | |
n_images = len(images) | |
similarity_matrix = np.zeros((n_images, n_images)) | |
for i in range(n_images): | |
for j in range(n_images): | |
similarity_matrix[i, j] = ssim( | |
images[i], images[j], win_size=win_size, multichannel=True, data_range=1.0, channel_axis=channel_axis) | |
return similarity_matrix | |
def group_images_by_similarity(similarity_matrix, image_files, similarity_threshold): | |
# Group images based on similarity | |
n_images = len(image_files) | |
image_groups = [] | |
grouped = set() # Keep track of images that have already been grouped | |
for i in range(n_images): | |
if i not in grouped: | |
# Start a new group with the first image | |
group = [{"filename": image_files[i], | |
"similarity_score": similarity_matrix[1, 1]}] | |
for j in range(i + 1, n_images): | |
if similarity_matrix[i, j] > similarity_threshold: | |
group.append( | |
{"filename": image_files[j], "similarity_score": similarity_matrix[i, j]}) | |
grouped.add(j) # Mark as grouped | |
image_groups.append(group) | |
return image_groups | |
# get blind quality score of image | |
def get_brisque_score(input_path): | |
img = Image.open(input_path).convert("RGB") | |
brisque = BRISQUE() | |
score = brisque.score(img) | |
return score | |
def main(): | |
similarity_threshold = 0.4 | |
# Load your images into an array "inputs" is the folder where the input images will be | |
image_folder = os.path.join(os.getcwd(), "training_images_uploads") | |
cropped_images = crop_out_images(image_folder, "cropped") | |
images, image_files = load_images(cropped_images) | |
similarity_matrix = calculate_similarity(images) | |
image_groups = group_images_by_similarity( | |
similarity_matrix, image_files, similarity_threshold) | |
# Create an instance of BRISQUE | |
low_quality_threshold = 10.0 | |
bad_images = [] | |
# Calculate BRISQUE scores for each image and add them to the dictionaries | |
for group in image_groups: | |
if len(group) == 1: | |
img_path = os.path.join(image_folder, group[0]["filename"]) | |
score = get_brisque_score(img_path) | |
group[0]["quality_score"] = score | |
if score < low_quality_threshold: # Adjust the threshold as needed | |
bad_images.append(group[0]) | |
elif len(group) > 1: | |
# bad_images.extend(group[1:]) | |
for img_info in group: | |
filename = img_info["filename"] | |
img_path = os.path.join( | |
image_folder, filename) | |
score = get_brisque_score(img_path) | |
# dist_img = prepare_image(dist_pil) | |
img_info["quality_score"] = score | |
if not img_info in bad_images: | |
bad_images.append(img_info) | |
# combine all clusters and label most similar sets as bad | |
combined = [] | |
for cluster_id, group in enumerate(image_groups): | |
if len(group) > 1: | |
group[0]["bad"] = False | |
for img_info in group[1:]: | |
img_info["bad"] = True | |
else: | |
group[0]["bad"] = False | |
combined += group | |
print("Clustered") | |
print("-"*30) | |
print(f"Cluster {cluster_id + 1}: {group}") | |
print("-"*30) | |
print("Combined") | |
print(combined) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment