Last active
June 5, 2024 01:47
-
-
Save tori29umai0123/a0776b5c40d43b901f51aa1b17957d48 to your computer and use it in GitHub Desktop.
nsfw_filter_with_tagger.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import glob | |
import os | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
import onnx | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
import shutil | |
# Image size | |
IMAGE_SIZE = 448 | |
def preprocess_image(image): | |
image = np.array(image) | |
image = image[:, :, ::-1] # Convert BGR to RGB | |
# Padding the image to make it square | |
size = max(image.shape[0:2]) | |
pad_x = size - image.shape[1] | |
pad_y = size - image.shape[0] | |
pad_l = pad_x // 2 | |
pad_t = pad_y // 2 | |
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) | |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 | |
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) | |
image = image.astype(np.float32) | |
return image | |
def check_if_exists(image_path, directories): | |
filename = os.path.basename(image_path) | |
for directory in directories: | |
if os.path.exists(os.path.join(directory, filename)): | |
return True | |
return False | |
def run_batch(path_imgs, input_name, ort_sess, rating_tags, general_tags, thresh, nsfw_dir, sfw_dir): | |
imgs = np.array([im for _, im in path_imgs]) | |
probs = ort_sess.run(None, {input_name: imgs})[0] # ONNX output | |
probs = probs[: len(path_imgs)] | |
for (image_path, _), prob in zip(path_imgs, probs): | |
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)} | |
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0)) | |
max_sfw_score = tag_confidences.get("general", 0) | |
destination = nsfw_dir if max_nsfw_score > max_sfw_score else sfw_dir | |
tag_file_path = os.path.join(destination, os.path.splitext(os.path.basename(image_path))[0] + ".txt") | |
# Save tags in a single line | |
tag_list = [tag for i, tag in enumerate(general_tags) if prob[i] >= thresh] | |
with open(tag_file_path, 'w') as f: | |
f.write(", ".join(tag_list)) | |
# Copy image to the appropriate folder | |
try: | |
shutil.copy(image_path, os.path.join(destination, os.path.basename(image_path))) | |
print(f"{image_path} copied to {destination}.") | |
except Exception as e: | |
print(f"Failed to copy {image_path} to {destination}. Error: {e}") | |
def main(): | |
print("Loading wd14 tagger from Hugging Face") | |
onnx_path = hf_hub_download(MODEL_ID, "model.onnx") | |
csv_path = hf_hub_download(MODEL_ID, "selected_tags.csv") | |
print("Running wd14 tagger ONNX") | |
print(f"Loading ONNX model: {onnx_path}") | |
ort_sess = ort.InferenceSession(onnx_path) | |
with open(csv_path, "r", encoding="utf-8") as f: | |
reader = csv.reader(f) | |
header = next(reader) # Read header row | |
rows = list(reader) | |
assert header == ["tag_id", "name", "category", "count"], f"Unexpected CSV format: {header}" | |
rating_tags = [row[1] for row in rows if row[2] == "9"] | |
general_tags = [row[1] for row in rows if row[2] == "0"] | |
image_paths = glob.glob(os.path.join(input_dir, "*.*")) | |
b_imgs = [] | |
for image_path in tqdm(image_paths, smoothing=0.0): | |
if not check_if_exists(image_path, [sfw_dir, nsfw_dir]): | |
try: | |
image = Image.open(image_path) | |
image = image.convert("RGB") if image.mode != "RGB" else image | |
image = preprocess_image(image) | |
b_imgs.append((image_path, image)) | |
except Exception as e: | |
print(f"Failed to load image: {image_path}, Error: {e}") | |
continue | |
if len(b_imgs) >= batch_size: | |
run_batch(b_imgs, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, general_tags, thresh, nsfw_dir, sfw_dir) | |
b_imgs = [] | |
if b_imgs: | |
run_batch(b_imgs, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, general_tags, thresh, nsfw_dir, sfw_dir) | |
print("Processing complete!") | |
if __name__ == "__main__": | |
MODEL_ID = "SmilingWolf/wd-swinv2-tagger-v3" | |
input_dir = "E:/desktop/dart" | |
sfw_dir = "E:/desktop/sfw" | |
nsfw_dir = "E:/desktop/nsfw" | |
if not os.path.exists(sfw_dir): | |
os.makedirs(sfw_dir) | |
if not os.path.exists(nsfw_dir): | |
os.makedirs(nsfw_dir) | |
batch_size = 16 | |
thresh = 0.35 | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment