Skip to content

Instantly share code, notes, and snippets.

@isaacmg
Created February 10, 2020 19:06
Show Gist options
  • Save isaacmg/4df4c99eca0991dfd87aa6a229b9e2d4 to your computer and use it in GitHub Desktop.
Save isaacmg/4df4c99eca0991dfd87aa6a229b9e2d4 to your computer and use it in GitHub Desktop.
@TrainerBase.register('metatrainer')
class MetaTrainer(Trainer):
def __init__(self,
model: Model,
meta_model: MetaModel,
optimizer: torch.optim.Optimizer,
iterator: DataIterator,
train_datasets: List[Iterable[Instance]],
validation_datasets: Optional[Iterable[Instance]] = None,
# meta learner parameters
meta_batches: int = 200,
inner_steps: int = 3,
tasks_per_batch: int = 2,
batch_norm = True,
**kwargs) -> None:
"""
A metatrainer for doing meta-learning. It just takes a list of labeled datasets
and a ``DataIterator``, and uses the supplied meta-learner to learn the weights
for your model over some fixed number of epochs. You can also pass in a validation
datasets and enable early stopping. There are many other bells and whistles as well.
Parameters
----------
model : ``Model``, required.
"""
# I am not calling move_to_gpu here, because if the model is
# not already on the GPU then the optimizer is going to be wrong.
super().__init__(model, optimizer, iterator, train_datasets, **kwargs)
self.train_data = train_datasets
self._validation_data = validation_datasets
# Meta Trainer specific params
self.meta_batches = meta_batches
self.tasks_per_batch = tasks_per_batch
self.inner_steps = inner_steps
self.step_size = .01
self.batch_norm = batch_norm
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment