Created
February 15, 2021 10:43
-
-
Save allanbatista/a5d902017827081f53b5e1e63dcc25de to your computer and use it in GitHub Desktop.
Pytorch Utils
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import s3fs | |
import math | |
import torch | |
import datetime | |
from sklearn.metrics import f1_score | |
import numpy as np | |
import tensorflow as tf | |
from collections import OrderedDict | |
s3 = s3fs.S3FileSystem(anon=False) | |
def log(*args): | |
print(f"[{datetime.datetime.now().isoformat()}]", *args) | |
def create_class_weight(labels_dict, mu=0.15): | |
total = np.sum(list(labels_dict.values())) | |
keys = labels_dict.keys() | |
class_weight = dict() | |
for key in keys: | |
score = math.log(mu * total/ float(labels_dict[key]) ) | |
class_weight[key] = score if score > 1.0 else 1.0 | |
return class_weight | |
def compute_weights(labels_index, labels_count, label): | |
labels = labels_index[label] #=> index of label | |
weights = create_class_weight(labels_count[label]) #=> labels_count #=> quantity of items into training for each label | |
weights = sorted([(labels[name], value) for name, value in weights.items()], key=lambda x: x[0]) | |
return [v for k, v in weights] | |
def open_io(path: str, mode: str = 'r'): | |
if path.startswith('s3://'): | |
return s3.open(path.split("//", 1)[1], mode) | |
elif path.startswith('gs://'): | |
return tf.io.gfile.GFile(path, mode) | |
else: | |
return open(path, mode) | |
def open_fn(path: str): | |
return open_io(path, "rb") | |
def exists(path): | |
if path.startswith('s3://'): | |
return s3.exists(path.split("//", 1)[1]) | |
if path.startswith('gs://'): | |
return tf.io.gfile.exists(path) | |
else: | |
return os.path.exists(path) | |
def mkdirs(*paths): | |
print(*paths) | |
for path in paths: | |
if not (path.startswith('s3://') or path.startswith('gs://')): | |
log("creating", path) | |
os.makedirs(path, exist_ok=True) | |
def ls(path: str): | |
if path.startswith('s3://'): | |
return s3.ls(path) | |
elif path.startswith('gs://'): | |
return [os.path.join(path, filename) for filename in tf.io.gfile.listdir(path)] | |
results = [] | |
for root, dirs, files in os.walk(path): | |
results += [os.path.join(root, filename) for filename in files] | |
return results | |
def copy_all(origin_dir, destiny_dir: str): | |
log("coping", origin_dir, "to", destiny_dir) | |
for origin_path in ls(origin_dir): | |
with open_io(origin_path, "rb") as f: | |
destiny_path = os.path.join(destiny_dir, origin_path.replace(origin_dir + "/", "")) | |
if not (destiny_dir.startswith("s3://") or destiny_dir.startswith("gs://")): | |
mkdirs(os.path.dirname(destiny_path)) | |
with open_io(destiny_path, "wb") as ff: | |
ff.write(f.read()) | |
def create_checkpoint(model, optimizer, epoch, path, losses): | |
log("start create checkpoint") | |
mkdirs(path) | |
best_path = os.path.join(path, f'checkpoint-best.pth.tar') | |
last_path = os.path.join(path, f'checkpoint-last.pth.tar') | |
_checkpoint = { | |
'epoch': epoch, | |
'state_dict': model.state_dict(), | |
'optimizer': optimizer.state_dict(), | |
'losses': losses | |
} | |
with open_io(last_path, "wb") as f: | |
log("creating last checkpoint", last_path) | |
torch.save(_checkpoint, f) | |
if len(losses) == 1 or losses[-1] <= min(losses[:-1]): | |
log("creating best checkpoint", best_path) | |
with open_io(best_path, "wb") as f: | |
torch.save(_checkpoint, f) | |
def restore_checkpoint(model, optimizer, path, name='checkpoint-last.pth.tar', restore=True): | |
log("start restoring checkpoint") | |
last_path = os.path.join(path, "checkpoints", name) | |
if not exists(last_path): | |
return 0, [] | |
log("=> loading checkpoint '{}'".format(last_path)) | |
with open_io(last_path, "rb") as f: | |
checkpoint = torch.load(f) | |
if restore: | |
if checkpoint.get('state_dict') is not None: | |
log("Restoring model") | |
model.load_state_dict(checkpoint['state_dict']) | |
# new_state_dict = OrderedDict() | |
# for k, v in checkpoint['state_dict'].items(): | |
# name = k[7:] if k.startswith("module.") else k | |
# new_state_dict[name] = v | |
# | |
# model.load_state_dict(new_state_dict) | |
if checkpoint.get('optimizer') is not None and optimizer is not None: | |
log("Restoring Optimizer") | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
return checkpoint['epoch'] + 1, checkpoint.get('losses', []) | |
class Mean: | |
def __init__(self): | |
self.values = [] | |
def add(self, x): | |
self.values.append(x) | |
def avg(self): | |
return sum(self.values) / len(self.values) | |
class Metric: | |
def __init__(self): | |
self.y_true = [] | |
self.y_pred = [] | |
self.total = 0 | |
def add(self, y_true, y_pred): | |
self.total += len(y_true) | |
self.y_true += y_true.tolist() | |
self.y_pred += y_pred.tolist() | |
def acc(self): | |
return sum((np.array(self.y_true) == np.array(self.y_pred)).astype('int')) / self.total | |
def f1(self, average='macro'): | |
return f1_score(np.array(self.y_true), np.array(self.y_pred), average=average) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment