Skip to content

Instantly share code, notes, and snippets.

@BlackHC
Created August 27, 2024 17:23
Show Gist options
  • Save BlackHC/81b30e2dd7bc1a4decf3f1fd2858f5b4 to your computer and use it in GitHub Desktop.
Save BlackHC/81b30e2dd7bc1a4decf3f1fd2858f5b4 to your computer and use it in GitHub Desktop.
Verify OAI fine-tuning JSONL
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "tiktoken",
# "typer",
# "numpy",
# ]
# ///
"""
Verify and analyze a JSONL dataset for fine-tuning with OpenAI models.
Extends https://cookbook.openai.com/examples/chat_finetuning_data_prep
Run with e.g. `uv run --no-project verify_oai_fine_tuning_jsonl.py --help`
"""
import json
from collections import defaultdict
import numpy as np
import tiktoken
import typer
def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1):
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3
return num_tokens
def verify_and_analyze_dataset(data_path: str, model: str = "gpt-4o-mini-2024-07-18"):
# Load the dataset
with open(data_path, "r", encoding="utf-8") as f:
dataset = [json.loads(line) for line in f]
print(f"Num examples: {len(dataset)}")
# Format error checks
format_errors = defaultdict(int)
for ex in dataset:
if not isinstance(ex, dict):
format_errors["data_type"] += 1
continue
messages = ex.get("messages", None)
if not messages:
format_errors["missing_messages_list"] += 1
continue
for message in messages:
if "role" not in message or "content" not in message:
format_errors["message_missing_key"] += 1
if message.get("role", None) not in (
"system",
"user",
"assistant",
"function",
):
format_errors["unrecognized_role"] += 1
if not message.get("content") and not message.get("function_call"):
format_errors["missing_content"] += 1
if not any(message.get("role", None) == "assistant" for message in messages):
format_errors["example_missing_assistant_message"] += 1
if format_errors:
print("Found errors:")
for k, v in format_errors.items():
print(f"{k}: {v}")
else:
print("No format errors found")
# Token counting and statistics
n_messages = []
convo_lens = []
for ex in dataset:
messages = ex["messages"]
n_messages.append(len(messages))
convo_lens.append(num_tokens_from_messages(messages))
print("\nDistribution of num_messages_per_example:")
print(f"min / max: {min(n_messages)}, {max(n_messages)}")
print(f"mean / median: {np.mean(n_messages):.1f}, {np.median(n_messages):.1f}")
print("\nDistribution of num_total_tokens_per_example:")
print(f"min / max: {min(convo_lens)}, {max(convo_lens)}")
print(f"mean / median: {np.mean(convo_lens):.1f}, {np.median(convo_lens):.1f}")
# Set token limit based on the model
if model.startswith("gpt-4o"):
token_limit = 65536
else:
token_limit = 16385
n_too_long = sum(l > token_limit for l in convo_lens)
print(
f"\n{n_too_long} examples may be over the {token_limit} token limit for {model}, they will be truncated during fine-tuning"
)
# Cost estimation
n_epochs = 3
n_billing_tokens_in_dataset = sum(min(token_limit, length) for length in convo_lens)
print(
f"\nDataset has ~{n_billing_tokens_in_dataset} tokens that will be charged for during training"
)
print(f"By default, you'll train for {n_epochs} epochs on this dataset")
print(
f"By default, you'll be charged for ~{n_epochs * n_billing_tokens_in_dataset} tokens"
)
if __name__ == "__main__":
typer.run(verify_and_analyze_dataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment