Last active
August 21, 2023 15:22
-
-
Save unai-ndz/ca380c7aa65c9f2aa3b55df2bb0faab1 to your computer and use it in GitHub Desktop.
Wrapper around civitai-minified.py to automate the hashing and downloading of info for your models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import io | |
import json | |
import requests | |
import hashlib | |
import argparse | |
from pathlib import Path | |
import subprocess | |
import re | |
# Compute the hash of all the models in a folder | |
# Use the hash to get the civitai id | |
# Spawn minified-civitai with the ids of your models to download their info | |
# Generate a small markdown file for each model from the merged.json downloaded by minified-civitai | |
# Download the preview images of the models | |
# Both markdown and images get downloaded in the folder where the model is located | |
# If you move or rename the models the hashes will be computed again as the chache is based on filepath | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-p', '--path', type=str, help='The path where your models are stored') | |
parser.add_argument('-c', '--regenerate-cache', type=bool, help='Recreate the cache from scratch') | |
parser.add_argument('-d', '--force-redownload', type=bool, help='Force minified-civitai to update all the models instead of using it\'s cache') | |
args = parser.parse_args() | |
model_dir_path = Path(args.path) | |
CACHE_FILE = 'cache.json' # The name of the JSON cache file | |
def read_chunks(file, size=io.DEFAULT_BUFFER_SIZE): | |
"""Yield pieces of data from a file-like object until EOF.""" | |
while True: | |
chunk = file.read(size) | |
if not chunk: | |
break | |
yield chunk | |
def gen_file_sha256(filname): | |
blocksize = 1 << 20 | |
h = hashlib.sha256() | |
length = 0 | |
with open(filname, 'rb') as f: | |
for block in read_chunks(f, size=blocksize): | |
length += len(block) | |
h.update(block) | |
hash_value = h.hexdigest() | |
# print('sha256: ' + hash_value) | |
return hash_value | |
def_headers = {'User-Agent': 'Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148'} | |
hash_url = 'https://civitai.com/api/v1/model-versions/by-hash/' | |
# curl https://civitai.com/api/v1/model-versions/by-hash/$HASH \ | |
# -H "Content-Type: application/json" \ | |
# -X GET | |
# use this sha256 to get model info from civitai | |
# return: model info dict | |
def get_model_info_by_hash(hash: str): | |
print('Request model info from civitai') | |
if not hash: | |
print('hash is empty') | |
return | |
r = requests.get(hash_url+hash, headers=def_headers) | |
if not r.ok: | |
if r.status_code == 404: | |
# this is not a civitai model | |
print('Civitai does not have this model') | |
return False | |
else: | |
print('Get error code: ' + str(r.status_code)) | |
print(r.text) | |
return | |
# try to get content | |
content = None | |
try: | |
content = r.json() | |
except Exception as e: | |
print('Parse response json failed') | |
print(str(e)) | |
print('response:') | |
print(r.text) | |
return | |
if not content: | |
print('error, content from civitai is None') | |
return | |
# print(content) | |
return content | |
CACHE = {} # A dictionary to hold the cached file data | |
if not args.regenerate_cache: | |
# Load the cache file | |
if os.path.exists(CACHE_FILE): | |
with open(CACHE_FILE, 'r') as f: | |
CACHE = json.load(f) | |
exts = ('.bin', '.pt', '.safetensors', '.ckpt') | |
vae_suffix = '.vae' | |
# scan model to generate SHA256, then use this SHA256 to get model info from civitai | |
def scan_models(): | |
print('Scan models') | |
model_ids = [] | |
model_count = 0 | |
new_model_count = 0 | |
skipped = 0 | |
for root, dirs, files in os.walk(model_dir_path, followlinks=True): | |
for filename in files: | |
# check ext | |
item = os.path.join(root, filename) | |
base, ext = os.path.splitext(item) | |
if ext in exts: | |
# ignore vae file | |
if len(base) > 4: | |
if base[-4:] == vae_suffix: | |
# print('This is a vae file: ' + filename) | |
continue | |
model_count = model_count + 1 | |
# If the file is not already in the cache, compute it's hash and add it | |
print(filename) | |
file_path = os.path.join(root, filename) | |
# Calculate hash | |
if file_path not in CACHE: | |
hash = gen_file_sha256(file_path) | |
if not hash: | |
print('Failed generating SHA256 for model:' + filename) | |
continue | |
CACHE[file_path] = {'hash': hash} | |
new_model_count = new_model_count + 1 | |
# Get it from cache | |
hash = CACHE[file_path]['hash'] | |
if 'id' in CACHE[file_path]: | |
model_ids.append(str(CACHE[file_path]['id'])) | |
print(CACHE[file_path]['id']) | |
else: | |
if 'source' in CACHE[file_path] and CACHE[file_path]['source'] != 'civitai': | |
continue | |
if skipped < 4: | |
model_info = get_model_info_by_hash(hash) | |
else: | |
print('Civitai failed too many times, keep caching hashes stop requesting civitai info') | |
skipped = skipped + 1 | |
# delay 1 second for ti | |
# if model_type == 'ti': | |
# print('Delay 1 second for TI') | |
# time.sleep(1) | |
if model_info is None: | |
print(f'{filename}: Connect to Civitai API service failed. Wait a while and try again') | |
skipped = skipped + 1 | |
continue | |
elif model_info: | |
CACHE[file_path]['source'] = 'civitai' | |
CACHE[file_path]['id'] = model_info['modelId'] | |
model_ids.append(str(CACHE[file_path]['id'])) | |
else: | |
CACHE[file_path]['source'] = 'NA' | |
# Update the cache file | |
with open(CACHE_FILE, 'w') as f: | |
json.dump(CACHE, f) | |
print(f'Scanned {model_count} total models, {new_model_count} new models') | |
if skipped > 0: | |
print(f'Skipped models because of too many civitai errors: {skipped}') | |
return model_ids | |
def sanitize_filename(filename): | |
# Remove single quotes | |
filename = re.sub(r"'", "", filename) | |
# Replace any non-alphanumeric characters with spaces and trim leading/trailing spaces | |
filename = re.sub(r'[^a-zA-Z0-9]', ' ', filename).strip().rstrip(' ').title() | |
# remove spaces | |
filename = re.sub(r' +', '', filename) | |
# If the resulting file name is empty or consists of only spaces, change it to 'default' | |
if not filename or re.fullmatch(r' +', filename): | |
filename = 'default' | |
# Truncate the file name if it is longer than the provided maximum length | |
if len(filename) > 60: | |
filename = filename[:60] | |
return filename | |
def get_folder(type): | |
if type.lower() == 'textualinversion': | |
folder = '00.final/embeddings' | |
elif type.lower() == 'hypernetwork': | |
folder = '00.final/hypernetwork' | |
elif type.lower() == 'checkpoint': | |
folder = '00.final/models' | |
elif type.lower() == 'lora': | |
folder = '00.final/lora' | |
elif type.lower() == 'locon': | |
folder = '00.final/locon' | |
else: | |
folder = '00.final/unknownCivitai' | |
return folder | |
def get_path_for_model(id:int): | |
return next(filter(lambda x: CACHE[x]==id, CACHE.keys()), None) | |
def generate_markdown(): | |
# Load the JSON file | |
with open('merged.json', 'r') as f: | |
models = json.load(f) | |
model_path_by_id = {} | |
for path, model in CACHE.items(): | |
if 'id' in model: | |
model_path_by_id[model['id']] = path | |
# Loop through the models and write a markdown file for each one | |
for _, model_json in models.items(): | |
# Get the name and description values | |
data = model_json['pageData']['props']['pageProps']['trpcState']['json']['queries'][0]['state']['data'] | |
id = str(data['id']) | |
name = sanitize_filename(data['name']) | |
description = data['description'] | |
type = data['type'] | |
url = 'https://civitai.com/models/' + id | |
model_path = Path(model_path_by_id[int(id)]) | |
# Get the directory where the model is | |
dir = Path(model_path).parent | |
basename = Path(model_path).stem | |
# For the purpose of SD each version of a model should get its own markdown file and images | |
# Each one can have different trigger words, images, etc. (The only thing shared is the name, url and description?) | |
# version_headers = ['Version', 'trainedWords', 'baseModel', 'epochs', 'Description'] | |
versions = [] | |
for i, v in enumerate(data['modelVersions']): | |
versions.append({ | |
'version': v['name'] or '', | |
'description': v['description'] or '', | |
'trainedWords': v['trainedWords'] or '', | |
'baseModel': v['baseModel'] or '', | |
'epochs': v['epochs'] or '', | |
'images': v['images'], | |
}) | |
first_version = versions[0] # The one with the more recent release | |
for i, img_data in enumerate(first_version['images']): | |
if i == 0: | |
i = 'preview' | |
img_file = os.path.join(f'{dir}', f'{basename}.{str(i)}.png') | |
if not os.path.exists(img_file): | |
with open(img_file, 'wb') as f: | |
url = f'https://imagecache.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{img_data["url"]}/width=400' | |
f.write(requests.get(url).content) | |
# Write the markdown file | |
with open(f'{dir}/{basename}.md', 'w') as f: | |
# Write title with extra markdown to convert into a link | |
f.write(f'# [{name}][1]\n\n') | |
if first_version['trainedWords']: | |
print(first_version['trainedWords']) | |
trigger_words = ', '.join(map(str, first_version["trainedWords"])) | |
f.write(f'Trigger Words: {trigger_words}\n\n') | |
# Write description | |
f.write(description + '\n\n') | |
f.write('Type: ' + type + '\n\n') | |
# Add links | |
f.write(f'[1]: <{url}/> "Go to the model\'s page"\n\n') | |
# Scan directory for models and download info | |
model_ids = scan_models() | |
if args.force_redownload: | |
cmd = ['python', 'get.py', '-p', '--no-base64-images', '--json'] | |
else: | |
cmd = ['python', 'get.py', '-o', '-p', '--no-base64-images', '--json'] | |
subprocess.run(cmd + model_ids) | |
generate_markdown() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment