Created
May 27, 2019 00:33
-
-
Save emrul/74486783e9d750f2cb08695bf26719da to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from flair.data import Sentence, Dictionary | |
from flair.data_fetcher import NLPTaskDataFetcher, NLPTask | |
from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, CharacterEmbeddings, \ | |
PooledFlairEmbeddings, FlairEmbeddings | |
from flair.visual.training_curves import Plotter | |
from flair.trainers import ModelTrainer | |
from flair.models import SequenceTagger | |
from flair.datasets import ColumnCorpus | |
# 1. get the corpus | |
def train(): | |
columns = {0: 'text', 1: 'ner'} | |
data_folder = 'training/sequences' | |
corpus = ColumnCorpus(data_folder, columns, in_memory=False) # NLPTaskDataFetcher.load_column_corpus(data_folder, columns) #.downsample(0.05) | |
#corpus = NLPTaskDataFetcher.load_corpus(NLPTask.WNUT_17).downsample(0.1) | |
#for sentence in corpus.train: | |
# print(len(sentence)) | |
print(corpus) | |
# 2. what tag do we want to predict? | |
tag_type = 'ner' | |
# 3. make the tag dictionary from the corpus | |
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type) | |
print(tag_dictionary.idx2item) | |
# 4. initialize embeddings | |
embeddings = StackedEmbeddings(embeddings=[ | |
WordEmbeddings('glove'), | |
# contextual string embeddings, forward | |
FlairEmbeddings('news-forward', use_cache=True, chars_per_chunk=64), | |
FlairEmbeddings('news-backward', use_cache=True, chars_per_chunk=64), | |
WordEmbeddings(embeddings="embeddings/markup2vec.wv") | |
]) | |
tagger: SequenceTagger = SequenceTagger(hidden_size=128, | |
embeddings=embeddings, | |
tag_dictionary=tag_dictionary, | |
tag_type=tag_type) | |
# 6. initialize trainer | |
trainer: ModelTrainer = ModelTrainer(tagger, corpus) | |
# 7. start training | |
trainer.train('resources/taggers/v003', | |
max_epochs=150, | |
mini_batch_size=64, | |
eval_mini_batch_size=32, | |
#embeddings_in_memory=False, | |
checkpoint=True) | |
if __name__ == '__main__': | |
print("PyTorch Version: ",torch.__version__) | |
print("CUDA available: ", torch.cuda.is_available() ) | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment