Skip to content

Instantly share code, notes, and snippets.

@amenabe22
Last active November 8, 2023 15:41
Show Gist options
  • Save amenabe22/118777b6354c4f980d5e7eab6834beb0 to your computer and use it in GitHub Desktop.
Save amenabe22/118777b6354c4f980d5e7eab6834beb0 to your computer and use it in GitHub Desktop.
Code to analyze similarity in input images
import os
import numpy as np
from PIL import Image
from skimage import io
from brisque import BRISQUE
from skimage.transform import resize
from skimage.metrics import structural_similarity as ssim
from deepface_helper import *
# def load_images(image_folder, target_size=(256, 256)):
def crop_single_image(input, output, filename):
# face detection and alignment
face_objs = DeepFace.extract_faces(
img_path=input, detector_backend=backends[1])
face_index = main_face_index(face_objs)
output_path = os.path.join(output, filename)
bounding_box = face_objs[face_index]['facial_area']
crop_image(input, output_path, bounding_box, 0.7)
def crop_out_images(image_folder, output):
# Load and resize images from a folder
image_files = os.listdir(image_folder)
[crop_single_image(os.path.join(image_folder, image_file), output, image_file)
for image_file in image_files]
return output
def load_images(image_folder, target_size=(256, 256)):
# Load and resize images from a folder
image_files = os.listdir(image_folder)
images = [io.imread(os.path.join(image_folder, image_file))
for image_file in image_files]
images = [resize(image, target_size) for image in images]
return images, image_files
def calculate_similarity(images, win_size=7, channel_axis=2):
# Calculate the similarity matrix using SSIM
n_images = len(images)
similarity_matrix = np.zeros((n_images, n_images))
for i in range(n_images):
for j in range(n_images):
similarity_matrix[i, j] = ssim(
images[i], images[j], win_size=win_size, multichannel=True, data_range=1.0, channel_axis=channel_axis)
return similarity_matrix
def group_images_by_similarity(similarity_matrix, image_files, similarity_threshold):
# Group images based on similarity
n_images = len(image_files)
image_groups = []
grouped = set() # Keep track of images that have already been grouped
for i in range(n_images):
if i not in grouped:
# Start a new group with the first image
group = [{"filename": image_files[i],
"similarity_score": similarity_matrix[1, 1]}]
for j in range(i + 1, n_images):
if similarity_matrix[i, j] > similarity_threshold:
group.append(
{"filename": image_files[j], "similarity_score": similarity_matrix[i, j]})
grouped.add(j) # Mark as grouped
image_groups.append(group)
return image_groups
# get blind quality score of image
def get_brisque_score(input_path):
img = Image.open(input_path).convert("RGB")
brisque = BRISQUE()
score = brisque.score(img)
return score
def main():
similarity_threshold = 0.4
# Load your images into an array "inputs" is the folder where the input images will be
image_folder = os.path.join(os.getcwd(), "training_images_uploads")
cropped_images = crop_out_images(image_folder, "cropped")
images, image_files = load_images(cropped_images)
similarity_matrix = calculate_similarity(images)
image_groups = group_images_by_similarity(
similarity_matrix, image_files, similarity_threshold)
# Create an instance of BRISQUE
low_quality_threshold = 10.0
bad_images = []
# Calculate BRISQUE scores for each image and add them to the dictionaries
for group in image_groups:
if len(group) == 1:
img_path = os.path.join(image_folder, group[0]["filename"])
score = get_brisque_score(img_path)
group[0]["quality_score"] = score
if score < low_quality_threshold: # Adjust the threshold as needed
bad_images.append(group[0])
elif len(group) > 1:
# bad_images.extend(group[1:])
for img_info in group:
filename = img_info["filename"]
img_path = os.path.join(
image_folder, filename)
score = get_brisque_score(img_path)
# dist_img = prepare_image(dist_pil)
img_info["quality_score"] = score
if not img_info in bad_images:
bad_images.append(img_info)
# combine all clusters and label most similar sets as bad
combined = []
for cluster_id, group in enumerate(image_groups):
if len(group) > 1:
group[0]["bad"] = False
for img_info in group[1:]:
img_info["bad"] = True
else:
group[0]["bad"] = False
combined += group
print("Clustered")
print("-"*30)
print(f"Cluster {cluster_id + 1}: {group}")
print("-"*30)
print("Combined")
print(combined)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment