Created
June 3, 2021 08:27
-
-
Save gaphex/7914752315d70ec3b46ca17e09a64864 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SBERT: | |
def __init__(self, config): | |
self.loss = 0 | |
self.metrics = [] | |
self.inputs = [] | |
self.config = config | |
self.build() | |
def build(self): | |
self.saver_dict = {} | |
self.build_body() | |
if self.config.use_par_head: | |
self.build_nli_head() | |
if self.config.use_toxic_head: | |
self.build_toxic_head() | |
if self.config.use_ner_head: | |
self.build_tag_head() | |
self.compile_model() | |
def compile_model(self): | |
log_this("Compiling") | |
self.train_model = tf.keras.models.Model(inputs=self.inputs, outputs=[self.loss]) | |
opt = tf.keras.optimizers.Adam(learning_rate=self.config.lr) | |
self.train_model.compile( | |
optimizer=opt, | |
loss=average_loss, | |
metrics=self.metrics) | |
log_this("The model is built") | |
def build_body(self): | |
self.nlu_encoder = BertLayer( | |
self.config.module_path, self.config.ctx_len, | |
n_tune_layers=self.config.n_tune, do_preprocessing=True, | |
pooling='mean', tune_embeddings=self.config.tune_embs, | |
trainable=self.config.train_bert) | |
def build_tag_head(self): | |
log_this("Building tagger head") | |
tag_input = layers.Input(shape=(1, ), dtype=tf.string) | |
tag_label = layers.Input(shape=(self.config.ctx_len, self.config.n_tags,), dtype=tf.float32) | |
self.nlu_encoder.as_dict = True | |
inp_tok_encoded = self.nlu_encoder(tag_input)['token_output'] | |
self.nlu_encoder.as_dict = False | |
tag_mlp = self.build_mlp( | |
2, self.config.dim, self.config.dim, self.config.n_tags, | |
name="ner", dropout_rate=self.config.head_dropout_rate) | |
tag_pred = tf.keras.layers.TimeDistributed(tag_mlp)(inp_tok_encoded) | |
tag_loss = tf.keras.losses.categorical_crossentropy(tag_label, tag_pred) | |
self.tag_model = tf.keras.models.Model(inputs=[tag_input], outputs=[tag_pred], name=f'tagger_model') | |
self.inputs += [tag_input, tag_label] | |
self.loss += self.config.tagger_loss_weight * tag_loss | |
def build_nli_head(self): | |
log_this("Building paraphraser head") | |
anc_input = layers.Input(shape=(1,), dtype=tf.string) | |
pos_input = layers.Input(shape=(1,), dtype=tf.string) | |
neg_input = layers.Input(shape=(1,), dtype=tf.string) | |
anc_encoded = self.nlu_encoder(anc_input) | |
pos_encoded = self.nlu_encoder(pos_input) | |
if self.config.train_bert: | |
neg_encoded = self.nlu_encoder(neg_input) | |
par_loss = tf.keras.layers.Lambda(softmax_loss)([anc_encoded, pos_encoded, neg_encoded]) | |
self.loss += self.config.paraphrase_loss_weight * par_loss | |
self.nli_encoder_model = tf.keras.models.Model(inputs=[pos_input], outputs=[pos_encoded]) | |
sim = tf.keras.layers.Lambda(cosine_similarity, name='similarity')([anc_encoded, pos_encoded]) | |
self.sim_model = tf.keras.models.Model(inputs=[anc_input, pos_input], outputs=[sim]) | |
self.inputs += [anc_input, pos_input, neg_input] | |
def build_toxic_head(self): | |
log_this("Building toxic head") | |
sent_input = layers.Input(shape=(1, ), dtype=tf.string) | |
sent_label = layers.Input(shape=(self.config.n_toxic_tags, ), dtype=tf.float32) | |
sents_encoded = self.nlu_encoder(sent_input) | |
tox_mlp = self.build_mlp( | |
2, self.config.dim, self.config.dim, self.config.n_toxic_tags, | |
name="toxic", dropout_rate=self.config.head_dropout_rate) | |
pred = tox_mlp(sents_encoded) | |
tox_loss = tf.keras.losses.categorical_crossentropy(sent_label, pred) | |
tox_loss = tf.reshape(tox_loss, (-1, 1)) | |
self.tox_model = tf.keras.models.Model(inputs=[sent_input], outputs=[pred], name=f'toxic_model') | |
self.inputs += [sent_input, sent_label] | |
self.loss += self.config.toxic_loss_weight * tox_loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment