Skip to content

Instantly share code, notes, and snippets.

@gaphex
Created June 3, 2021 08:27
Show Gist options
  • Save gaphex/7914752315d70ec3b46ca17e09a64864 to your computer and use it in GitHub Desktop.
Save gaphex/7914752315d70ec3b46ca17e09a64864 to your computer and use it in GitHub Desktop.
class SBERT:
def __init__(self, config):
self.loss = 0
self.metrics = []
self.inputs = []
self.config = config
self.build()
def build(self):
self.saver_dict = {}
self.build_body()
if self.config.use_par_head:
self.build_nli_head()
if self.config.use_toxic_head:
self.build_toxic_head()
if self.config.use_ner_head:
self.build_tag_head()
self.compile_model()
def compile_model(self):
log_this("Compiling")
self.train_model = tf.keras.models.Model(inputs=self.inputs, outputs=[self.loss])
opt = tf.keras.optimizers.Adam(learning_rate=self.config.lr)
self.train_model.compile(
optimizer=opt,
loss=average_loss,
metrics=self.metrics)
log_this("The model is built")
def build_body(self):
self.nlu_encoder = BertLayer(
self.config.module_path, self.config.ctx_len,
n_tune_layers=self.config.n_tune, do_preprocessing=True,
pooling='mean', tune_embeddings=self.config.tune_embs,
trainable=self.config.train_bert)
def build_tag_head(self):
log_this("Building tagger head")
tag_input = layers.Input(shape=(1, ), dtype=tf.string)
tag_label = layers.Input(shape=(self.config.ctx_len, self.config.n_tags,), dtype=tf.float32)
self.nlu_encoder.as_dict = True
inp_tok_encoded = self.nlu_encoder(tag_input)['token_output']
self.nlu_encoder.as_dict = False
tag_mlp = self.build_mlp(
2, self.config.dim, self.config.dim, self.config.n_tags,
name="ner", dropout_rate=self.config.head_dropout_rate)
tag_pred = tf.keras.layers.TimeDistributed(tag_mlp)(inp_tok_encoded)
tag_loss = tf.keras.losses.categorical_crossentropy(tag_label, tag_pred)
self.tag_model = tf.keras.models.Model(inputs=[tag_input], outputs=[tag_pred], name=f'tagger_model')
self.inputs += [tag_input, tag_label]
self.loss += self.config.tagger_loss_weight * tag_loss
def build_nli_head(self):
log_this("Building paraphraser head")
anc_input = layers.Input(shape=(1,), dtype=tf.string)
pos_input = layers.Input(shape=(1,), dtype=tf.string)
neg_input = layers.Input(shape=(1,), dtype=tf.string)
anc_encoded = self.nlu_encoder(anc_input)
pos_encoded = self.nlu_encoder(pos_input)
if self.config.train_bert:
neg_encoded = self.nlu_encoder(neg_input)
par_loss = tf.keras.layers.Lambda(softmax_loss)([anc_encoded, pos_encoded, neg_encoded])
self.loss += self.config.paraphrase_loss_weight * par_loss
self.nli_encoder_model = tf.keras.models.Model(inputs=[pos_input], outputs=[pos_encoded])
sim = tf.keras.layers.Lambda(cosine_similarity, name='similarity')([anc_encoded, pos_encoded])
self.sim_model = tf.keras.models.Model(inputs=[anc_input, pos_input], outputs=[sim])
self.inputs += [anc_input, pos_input, neg_input]
def build_toxic_head(self):
log_this("Building toxic head")
sent_input = layers.Input(shape=(1, ), dtype=tf.string)
sent_label = layers.Input(shape=(self.config.n_toxic_tags, ), dtype=tf.float32)
sents_encoded = self.nlu_encoder(sent_input)
tox_mlp = self.build_mlp(
2, self.config.dim, self.config.dim, self.config.n_toxic_tags,
name="toxic", dropout_rate=self.config.head_dropout_rate)
pred = tox_mlp(sents_encoded)
tox_loss = tf.keras.losses.categorical_crossentropy(sent_label, pred)
tox_loss = tf.reshape(tox_loss, (-1, 1))
self.tox_model = tf.keras.models.Model(inputs=[sent_input], outputs=[pred], name=f'toxic_model')
self.inputs += [sent_input, sent_label]
self.loss += self.config.toxic_loss_weight * tox_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment