Skip to content

Instantly share code, notes, and snippets.

@tcapelle
Created May 6, 2022 17:35
Show Gist options
  • Save tcapelle/28ee549b119455f365312504fab45c42 to your computer and use it in GitHub Desktop.
Save tcapelle/28ee549b119455f365312504fab45c42 to your computer and use it in GitHub Desktop.
import wandb
import timm
import argparse
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback
from torchvision import models
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--num_experiments', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=0.002)
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument('--model_name', type=str, default='resnet18')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--mixup', action='store_true')
parser.add_argument('--force_torchvision', action='store_true')
parser.add_argument('--wandb_project', type=str, default='fine_tune_timm')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
set_seed(args.seed)
for _ in range(args.num_experiments):
with wandb.init(project=args.wandb_project, group="torchvision" if args.force_torchvision else "timm", config=args):
dataset_path = untar_data(URLs.PETS)
files = get_image_files(dataset_path/"images")
pat = re.compile('(^[a-zA-Z]+_*[a-zA-Z]+)')
labels = [pat.match(f.name)[0] for f in files]
dls = ImageDataLoaders.from_name_re(dataset_path, files, r'(^[a-zA-Z]+_*[a-zA-Z]+)', valid_pct=0.2, seed=42, item_tfms=Resize(224))
cbs = [MixedPrecision(), WandbCallback(log_preds=False)]
if args.mixup: cbs.append(MixUp())
if args.force_torchvision:
model_name = getattr(models, args.model_name)
else:
model_name = args.model_name
learn = vision_learner(dls,
model_name,
metrics=[accuracy],
cbs=cbs,
pretrained=True)
learn.fine_tune(args.epochs, args.learning_rate)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment