Last active
May 28, 2024 23:20
-
-
Save tori29umai0123/b710efabf3781f137359fa1616da85f4 to your computer and use it in GitHub Desktop.
HF_upload_sdxl_gen_img.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import wget | |
import zipfile | |
import threading | |
import re | |
from time import sleep | |
import shutil | |
from huggingface_hub import HfApi, Repository, upload_file | |
import gen_img | |
outdir = "output_img" | |
ckpt = "animagine-xl-3.1.safetensors" | |
LoRA = "lcm-animaginexl-3_1.safetensors" | |
prompt_file = "dart_prompts.txt" | |
stop_event = threading.Event() | |
upload_event = threading.Event() | |
batch_size = 5000 | |
sleep_time = 5 | |
def download_file(url, filename): | |
try: | |
if not os.path.exists(filename): | |
print(f"{filename} をダウンロード中...") | |
wget.download(url, filename) | |
else: | |
print(f"{filename} は既に存在します。") | |
except Exception as e: | |
print(f"ファイルのダウンロード中にエラーが発生しました: {e}") | |
def check_repository_access(repo_name, token): | |
api = HfApi() | |
try: | |
repo_info = api.repo_info(repo_name, token=token, repo_type='dataset') | |
print("リポジトリアクセスを確認しました。") | |
return True | |
except Exception as e: | |
print(f"リポジトリへのアクセスに失敗しました: {e}") | |
return False | |
def upload_to_hf(zip_file, repo_name, token): | |
repo_local_path = os.path.join(os.getcwd(), repo_name.split('/')[-1]) | |
if not os.path.exists(repo_local_path): | |
Repository(repo_local_path, clone_from=f"https://huggingface.co/datasets/{repo_name}", use_auth_token=token) | |
shutil.copy(zip_file, repo_local_path) | |
repo_zip_path = os.path.join(repo_local_path, os.path.basename(zip_file)) | |
upload_file(path_or_fileobj=repo_zip_path, path_in_repo=os.path.basename(repo_zip_path), repo_id=repo_name, token=token, repo_type='dataset') | |
os.remove(zip_file) | |
print(f"{zip_file}をHugging Faceのデータセットリポジトリ{repo_name}にアップロードしました。") | |
upload_event.set() | |
def zip_images(directory, batch_size, repo_name, token, zip_counter_start): | |
zip_counter = zip_counter_start | |
while not stop_event.is_set() or len([f for f in os.listdir(directory) if f.endswith(('.png', '.jpg', '.jpeg', '.webp')) and not f.endswith('.zip')]) > 0: | |
files = [f for f in os.listdir(directory) if f.endswith(('.png', '.jpg', '.jpeg', '.webp')) and not f.endswith('.zip')] | |
if len(files) >= batch_size: | |
zip_and_upload(files, directory, batch_size, repo_name, token, zip_counter) | |
zip_counter += batch_size | |
elif stop_event.is_set() and len(files) > 0: | |
zip_and_upload(files, directory, len(files), repo_name, token, zip_counter) | |
zip_counter += len(files) | |
sleep(sleep_time) | |
def zip_and_upload(files, directory, count, repo_name, token, zip_counter): | |
files.sort() | |
zip_filename = os.path.join(directory, f'{zip_counter + count:05d}.zip') | |
with zipfile.ZipFile(zip_filename, 'w') as zipf: | |
for file in files[:count]: | |
zipf.write(os.path.join(directory, file), arcname=file) | |
print(f'Created {zip_filename}') | |
for file in files[:count]: | |
os.remove(os.path.join(directory, file)) | |
upload_event.clear() | |
upload_to_hf(zip_filename, repo_name, token) | |
def create_partial_prompt_file(original_file, new_file, start_line): | |
with open(original_file, 'r', encoding='utf-8') as fin: | |
lines = fin.readlines() | |
line_count = len(lines) | |
if start_line > line_count: | |
print(f"指定された開始行 {start_line} は、ファイルの行数 {line_count} を超えています。") | |
return # 何もせずに終了 | |
with open(new_file, 'w', encoding='utf-8') as fout: | |
fout.writelines(lines[start_line-1:]) | |
def main(repo_name, token): | |
repo_local_path = os.path.join(os.getcwd(), repo_name.split('/')[-1]) | |
Repository(repo_local_path, clone_from=f"https://huggingface.co/datasets/{repo_name}", use_auth_token=token) | |
zip_files = [f for f in os.listdir(repo_local_path) if f.endswith('.zip')] | |
if zip_files: | |
zip_numbers = [int(re.search(r'(\d+)\.zip', f).group(1)) for f in zip_files if re.search(r'(\d+)\.zip', f)] | |
zip_counter_start = max(zip_numbers) | |
else: | |
zip_counter_start = 0 | |
start_line = zip_counter_start + 1 | |
# 新しいプロンプトファイルを作成 | |
new_prompt_file = "partial_prompts.txt" | |
create_partial_prompt_file(prompt_file, new_prompt_file, start_line) | |
# 開始行を出力 | |
print(f"プロンプトファイルの処理は行 {start_line} から開始します。") | |
zip_thread = threading.Thread(target=zip_images, args=(outdir, batch_size, repo_name, token, zip_counter_start)) | |
zip_thread.start() | |
# 画像生成用の設定 | |
parser = gen_img.setup_parser() | |
sys.argv = [ | |
'script_name', '--ckpt', ckpt, '--n_iter', '1', '--scale', '3', | |
'--steps', '12', '--outdir', outdir, '--xformers', '--bf16', | |
'--sampler', 'euler_a', '--batch_size', '4', '--vae_batch_size', '2', | |
'--from_file', new_prompt_file, # 更新されたfrom_fileを使用 | |
'--max_embeddings_multiples', '3', '--seed', '42', '--network_module', 'networks.lora', '--network_weights' | |
] + [LoRA] + ['--network_mul', '0.4', '--network_merge'] | |
args = parser.parse_args() | |
gen_img.main(args) | |
stop_event.set() | |
zip_thread.join() | |
print("すべての処理が完了しました。") | |
if __name__ == "__main__": | |
ckpt_url = "https://huggingface.co/cagliostrolab/animagine-xl-3.1/resolve/main/animagine-xl-3.1.safetensors" | |
lora_url = "https://huggingface.co/furusu/SD-LoRA/resolve/main/lcm-animaginexl-3_1.safetensors" | |
download_file(ckpt_url, ckpt) | |
download_file(lora_url, LoRA) | |
if not os.path.exists(outdir): | |
os.makedirs(outdir) | |
repo_name = input("Hugging Faceのリポジトリ名を入力してください: ") | |
token = input("Hugging FaceのAPIトークンを入力してください: ") | |
if not check_repository_access(repo_name, token): | |
print("アクセスが拒否されたか、無効なリポジトリです。実行を停止します。") | |
exit(1) | |
main(repo_name, token) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment