Skip to content

Instantly share code, notes, and snippets.

@mvivarelli
Last active November 10, 2021 10:23
Show Gist options
  • Save mvivarelli/998689b00349c2007393f6d6eb3b8b0f to your computer and use it in GitHub Desktop.
Save mvivarelli/998689b00349c2007393f6d6eb3b8b0f to your computer and use it in GitHub Desktop.
weighted_categorical_crossentropy
def weighted_categorical_crossentropy(weights):
weights = K.variable(weights)
def loss(y_true, y_pred):
# scale predictions so that the class probas of each sample sum to 1
y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
# clip to prevent NaN's and Inf's
y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
# calc
lss = y_true * K.log(y_pred) * weights
lss += (1 - y_true) * K.log(1 - y_pred) * weights
lss = -K.sum(lss, -1)
return lss
return loss
weights = np.array([0.5,5,10])
loss = weighted_categorical_crossentropy(weights)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment