Last active
November 26, 2018 22:04
-
-
Save oscarknagg/a42c34ba23078dd9b124765395b496c7 to your computer and use it in GitHub Desktop.
Key functionality for Prototypical Networks (Snell et al 2017)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def proto_net_episode(model: Module, | |
optimiser: Optimizer, | |
loss_fn: Callable, | |
x: torch.Tensor, | |
y: torch.Tensor, | |
n_shot: int, | |
k_way: int, | |
q_queries: int, | |
distance: str, | |
train: bool): | |
"""Performs a single training episode for a Prototypical Network. | |
# Arguments | |
model: Prototypical Network to be trained. | |
optimiser: Optimiser to calculate gradient step | |
loss_fn: Loss function to calculate between predictions and outputs. Should be cross-entropy | |
x: Input samples of few shot classification task | |
y: Input labels of few shot classification task | |
n_shot: Number of examples per class in the support set | |
k_way: Number of classes in the few shot classification task | |
q_queries: Number of examples per class in the query set | |
distance: Distance metric to use when calculating distance between class prototypes and queries | |
train: Whether (True) or not (False) to perform a parameter update | |
# Returns | |
loss: Loss of the Prototypical Network on this task | |
y_pred: Predicted class probabilities for the query set on this task | |
""" | |
if train: | |
model.train() | |
optimiser.zero_grad() | |
else: | |
model.eval() | |
# Embed all samples | |
embeddings = model(x) | |
# Samples are ordered by the NShotWrapper class as follows: | |
# k lots of n support samples from a particular class | |
# k lots of q query samples from those classes | |
support = embeddings[:n_shot*k_way] | |
queries = embeddings[n_shot*k_way:] | |
y_support = y[:n_shot * k_way] | |
y_queries = y[n_shot * q_queries:] | |
# Reshape so the first dimension indexes by class then take the mean | |
# along that dimension to generate the "prototypes" for each class | |
prototypes = support.reshape(k, n, -1).mean(dim=1) | |
# Calculate squared distances between all queries and all prototypes | |
# Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way) | |
distances = ( | |
queries.unsqueeze(1).expand(queries.shape[0], support.shape[0], -1) - | |
support.unsqueeze(0).expand(queries.shape[0], support.shape[0], -1) | |
).pow(2).sum(dim=2) | |
# Calculate log p_{phi} (y = k | x) | |
log_p_y = (-distances).log_softmax(dim=1) | |
loss = loss_fn(log_p_y, y_queries) | |
# Prediction probabilities are softmax over distances | |
y_pred = (-distances).softmax(dim=1) | |
if train: | |
# Take gradient step | |
loss.backward() | |
optimiser.step() | |
return loss, y_pred |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment