Last active
May 6, 2020 04:46
-
-
Save joydeb28/6add8e80d657e27c87f86cc51110c2d3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DesignModel(): | |
def __init__(self): | |
self.model = None | |
self.train_data = [train_input_ids, train_input_masks, train_segment_ids] | |
self.train_labels = train_labels | |
def bert_model(self,max_seq_length): | |
in_id = Input(shape=(max_seq_length,), dtype=tf.int32, name="input_ids") | |
in_mask = Input(shape=(max_seq_length,), dtype=tf.int32, name="input_masks") | |
in_segment = Input(shape=(max_seq_length,), dtype=tf.int32, name="segment_ids") | |
bert_inputs = [in_id, in_mask, in_segment] | |
bert_pooling_out, bert_sequence_out = bert_model_obj.bert_module(bert_inputs) | |
out = GlobalAveragePooling1D()(bert_sequence_out) | |
out = Dropout(0.2)(out) | |
out = Dense(len(load_data_obj.cat_to_intent), activation="softmax", name="dense_output")(out) | |
self.model = Model(inputs=bert_inputs, outputs=out) | |
self.model.compile(optimizer=tf.keras.optimizers.Adam(1e-5), | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="acc")]) | |
self.model.summary() | |
def model_train(self,batch_size,num_epoch): | |
print("Fitting to model") | |
self.model.fit(self.train_data,self.train_labels,epochs=num_epoch,batch_size=batch_size,validation_split=0.2,shuffle=True) | |
print("Model Training complete.") | |
def save_model(self,model,model_name): | |
self.model.save(model_name+".h5") | |
print("Model saved to Model folder.") | |
model_obj = DesignModel() | |
model_obj.bert_model(bert_model_obj.max_len) | |
model_obj.model_train(32,1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment