Created
November 5, 2021 14:44
-
-
Save dbpprt/bf06f54cfda9f15f11c2ce46fc0be216 to your computer and use it in GitHub Desktop.
Arcface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import logging | |
import os | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
import torch.utils.data.distributed | |
from torch.nn.utils import clip_grad_norm_ | |
from backbones import get_model | |
import losses | |
from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX | |
from partial_fc import PartialFC | |
from utils.utils_amp import MaxClipGradScaler | |
from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint | |
from utils.utils_config import get_config | |
from utils.utils_logging import AverageMeter, init_logging | |
def main(args): | |
cfg = get_config(args.config) | |
world_size = len(args.hosts) | |
os.environ["WORLD_SIZE"] = str(world_size) | |
host_rank = args.hosts.index(args.current_host) | |
os.environ["RANK"] = str(host_rank) | |
rank = int(os.environ['RANK']) | |
dist.init_process_group(backend=args.backend, rank=host_rank, world_size=world_size) | |
logger.info( | |
"Initialized the distributed environment: '{}' backend on {} nodes. ".format( | |
args.backend, dist.get_world_size() | |
) | |
+ "Current host rank is {}. Number of gpus: {}".format(dist.get_rank(), args.num_gpus) | |
) | |
local_rank = rank | |
torch.cuda.set_device(local_rank) | |
os.makedirs(cfg.output, exist_ok=True) | |
init_logging(rank, cfg.output) | |
if cfg.rec == "synthetic": | |
train_set = SyntheticDataset(local_rank=local_rank) | |
else: | |
train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank) | |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True) | |
train_loader = DataLoaderX( | |
local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size, | |
sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True) | |
backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank) | |
if cfg.resume: | |
try: | |
backbone_pth = os.path.join(cfg.output, "backbone.pth") | |
backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank))) | |
if rank == 0: | |
logging.info("backbone resume successfully!") | |
except (FileNotFoundError, KeyError, IndexError, RuntimeError): | |
if rank == 0: | |
logging.info("resume fail, backbone init successfully!") | |
backbone = torch.nn.parallel.DistributedDataParallel( | |
module=backbone, broadcast_buffers=False, device_ids=[local_rank]) | |
backbone.train() | |
margin_softmax = losses.get_loss(cfg.loss) | |
module_partial_fc = PartialFC( | |
rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume, | |
batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes, | |
sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output) | |
opt_backbone = torch.optim.SGD( | |
params=[{'params': backbone.parameters()}], | |
lr=cfg.lr / 512 * cfg.batch_size * world_size, | |
momentum=0.9, weight_decay=cfg.weight_decay) | |
opt_pfc = torch.optim.SGD( | |
params=[{'params': module_partial_fc.parameters()}], | |
lr=cfg.lr / 512 * cfg.batch_size * world_size, | |
momentum=0.9, weight_decay=cfg.weight_decay) | |
num_image = len(train_set) | |
total_batch_size = cfg.batch_size * world_size | |
cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch | |
cfg.total_step = num_image // total_batch_size * cfg.num_epoch | |
def lr_step_func(current_step): | |
cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch] | |
if current_step < cfg.warmup_step: | |
return current_step / cfg.warmup_step | |
else: | |
return 0.1 ** len([m for m in cfg.decay_step if m <= current_step]) | |
scheduler_backbone = torch.optim.lr_scheduler.LambdaLR( | |
optimizer=opt_backbone, lr_lambda=lr_step_func) | |
scheduler_pfc = torch.optim.lr_scheduler.LambdaLR( | |
optimizer=opt_pfc, lr_lambda=lr_step_func) | |
for key, value in cfg.items(): | |
num_space = 25 - len(key) | |
logging.info(": " + key + " " * num_space + str(value)) | |
val_target = cfg.val_targets | |
callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec) | |
callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None) | |
callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output) | |
loss = AverageMeter() | |
start_epoch = 0 | |
global_step = 0 | |
grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None | |
for epoch in range(start_epoch, cfg.num_epoch): | |
train_sampler.set_epoch(epoch) | |
for step, (img, label) in enumerate(train_loader): | |
global_step += 1 | |
features = F.normalize(backbone(img)) | |
x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc) | |
if cfg.fp16: | |
features.backward(grad_amp.scale(x_grad)) | |
grad_amp.unscale_(opt_backbone) | |
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) | |
grad_amp.step(opt_backbone) | |
grad_amp.update() | |
else: | |
features.backward(x_grad) | |
clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2) | |
opt_backbone.step() | |
opt_pfc.step() | |
module_partial_fc.update() | |
opt_backbone.zero_grad() | |
opt_pfc.zero_grad() | |
loss.update(loss_v, 1) | |
callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp) | |
callback_verification(global_step, backbone) | |
scheduler_backbone.step() | |
scheduler_pfc.step() | |
callback_checkpoint(global_step, backbone, module_partial_fc) | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
torch.backends.cudnn.benchmark = True | |
parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') | |
parser.add_argument('--config', type=str, default='config/speed', help='py config file') | |
parser.add_argument( | |
"--backend", | |
type=str, | |
default=None, | |
help="backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)", | |
) | |
# Container environment | |
parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"])) | |
parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"]) | |
parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) | |
parser.add_argument("--data-dir", type=str, default=os.environ["SM_CHANNEL_TRAINING"]) | |
parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"]) | |
main(parser.parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment