Last active
February 26, 2020 09:25
-
-
Save sai-prasanna/4562d73146af8b7a55b4b9d96da5a9a3 to your computer and use it in GitHub Desktop.
Multiprocess seq2seq reader using pytorch Dataloader, Dataset.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
from typing import Dict, Optional | |
import logging | |
import torch | |
import random | |
from collections import Counter | |
import numpy as np | |
from overrides import overrides | |
from torch.utils.data import Dataset, IterableDataset, DataLoader, DistributedSampler | |
from allennlp.common.checks import ConfigurationError | |
from allennlp.common.file_utils import cached_path | |
from allennlp.common.util import START_SYMBOL, END_SYMBOL | |
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | |
from allennlp.data.fields import TextField | |
from allennlp.data.instance import Instance | |
from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer | |
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer | |
logger = logging.getLogger(__name__) | |
@DatasetReader.register("seq2seq") | |
class Seq2SeqDatasetReader(DatasetReader): | |
""" | |
Read a tsv file containing paired sequences, and create a dataset suitable for a | |
``ComposedSeq2Seq`` model, or any model with a matching API. | |
Expected format for each input line: <source_sequence_string>\t<target_sequence_string> | |
The output of ``read`` is a list of ``Instance`` s with the fields: | |
source_tokens: ``TextField`` and | |
target_tokens: ``TextField`` | |
`START_SYMBOL` and `END_SYMBOL` tokens are added to the source and target sequences. | |
Parameters | |
---------- | |
source_tokenizer : ``Tokenizer``, optional | |
Tokenizer to use to split the input sequences into words or other kinds of tokens. Defaults | |
to ``WordTokenizer()``. | |
target_tokenizer : ``Tokenizer``, optional | |
Tokenizer to use to split the output sequences (during training) into words or other kinds | |
of tokens. Defaults to ``source_tokenizer``. | |
source_token_indexers : ``Dict[str, TokenIndexer]``, optional | |
Indexers used to define input (source side) token representations. Defaults to | |
``{"tokens": SingleIdTokenIndexer()}``. | |
target_token_indexers : ``Dict[str, TokenIndexer]``, optional | |
Indexers used to define output (target side) token representations. Defaults to | |
``source_token_indexers``. | |
source_add_start_token : bool, (optional, default=True) | |
Whether or not to add `START_SYMBOL` to the beginning of the source sequence. | |
delimiter : str, (optional, default="\t") | |
Set delimiter for tsv/csv file. | |
""" | |
def __init__( | |
self, | |
source_tokenizer: Tokenizer = None, | |
target_tokenizer: Tokenizer = None, | |
source_token_indexers: Dict[str, TokenIndexer] = None, | |
target_token_indexers: Dict[str, TokenIndexer] = None, | |
source_add_start_token: bool = True, | |
delimiter: str = "\t", | |
source_max_tokens: Optional[int] = None, | |
target_max_tokens: Optional[int] = None, | |
lazy: bool = False, | |
) -> None: | |
super().__init__(lazy) | |
self._source_tokenizer = source_tokenizer or WordTokenizer() | |
self._target_tokenizer = target_tokenizer or self._source_tokenizer | |
self._source_token_indexers = source_token_indexers or {"tokens": SingleIdTokenIndexer()} | |
self._target_token_indexers = target_token_indexers or self._source_token_indexers | |
self._source_add_start_token = source_add_start_token | |
self._delimiter = delimiter | |
self._source_max_tokens = source_max_tokens | |
self._target_max_tokens = target_max_tokens | |
self._source_max_exceeded = 0 | |
self._target_max_exceeded = 0 | |
self._epoch_counter = Counter() | |
self._initial_seed = 1337 | |
def _raw_dataset(self, file_path): | |
paired_sequences = [] | |
with open(cached_path(file_path), "r") as data_file: | |
logger.info("Reading instances from lines in file at: %s", file_path) | |
for line_num, row in enumerate(csv.reader(data_file, delimiter=self._delimiter)): | |
if len(row) != 2: | |
continue | |
source_sequence, target_sequence = row | |
paired_sequences.append(row) | |
np.random.RandomState(self._initial_seed + self._epoch_counter[file_path]).shuffle(paired_sequences) | |
return paired_sequences | |
def _to_instance(self, raw_data_item): | |
return self.text_to_instance(*raw_data_item) | |
@overrides | |
def _read(self, file_path): | |
raw_dataset = self._raw_dataset(file_path) | |
instancizer = self | |
dataset = _DatasetWrapper(raw_dataset, instancizer) | |
if torch.distributed.is_available(): | |
sampler = DistributedSampler(dataset) | |
else: | |
sampler = DistributedSampler(dataset, 1, 0) | |
loader = DataLoader(dataset, batch_size=100, num_workers=2, sampler=sampler, collate_fn=identity) | |
for instances in loader: | |
for instance in instances: | |
instance["source_tokens"]._token_indexers = self._source_token_indexers | |
instance["target_tokens"]._token_indexers = self._target_token_indexers | |
yield instance | |
self._epoch_counter[file_path] += 1 | |
@overrides | |
def text_to_instance( | |
self, source_string: str, target_string: str = None | |
) -> Instance: # type: ignore | |
tokenized_source = self._source_tokenizer.tokenize(source_string) | |
if self._source_max_tokens and len(tokenized_source) > self._source_max_tokens: | |
self._source_max_exceeded += 1 | |
tokenized_source = tokenized_source[: self._source_max_tokens] | |
if self._source_add_start_token: | |
tokenized_source.insert(0, Token(START_SYMBOL)) | |
tokenized_source.append(Token(END_SYMBOL)) | |
source_field = TextField(tokenized_source, self._source_token_indexers) | |
if target_string is not None: | |
tokenized_target = self._target_tokenizer.tokenize(target_string) | |
if self._target_max_tokens and len(tokenized_target) > self._target_max_tokens: | |
self._target_max_exceeded += 1 | |
tokenized_target = tokenized_target[: self._target_max_tokens] | |
tokenized_target.insert(0, Token(START_SYMBOL)) | |
tokenized_target.append(Token(END_SYMBOL)) | |
target_field = TextField(tokenized_target, self._target_token_indexers) | |
return Instance({"source_tokens": source_field, "target_tokens": target_field}) | |
else: | |
return Instance({"source_tokens": source_field}) | |
class _DatasetWrapper(Dataset): | |
def __init__(self, raw_dataset, instancizer): | |
self._raw_dataset = raw_dataset | |
self.instancizer = instancizer | |
def __getitem__(self, index): | |
return self.instancizer._to_instance(self._raw_dataset[index]) | |
def __len__(self): | |
return len(self._raw_dataset) | |
def identity(x): | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment