Created
June 21, 2020 22:55
-
-
Save bjourne/0cc9ae5729ab78ce0c7765a5c5a207c0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from observations import ptb | |
from time import sleep, time | |
from torch.nn import * | |
from torch.optim import * | |
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from torch_xla.core.xla_model import (get_ordinal, | |
is_master_ordinal, | |
master_print, | |
xla_device, | |
xrt_world_size) | |
from torch_xla.distributed.parallel_loader import ParallelLoader | |
from torch_xla.distributed.xla_multiprocessing import spawn | |
import torch | |
class RNN(Module): | |
def __init__(self, vocab_size, embed_size, hidden_size, | |
n_layers, emb_dropout): | |
super(RNN, self).__init__() | |
self.encoder = Embedding(vocab_size, embed_size) | |
self.lstm = LSTM(embed_size, hidden_size, | |
n_layers, batch_first = True) | |
self.linear = Linear(hidden_size, vocab_size) | |
self.drop = Dropout(emb_dropout) | |
def forward(self, x, state): | |
x2 = self.drop(self.encoder(x)) | |
out, state = self.lstm(x2, state) | |
out = out.reshape(out.size(0)*out.size(1), out.size(2)) | |
out = self.linear(out) | |
return out, state | |
def init_state(self, batch_size, device): | |
num_layers = self.lstm.num_layers | |
hidden_size = self.lstm.hidden_size | |
hs = torch.zeros(num_layers, batch_size, hidden_size) | |
cs = torch.zeros(num_layers, batch_size, hidden_size) | |
return hs.to(device), cs.to(device) | |
def text_to_tensor(text): | |
ix2ch = sorted(set(text)) | |
ch2ix = {c : i for i, c in enumerate(ix2ch)} | |
seq = torch.LongTensor([ch2ix[c] for c in text]) | |
return ix2ch, ch2ix, seq | |
def batchify(tensor, batch_size): | |
n_batches = tensor.size(0) // batch_size | |
tensor = tensor[:n_batches * batch_size] | |
return tensor.view(batch_size, -1) | |
def successor_samples(batched_tensor, seq_len): | |
for i in range(0, batched_tensor.size(1) - seq_len, seq_len): | |
x = batched_tensor[:, i:i+seq_len] | |
y = batched_tensor[:, (i+1):(i+1) + seq_len] | |
yield x, y | |
def load_data(ptb_path, batch_size, seq_len): | |
texts = ptb(ptb_path) | |
tensors = [text_to_tensor(text) for text in texts] | |
ix2ch, ch2ix, _ = tensors[0] | |
tensors = [batchify(t[2], batch_size) for t in tensors] | |
data = [list(successor_samples(t, seq_len)) for t in tensors] | |
return ix2ch, ch2ix, data | |
def fn(ix, flags): | |
batch_size = flags['batch_size'] | |
seq_len = flags['seq_len'] | |
if not is_master_ordinal(): | |
rendezvous('download_once') | |
ix2ch, ch2ix, data = load_data('./data', batch_size, seq_len) | |
train_ds, valid_ds, test_ds = data | |
train_loader = DataLoader( | |
train_ds, | |
batch_size = None, | |
sampler = DistributedSampler( | |
train_ds, | |
num_replicas = xrt_world_size(), | |
rank = get_ordinal(), | |
shuffle = True), | |
shuffle = False, | |
num_workers = 8) | |
if is_master_ordinal(): | |
rendezvous('download_once') | |
dev = xla_device() | |
model = RNN(len(ix2ch), 100, 512, 1, 0.1).to(dev) | |
crit = CrossEntropyLoss() | |
opt = SGD(model.parameters(), lr = 4) | |
for i in range(3): | |
start = time() | |
loader = ParallelLoader(train_loader, [dev]) \ | |
.per_device_loader(dev) | |
state = model.init_state(batch_size, dev) | |
model.train() | |
for x, y in loader: | |
opt.zero_grad() | |
state = [s.detach() for s in state] | |
y_hat, state = model(x, state) | |
loss = crit(y_hat, y.reshape(-1)) | |
loss.backward() | |
optimizer_step(opt) | |
elapsed = time() - start | |
master_print('%.2f seconds for %d batches.' | |
% (elapsed, len(train_ds))) | |
rendezvous('done') | |
if ix == 0: | |
sleep(0.5) | |
def main(): | |
flags = {'batch_size' : 32, 'seq_len' : 320} | |
start = time() | |
spawn(fn, args = (flags,), nprocs = 8, start_method = 'fork') | |
elapsed = time() - start | |
print('Took %.2f seconds.' % elapsed) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment