Last active
May 6, 2020 17:33
-
-
Save joydeb28/0a5bfc7f45730a3a6f8b2dde5cb14656 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BertModel(object): | |
def __init__(self): | |
self.max_len = 128 | |
bert_path = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1" | |
FullTokenizer=bert.bert_tokenization.FullTokenizer | |
self.bert_module = hub.KerasLayer(bert_path,trainable=True) | |
self.vocab_file = self.bert_module.resolved_object.vocab_file.asset_path.numpy() | |
self.do_lower_case = self.bert_module.resolved_object.do_lower_case.numpy() | |
self.tokenizer = FullTokenizer(self.vocab_file,self.do_lower_case) | |
def get_masks(self,tokens, max_seq_length): | |
mask_data = [1]*len(tokens) + [0] * (max_seq_length - len(tokens)) | |
return mask_data | |
def get_segments(self,tokens, max_seq_length): | |
''' | |
Segments: 0 for the first sequence, | |
1 for the second | |
''' | |
segments = [] | |
segment_id = 0 | |
for token in tokens: | |
segments.append(current_segment_id) | |
if token == "[SEP]": | |
segment_id = 1 | |
'''Remaining are padded with 0''' | |
remaining_segment = [0] * (max_seq_length - len(tokens)) | |
segment_data = segments + remaining_segment | |
return segment_data | |
def get_ids(self,tokens, tokenizer, max_seq_length): | |
token_ids = tokenizer.convert_tokens_to_ids(tokens,) | |
remaining_ids = [0] * (max_seq_length-len(token_ids)) | |
input_ids = token_ids + remaining_ids | |
return input_ids | |
def get_input_data(self,sentence,maxlen): | |
sent_token = self.tokenizer.tokenize(sentence) | |
sent_token = sent_token[:maxlen] | |
sent_token = ["[CLS]"] + sent_token + ["[SEP]"] | |
id = self.get_ids(sent_token, self.tokenizer, self.max_len) | |
mask = self.get_masks(sent_token, self.max_len) | |
segment = self.get_segments(sent_token, self.max_len) | |
input_data = [id,mask,segment] | |
return input_data | |
def get_input_array(self,sentences): | |
input_ids, input_masks, input_segments = [], [], [] | |
for sentence in tqdm(sentences,position=0, leave=True): | |
ids,masks,segments=self.get_input_data(sentence,self.max_len-2) | |
input_ids.append(ids) | |
input_masks.append(masks) | |
input_segments.append(segments) | |
input_array = [np.asarray(input_ids, dtype=np.int32),np.asarray(input_masks, dtype=np.int32), np.asarray(input_segments, dtype=np.int32)] | |
return input_array |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment