Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active August 10, 2024 13:56
Show Gist options
  • Save pszemraj/aad41ff842e578baa86cd9307cb895f9 to your computer and use it in GitHub Desktop.
Save pszemraj/aad41ff842e578baa86cd9307cb895f9 to your computer and use it in GitHub Desktop.
Process log file from nanoT5
"""
parses the standard main.log from nanoT5 and makes some plots
pip install matplotlib pandas seaborn
"""
import argparse
import logging
import os
import re
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.ticker import FuncFormatter
# Set up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
def parse_log_file(file_path):
with open(file_path, "r") as file:
log_data = file.readlines()
pattern = re.compile(
r"\[(?P<timestamp>[\d-]+ [\d:,]+)\]\[Main\]\[INFO\] - \[train\] Step (?P<step>\d+) out of \d+ \| Loss --> (?P<loss>[\d.]+) \| Grad_l2 --> (?P<grad_l2>[\d.]+) \| Weights_l2 --> (?P<weights_l2>[\d.]+) \| Lr --> (?P<lr>[\d.]+) \| Seconds_per_step --> (?P<seconds_per_step>[\d.]+)"
)
extracted_data = []
for line in log_data:
match = pattern.search(line)
if match:
extracted_data.append(match.groupdict())
if not extracted_data:
logging.error(f"No matching log data found in {file_path}")
return pd.DataFrame()
df = pd.DataFrame(extracted_data)
# Check for missing columns and handle gracefully
required_columns = [
"step",
"loss",
"grad_l2",
"weights_l2",
"lr",
"seconds_per_step",
]
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
logging.error(f"Missing columns in parsed data: {missing_columns}")
return pd.DataFrame()
# Convert columns to numeric, handling errors
for column in required_columns:
df[column] = pd.to_numeric(df[column], errors="coerce")
return df.dropna()
def set_plot_style(dark_mode=False):
if dark_mode:
plt.style.use("dark_background")
sns.set_palette("deep")
else:
plt.style.use("seaborn-v0_8-whitegrid") # Updated style name
sns.set_palette("deep")
sns.set_context("paper", font_scale=1.5)
plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = ["Times New Roman"] + plt.rcParams["font.serif"]
def format_scientific(x, pos):
return f"{x:.0e}"
def plot_metrics(df, output_dir, dark_mode=False):
if df.empty:
logging.warning("Empty dataframe, skipping plotting.")
return
set_plot_style(dark_mode)
metrics = [
("loss", "Loss", "viridis"),
("lr", "Learning Rate", "plasma"),
("grad_l2", "Gradient L2 Norm", "cividis"),
("weights_l2", "Weights L2 Norm", "magma"),
("seconds_per_step", "Seconds per Step", "inferno"),
]
for metric, title, cmap in metrics:
fig, ax = plt.subplots(figsize=(12, 8))
sns.lineplot(
x="step",
y=metric,
data=df,
ax=ax,
linewidth=2,
color=plt.get_cmap(cmap)(0.6),
)
ax.set_xlabel("Steps", fontweight="bold")
ax.set_ylabel(title, fontweight="bold")
ax.set_title(f"{title} over Steps", fontweight="bold", fontsize=16)
if metric != "lr":
ax.set_yscale("log")
ax.yaxis.set_major_formatter(FuncFormatter(format_scientific))
ax.xaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ",")))
ax.grid(True, which="both", ls="-", alpha=0.2)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
plt.tight_layout()
plt.savefig(
f"{output_dir}/{metric}_over_steps{'_dark' if dark_mode else ''}.png",
dpi=300,
bbox_inches="tight",
)
plt.close()
def main():
parser = argparse.ArgumentParser(
description="Process log file and generate publication-quality plots."
)
parser.add_argument("log_file", type=str, help="Path to the log file")
parser.add_argument(
"-o", "--output_dir", default=None, type=str, help="Directory to save the plots"
)
parser.add_argument(
"--dark", action="store_true", help="Enable dark mode for plots"
)
args = parser.parse_args()
log_file = Path(args.log_file)
if not log_file.exists():
logging.error(f"Log file does not exist: {log_file}")
return
output_dir = (
Path(args.output_dir) if args.output_dir is not None else log_file.parent
)
df = parse_log_file(log_file)
if df.empty:
logging.error(
f"Failed to parse log file or log file contains no data: {log_file}"
)
return
plot_metrics(df, output_dir, dark_mode=args.dark)
df.to_csv(output_dir / "training_metrics.csv", index=False)
if __name__ == "__main__":
main()
#!/bin/bash
# List of log files
log_files=(
"./logs/2024-08-02/00-08-01-/main.log"
"./logs/2024-08-02/00-09-06-/main.log"
"./logs/2024-08-02/17-49-19-/main.log"
"./logs/2024-08-02/18-03-37-/main.log"
"./logs/2024-08-03/01-12-52-/main.log"
"./logs/2024-08-03/05-45-49-/main.log"
"./logs/2024-08-03/05-49-42-/main.log"
"./logs/2024-08-03/09-39-51-/main.log"
"./logs/2024-08-03/15-59-54-/main.log"
"./logs/2024-08-04/02-34-31-/main.log"
"./logs/2024-08-04/02-39-07-/main.log"
)
# Log file for script execution
log_output="process_log_output.txt"
# Check if the log_output file already exists, and remove it if it does
if [ -f "$log_output" ]; then
rm "$log_output"
fi
# Loop through each log file and run the Python script
for file in "${log_files[@]}"; do
echo "Processing $file" | tee -a "$log_output"
if python parse_nanoT5_log.py "$file" >> "$log_output" 2>&1; then
echo "Successfully processed $file" | tee -a "$log_output"
else
echo "Failed to process $file" | tee -a "$log_output"
fi
done
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment