Skip to content

Instantly share code, notes, and snippets.

@SteelPh0enix
Created September 24, 2024 20:13
Show Gist options
  • Save SteelPh0enix/b1165093a6d08da39055f358d74914b8 to your computer and use it in GitHub Desktop.
Save SteelPh0enix/b1165093a6d08da39055f358d74914b8 to your computer and use it in GitHub Desktop.
clone huggingface repo and pull Git LFS files via HTTP
import argparse
import shutil
import subprocess
import os
from pathlib import Path
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Clone a HuggingFace repository and download Git LFS files via HTTP, because fuck Git LFS."
)
parser.add_argument(
"repo_url",
type=str,
help="URL to the HuggingFace repository",
)
parser.add_argument(
"--destination_dir",
type=str,
default=None,
help="Destination directory for cloned repo, cwd/name-of-repository if not provided",
)
parser.add_argument(
"--exclude",
type=str,
default=None,
help="Glob pattern for files to be excluded from downloading via HTTP",
)
parser.add_argument(
"--download-rate-limit",
type=str,
default=None,
help="Download rate limit for cURL, for example 100K or 10M",
)
return parser.parse_args()
def get_hf_repo_path(url: str) -> str:
return url.removeprefix("http://huggingface.co/").removeprefix(
"https://huggingface.co/"
)
def convert_http_to_git_hf_url(url: str) -> str:
return f"git@hf.co:{get_hf_repo_path(url)}"
def require_git(func):
def wrapper(*args, **kwargs):
if git_exec := shutil.which("git"):
return func(git_exec=git_exec, *args, **kwargs)
raise RuntimeError(
"Git not found on the system! Install and configure Git first!"
)
return wrapper
def require_git_no_lfs(func):
def wrapper(*args, **kwargs):
if git_exec := shutil.which("git"):
if (
"GIT_LFS_SKIP_SMUDGE" not in os.environ
or os.environ["GIT_LFS_SKIP_SMUDGE"] != "1"
):
os.environ["GIT_LFS_SKIP_SMUDGE"] = "1"
return func(git_exec=git_exec, *args, **kwargs)
raise RuntimeError(
"Git not found on the system! Install and configure Git first!"
)
return wrapper
def require_curl(func):
def wrapper(*args, **kwargs):
if curl_exec := shutil.which("curl"):
return func(curl_exec=curl_exec, *args, **kwargs)
raise RuntimeError(
"cURL not found on the system! Install cURL executable first!"
)
return wrapper
@require_git_no_lfs
def clone_git_repo_without_lfs(url: str, destination_path: Path, git_exec: str = "git"):
print(f"Pulling {url} to {destination_path} using {git_exec}")
subprocess.run(
[git_exec, "clone", convert_http_to_git_hf_url(url), str(destination_path)],
check=True,
)
def parse_git_lfs_files_result(result: str) -> dict[str, bool]:
status: dict[str, bool] = dict()
for file_info in result.splitlines():
hash, file_status, filename = file_info.split(maxsplit=2)
status[filename] = file_status == "*"
return status
@require_git
def get_git_lfs_files_status(repo_path: Path, git_exec: str = "git") -> dict[str, bool]:
print(f"Checking Git LFS files in {repo_path} using {git_exec}")
result = subprocess.run(
[git_exec, "lfs", "ls-files"],
cwd=repo_path,
check=True,
text=True,
capture_output=True,
)
return parse_git_lfs_files_result(result.stdout)
def get_hf_file_base_url(repo_url: str, filename: str) -> str:
return f"https://huggingface.co/{get_hf_repo_path(repo_url)}/resolve/main/{filename}?download=true"
@require_curl
def download_file(
url: str, target_dir: Path, curl_exec: str = "curl", rate_limit: str | None = None
):
curl_args = [
curl_exec,
"--compressed",
"--remote-name",
url,
"--remote-time",
"--styled-output",
"--tcp-fastopen",
] + (["--limit-rate", rate_limit] if rate_limit is not None else [])
subprocess.run(
curl_args,
cwd=target_dir,
check=True,
)
@require_curl
def get_hf_file_true_url(base_url: str, curl_exec: str = "curl") -> str | None:
url_query = subprocess.run(
[curl_exec, "--get", base_url],
check=True,
text=True,
capture_output=True,
).stdout
valid_prefix = "Found. Redirecting to "
if url_query.startswith(valid_prefix):
return url_query.removeprefix(valid_prefix)
return None
def main():
args = parse_arguments()
repo_name = args.repo_url.split("/")[-1]
destination_dir = Path(
args.destination_dir if args.destination_dir is not None else f"./{repo_name}"
).absolute()
if not destination_dir.exists():
clone_git_repo_without_lfs(args.repo_url, destination_dir)
repo_files_status = get_git_lfs_files_status(destination_dir)
for filename, is_valid in repo_files_status.items():
print(
f"{filename} is {'valid' if is_valid else 'NOT valid and will be downloaded'}"
)
if not is_valid:
base_url = get_hf_file_base_url(args.repo_url, filename)
if file_url := get_hf_file_true_url(base_url):
download_file(
file_url,
destination_dir,
rate_limit=args.download_rate_limit,
)
else:
print(
"Couldn't fetch URL for this file, you'll have to download it manually!"
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment