Skip to content

Instantly share code, notes, and snippets.

@yasufumy
Last active August 22, 2019 09:06
Show Gist options
  • Save yasufumy/ba73b587bd3c516b66fb94b3a90bac71 to your computer and use it in GitHub Desktop.
Save yasufumy/ba73b587bd3c516b66fb94b3a90bac71 to your computer and use it in GitHub Desktop.
import os
import torch
import lineflow as lf
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))
def tokenize(self, path):
assert os.path.exists(path)
dataset = lf.TextDataset(path, encoding='utf-8').map(lambda x: x.split() + ['<eos>'])
for word in dataset.flat_map(lambda x: x):
self.dictionary.add_word(word)
return torch.LongTensor(dataset.flat_map(
lambda x: [self.dictionary.word2idx[token] for token in x]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment