Last active
September 8, 2023 19:57
-
-
Save fauxneticien/d57dbe5fc9d7ec38a8e35920d03cdb92 to your computer and use it in GitHub Desktop.
Lhotse DDP test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader | |
import os | |
import torchaudio | |
from lhotse import CutSet, Fbank, FbankConfig | |
from lhotse.dataset import IterableDatasetWrapper, SpeechSynthesisDataset, DynamicBucketingSampler, OnTheFlyFeatures, make_worker_init_fn | |
from lhotse.recipes import download_librispeech, prepare_librispeech | |
from tqdm import tqdm | |
# Download and set up data on first run | |
if not os.path.exists("LibriSpeech/dev-clean-2"): | |
download_librispeech(dataset_parts="dev-clean-2") | |
libri = prepare_librispeech(corpus_dir="LibriSpeech") | |
CutSet.from_manifests(**libri['dev-clean-2']).to_jsonl("LibriSpeech/dev-clean-2.jsonl") | |
import torch.multiprocessing as mp | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.distributed import init_process_group, destroy_process_group | |
import os | |
def ddp_setup(rank, world_size): | |
""" | |
Args: | |
rank: Unique identifier of each process | |
world_size: Total number of processes | |
""" | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = "12355" | |
init_process_group(backend="nccl", rank=rank, world_size=world_size) | |
torch.cuda.set_device(rank) | |
class Trainer: | |
def __init__( | |
self, | |
model: torch.nn.Module, | |
train_data: DataLoader, | |
optimizer: torch.optim.Optimizer, | |
gpu_id: int, | |
loss | |
) -> None: | |
self.gpu_id = gpu_id | |
self.model = model.to(gpu_id) | |
self.train_data = train_data | |
self.optimizer = optimizer | |
self.loss = loss | |
self.model = DDP(model, device_ids=[gpu_id]) | |
def train(self, max_iters: int): | |
epoch = 0 | |
iterator = iter(self.train_data) | |
for global_step in tqdm(range(max_iters), disable=self.gpu_id != 0): | |
try: | |
worker_info, cut_ids, batch = next(iterator) | |
except StopIteration: | |
epoch += 1 | |
self.train_data.sampler.set_epoch(epoch) | |
iterator = iter(self.train_data) | |
worker_info, cut_ids, batch = next(iterator) | |
# Uncomment to display cut_ids to make sure GPUs are getting different data across epochs/GPUs/processes/workers | |
# print(f"GPU: {self.gpu_id}; Worker: {worker_info.id + 1}/{worker_info.num_workers}; Iteration: {global_step}, Cuts: {cut_ids}") | |
log_probs = self.model(batch['features'].to(self.gpu_id)) | |
loss = self.loss( | |
log_probs.transpose(0, 1), | |
batch["tokens"].to(self.gpu_id), | |
batch["features_lens"].to(self.gpu_id), | |
batch["tokens_lens"].to(self.gpu_id) | |
) | |
loss.backward() | |
self.optimizer.step() | |
class ASRDataset(SpeechSynthesisDataset): | |
# Hijacking SpeechSynthesisDataset (has token collator, etc.) instead of K2SpeechRecognitionDataset for demo | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def __getitem__(self, cuts: CutSet): | |
worker_info=torch.utils.data.get_worker_info() | |
cut_ids = ", ".join([ c.id for c in cuts ]) | |
batch=super(ASRDataset, self).__getitem__(cuts) | |
return worker_info, cut_ids, batch | |
def main(rank: int, world_size: int, max_iters: int): | |
ddp_setup(rank, world_size) | |
cuts = CutSet.from_jsonl_lazy("LibriSpeech/dev-clean-2.jsonl") | |
sampler = DynamicBucketingSampler( | |
cuts, | |
shuffle=True, | |
max_duration=60, | |
drop_last=True, | |
num_buckets=10, | |
rank=rank, | |
world_size=world_size | |
) | |
dataset = ASRDataset(cuts, feature_input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))) | |
train_data = DataLoader( | |
dataset, | |
sampler=sampler, | |
batch_size=None, | |
num_workers=8, | |
worker_init_fn=make_worker_init_fn( | |
rank=rank, | |
world_size=world_size | |
), | |
persistent_workers=True, | |
pin_memory=True | |
) | |
tokenizer = dataset.token_collater | |
model = torchaudio.models.DeepSpeech(n_feature=80, n_class=len(list(tokenizer.idx2token))) | |
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |
loss = torch.nn.CTCLoss(blank=list(tokenizer.idx2token).index('<pad>'), reduction="mean", zero_infinity=True) | |
trainer = Trainer(model, train_data, optimizer, rank, loss) | |
trainer.train(max_iters) | |
destroy_process_group() | |
if __name__ == "__main__": | |
# import argparse | |
# parser = argparse.ArgumentParser(description='simple distributed training job') | |
# parser.add_argument('total_epochs', type=int, help='Total epochs to train the model') | |
# parser.add_argument('save_every', type=int, help='How often to save a snapshot') | |
# parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') | |
# args = parser.parse_args() | |
world_size = torch.cuda.device_count() | |
# mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size) | |
mp.spawn(main, args=(world_size, 100), nprocs=world_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment