Created
August 25, 2021 08:17
-
-
Save adimyth/e6e8cfbca8bd77ad933cf210d675b036 to your computer and use it in GitHub Desktop.
Transfer Learning with Poisson Loss (Pytorch Lightning)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from glob import glob | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import pytorch_lightning as pl | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping | |
from pytorch_lightning.loggers import WandbLogger | |
from torchmetrics.functional import accuracy, auroc | |
from torchmetrics.functional import mean_absolute_error, mean_squared_error | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.utils.data import DataLoader, random_split | |
import torchvision.models as models | |
from torchvision import transforms | |
from torchvision.datasets import ImageFolder | |
import wandb | |
# Kaggle specific way to use Wandb | |
from kaggle_secrets import UserSecretsClient | |
user_secrets = UserSecretsClient() | |
wandb_api = user_secrets.get_secret("wandb-key") | |
wandb.login(key=wandb_api) | |
wandb.init(project="count-the-green-boxes") | |
wandb_logger = WandbLogger(project="count-green-boxes-lightning", job_type="train") | |
# Constants | |
RANDOM_STATE = 42 | |
NUM_CLASSES = 98 | |
# Seed Everything | |
pl.seed_everything(RANDOM_STATE) | |
# DataModule | |
class DataModule(pl.LightningDataModule): | |
def __init__(self, batch_size: int = 64, data_dir: str = ""): | |
super().__init__() | |
self.data_dir = data_dir | |
self.batch_size = batch_size | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize(size=256), | |
transforms.CenterCrop(size=224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
def setup(self, stage=None): | |
dataset = ImageFolder(self.data_dir) | |
num_train = int(0.8 * len(dataset)) | |
num_valid = len(dataset) - num_train | |
self.train, self.val = random_split(dataset, [num_train, num_valid]) | |
self.train.dataset.transform = self.transform | |
self.val.dataset.transform = self.transform | |
def train_dataloader(self): | |
return DataLoader( | |
self.train, batch_size=self.batch_size, shuffle=True, num_workers=2 | |
) | |
def val_dataloader(self): | |
return DataLoader(self.val, batch_size=self.batch_size, num_workers=2) | |
# Model | |
class CountModel(pl.LightningModule): | |
def __init__( | |
self, input_shape, num_classes: int = 100, learning_rate: float = 2e-4 | |
): | |
super().__init__() | |
# log hyperparameters | |
self.save_hyperparameters() | |
self.learning_rate = learning_rate | |
self.dim = input_shape | |
self.num_classes = num_classes | |
self.feature_extractor = models.resnet18(pretrained=True) | |
self.feature_extractor.eval() | |
for param in self.feature_extractor.parameters(): | |
param.requires_grad = False | |
n_sizes = self._get_conv_output(input_shape) | |
self.classifier = nn.Linear(n_sizes, num_classes) | |
# returns the size of the output tensor going into Linear layer from the conv block. | |
def _get_conv_output(self, shape): | |
batch_size = 1 | |
input = torch.autograd.Variable(torch.rand(batch_size, *shape)) | |
output_feat = self._forward_features(input) | |
n_size = output_feat.data.view(batch_size, -1).size(1) | |
return n_size | |
# returns the feature tensor from the conv block | |
def _forward_features(self, x): | |
x = self.feature_extractor(x) | |
return x | |
# will be used during inference | |
def forward(self, x): | |
x = self._forward_features(x) | |
x = x.view(x.size(0), -1) | |
x = F.log_softmax(self.classifier(x), dim=1) | |
return x | |
# logic for a single training step | |
def training_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
preds = torch.argmax(logits, dim=1) | |
loss = F.poisson_nll_loss(preds, y, log_input=False) | |
loss.requires_grad = True | |
# training metrics | |
acc = accuracy(preds, y) | |
mae = mean_absolute_error(preds, y) | |
mse = mean_squared_error(preds, y) | |
self.log("train_loss", loss, on_step=True, on_epoch=True, logger=True) | |
self.log("train_acc", acc, on_step=True, on_epoch=True, logger=True) | |
self.log("train_mae", mae, on_step=True, on_epoch=True, logger=True) | |
self.log("train_mse", mse, on_step=True, on_epoch=True, logger=True) | |
return loss | |
# logic for a single validation step | |
def validation_step(self, batch, batch_idx): | |
x, y = batch | |
logits = self(x) | |
preds = torch.argmax(logits, dim=1) | |
loss = F.poisson_nll_loss(preds, y, log_input=False) | |
# validation metrics | |
acc = accuracy(preds, y) | |
mae = mean_absolute_error(preds, y) | |
mse = mean_squared_error(preds, y) | |
self.log("val_loss", loss, on_step=True, on_epoch=True, logger=True) | |
self.log("val_acc", acc, on_step=True, on_epoch=True, logger=True) | |
self.log("val_mae", mae, on_step=True, on_epoch=True, logger=True) | |
self.log("val_mse", mse, on_step=True, on_epoch=True, logger=True) | |
return loss | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) | |
return optimizer | |
# Test Dataset - because ImageFolder doesn't work with missing folder label | |
class TestDataset(torch.utils.data.Dataset): | |
def __init__(self, main_dir: str = ""): | |
self.main_dir = main_dir | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize(size=256), | |
transforms.CenterCrop(size=224), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
self.total_imgs = sorted(glob(f"{main_dir}/*.png")) | |
def __len__(self): | |
return len(self.total_imgs) | |
def __getitem__(self, idx): | |
img_loc = self.total_imgs[idx] | |
image = Image.open(img_loc).convert("RGB") | |
tensor_image = self.transform(image) | |
return tensor_image | |
if __name__ == "__main__": | |
# Data Setup | |
datamodule = DataModule( | |
batch_size=1024, data_dir="../input/count-the-blue-boxes/train/train/" | |
) | |
datamodule.setup() | |
# Callbacks | |
early_stop_callback = EarlyStopping( | |
monitor="val_loss", patience=3, verbose=False, mode="min" | |
) | |
checkpoint_callback = ModelCheckpoint( | |
monitor="val_loss", | |
filename="model-{epoch:02d}-{val_loss:.2f}", | |
save_top_k=3, | |
mode="min", | |
) | |
# Training | |
model = CountModel((3, 64, 64), NUM_CLASSES) | |
trainer = pl.Trainer( | |
max_epochs=1, | |
progress_bar_refresh_rate=5, | |
gpus=1, | |
callbacks=[early_stop_callback, checkpoint_callback], | |
) | |
trainer.fit(model, datamodule) | |
# Inference | |
# load the best model - model with lowest validation loss | |
model_ckpts = sorted(glob("lightning_logs/*/checkpoints/*.ckpt")) | |
losses = [] | |
for model_ckpt in model_ckpts: | |
loss = re.findall("\d+\.\d+", model_ckpt) | |
losses.append(float(loss[0])) | |
losses = np.array(losses) | |
best_model_index = np.argsort(losses)[0] | |
best_model = model_ckpts[best_model_index] | |
print(f"Best Model: {best_model}") | |
inference_model = CountModel.load_from_checkpoint(best_model) | |
test_dataset = TestDataset("../input/count-the-blue-boxes/test/test/") | |
test_dataloader = torch.utils.data.DataLoader( | |
test_dataset, batch_size=1024, num_workers=2 | |
) | |
print(f"Test Dataset: {len(test_dataset)}\tTest DataLoader: {len(test_dataloader)}") | |
y_pred = [] | |
for imgs in test_dataloader: | |
logits = inference_model(imgs) | |
preds = torch.argmax(logits, dim=1) | |
y_pred.extend(preds.detach().numpy()) | |
all_imgs = natsort.natsorted(os.listdir(main_dir)) | |
submission = pd.DataFrame.from_dict({"images": all_imgs, "labels": y_pred}) | |
submission.to_csv("submission.csv", index=False) | |
wandb.finish() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment