Skip to content

Instantly share code, notes, and snippets.

Forked from padamson/rwml-R_figure4_20.R
Created November 23, 2018 11:47
Show Gist options
  • Save uMklami/58e922394440e3e585897e0e177bfc14 to your computer and use it in GitHub Desktop.
Save uMklami/58e922394440e3e585897e0e177bfc14 to your computer and use it in GitHub Desktop.
ROC curves for each class of the MNIST 10-class classifier
mnistResultsDF <- data.frame(actual = mnistTest$label,
fit = mnist.kknn$fit,$prob))
plotROCs <- function(df, digitList) {
firstPlot <- TRUE
legendList <- NULL
for (digit in digitList) {
dfDigit <- df %>%
filter(as.character(actual) == as.character(digit) |
as.character(fit) == as.character(digit)) %>%
mutate(prediction = (as.character(actual) == as.character(fit)))
pred <- prediction(dfDigit[,digit+3], dfDigit$prediction)
perf <- performance(pred, "tpr", "fpr")
auc <- performance(pred, "auc")
legendList <- append(legendList,
paste0("Digit: ",digit,", AUC: ",
round(auc@y.values[[1]], digits = 4)))
if (firstPlot == TRUE) {
plot(perf, colorize = FALSE, lty = digit+1, col = digit+1)
firstPlot <- FALSE
} else {
plot(perf, colorize = FALSE, add = TRUE, lty = digit+1, col = digit+1)
legend(x=0.4, y=0.6,
legend = legendList,
col = 1:10,
lty = 1:10,
bty = "n")
plotROCs(mnistResultsDF, 0:9)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment