Last active
June 14, 2024 02:36
-
-
Save tori29umai0123/f2c08d5c0a1dffb1b38cce8185651d2b to your computer and use it in GitHub Desktop.
ContentSafetyAnalyzer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import os | |
from pathlib import Path | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
# 画像のサイズ設定 | |
IMAGE_SIZE = 448 | |
def preprocess_image(image): | |
image = np.array(image) | |
image = image[:, :, ::-1] # BGRからRGBへ変換 | |
# 画像を正方形にするためのパディングを追加 | |
size = max(image.shape[0:2]) | |
pad_x = size - image.shape[1] | |
pad_y = size - image.shape[0] | |
pad_l = pad_x // 2 | |
pad_t = pad_y // 2 | |
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) | |
# サイズに合わせた補間方法を選択 | |
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 | |
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) | |
image = image.astype(np.float32) | |
return image | |
def process_image(image_path, input_name, ort_sess, rating_tags, character_tags, general_tags, thresh): | |
try: | |
image = Image.open(image_path) | |
image = image.convert("RGB") if image.mode != "RGB" else image | |
image = preprocess_image(image) | |
except Exception as e: | |
print(f"画像を読み込めません: {image_path}, エラー: {e}") | |
return | |
img = np.array([image]) | |
prob = ort_sess.run(None, {input_name: img})[0][0] # ONNXモデルからの出力 | |
# NSFW/SFW判定 | |
tag_confidences = {tag: prob[i] for i, tag in enumerate(rating_tags)} | |
max_nsfw_score = max(tag_confidences.get("questionable", 0), tag_confidences.get("explicit", 0)) | |
max_sfw_score = tag_confidences.get("general", 0) | |
if max_nsfw_score > max_sfw_score: | |
print("NSFWの可能性が高いです") | |
else: | |
print("SFWの可能性が高いです") | |
# 版権キャラクターの可能性を評価 | |
character_tags_with_probs = [] | |
for i, p in enumerate(prob[4:]): | |
if p >= thresh and i >= len(general_tags): | |
tag_index = i - len(general_tags) | |
if tag_index < len(character_tags): | |
tag_name = character_tags[tag_index] | |
prob_percent = round(p * 100, 2) # 確率をパーセンテージに変換 | |
character_tags_with_probs.append((tag_name, f"{prob_percent}%")) | |
if character_tags_with_probs: | |
print(f"版権キャラクター: {character_tags_with_probs}の可能性があります") | |
else: | |
print("版権キャラクターの可能性が低いと思われます") | |
def main(MODEL_ID, image_path, thresh): | |
print("Hugging Faceからモデルをダウンロード中") | |
onnx_path = hf_hub_download(MODEL_ID, "model.onnx") | |
csv_path = hf_hub_download(MODEL_ID, "selected_tags.csv") | |
print("ONNXモデルを実行中") | |
print(f"ONNXモデルのパス: {onnx_path}") | |
ort_sess = ort.InferenceSession(onnx_path) | |
with open(csv_path, "r", encoding="utf-8") as f: | |
reader = csv.reader(f) | |
header = next(reader) | |
rows = list(reader) | |
assert header == ["tag_id", "name", "category", "count"], f"CSVフォーマットが期待と異なります: {header}" | |
rating_tags = [row[1] for row in rows if row[2] == "9"] | |
character_tags = [row[1] for row in rows if row[2] == "4"] | |
general_tags = [row[1] for row in rows[1:] if row[2] == "0"] | |
process_image(image_path, ort_sess.get_inputs()[0].name, ort_sess, rating_tags, character_tags, general_tags, thresh) | |
print("処理完了!") | |
if __name__ == "__main__": | |
MODEL_ID = "SmilingWolf/wd-swinv2-tagger-v3" | |
image_path = "E:/desktop/test.jpg" # 画像のパス | |
thresh = 0.35 # 閾値の設定 | |
main(MODEL_ID, image_path, thresh) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment