Skip to content

Instantly share code, notes, and snippets.

@PhilipMay
Last active September 7, 2020 12:58
Show Gist options
  • Save PhilipMay/bd250cba591b3252b8da2f3d31ee5b64 to your computer and use it in GitHub Desktop.
Save PhilipMay/bd250cba591b3252b8da2f3d31ee5b64 to your computer and use it in GitHub Desktop.
import logging
from pathlib import Path
import torch
from farm.data_handler.data_silo import DataSilo, DataSiloForCrossVal
from farm.data_handler.processor import TextClassificationProcessor
from farm.modeling.optimization import initialize_optimizer
from farm.modeling.adaptive_model import AdaptiveModel
from farm.modeling.language_model import LanguageModel
from farm.modeling.prediction_head import TextClassificationHead
from farm.modeling.tokenization import Tokenizer
from farm.train import Trainer, EarlyStopping
from farm.utils import set_all_seeds, MLFlowLogger, initialize_device_settings
#lang_model = "./models/dbmdz-bert-base-german-uncased"
lang_model = "bert-base-german-dbmdz-uncased"
#lang_model = "./models/german-nlp-group-electra-base-german-uncased"
#lang_model = "german-nlp-group/electra-base-german-uncased"
def doc_classification_crossvalidation():
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)
logging.getLogger('transformers').setLevel(logging.WARNING)
xval_folds = 5
xval_stratified = True
metric_name = "f1_macro"
save_dir = Path("./saved_models/electra-bert-test")
set_all_seeds(seed=42)
device, n_gpu = initialize_device_settings(use_cuda=True)
n_epochs = 3
batch_size = 32
evaluate_every = 100
use_amp = None
tokenizer = Tokenizer.load(pretrained_model_name_or_path=lang_model)
label_list = ["OTHER", "OFFENSE"]
processor = TextClassificationProcessor(tokenizer=tokenizer,
max_seq_len=64,
data_dir=Path("./data/germeval18"),
label_list=label_list,
metric=metric_name,
label_column_name="coarse_label"
)
data_silo = DataSilo(processor=processor, batch_size=batch_size)
silos = DataSiloForCrossVal.make(data_silo, n_splits=xval_folds)
def train_on_split(silo_to_use, n_fold, save_dir):
logger.info(f"############ Crossvalidation: Fold {n_fold} ############")
language_model = LanguageModel.load(lang_model)
prediction_head = TextClassificationHead(
class_weights=data_silo.calculate_class_weights(task_name="text_classification"),
num_labels=len(label_list))
model = AdaptiveModel(
language_model=language_model,
prediction_heads=[prediction_head],
embeds_dropout_prob=0.2,
lm_output_types=["per_sequence"],
device=device)
model, optimizer, lr_schedule = initialize_optimizer(
model=model,
learning_rate=0.5e-5,
device=device,
n_batches=len(silo_to_use.loaders["train"]),
n_epochs=n_epochs,
use_amp=use_amp)
earlystopping = EarlyStopping(
metric=metric_name,
mode="max",
save_dir=save_dir,
patience=5,
)
trainer = Trainer(
model=model,
optimizer=optimizer,
data_silo=silo_to_use,
epochs=n_epochs,
n_gpu=n_gpu,
lr_schedule=lr_schedule,
evaluate_every=evaluate_every,
device=device,
early_stopping=earlystopping,
)
trainer.train()
es_result = earlystopping.best_so_far
result = trainer.test_result[0][metric_name]
print('result from early stopping (on dev set)', es_result)
print('result from test set (with best loaded trial)', result)
input("Please compare result from early stopping (on dev set) and result from test set (with best loaded trial)...")
return model
for num_fold, silo in enumerate(silos):
model = train_on_split(silo, num_fold, save_dir)
# emtpy cache to avoid memory leak and cuda OOM across multiple folds
model.cpu()
torch.cuda.empty_cache()
if __name__ == "__main__":
doc_classification_crossvalidation()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment