Last active
April 9, 2018 01:38
-
-
Save YiqinZhao/d89b933b93740250e379a81ec34b1152 to your computer and use it in GitHub Desktop.
Automatic confusion matrix calculation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Automatic Keras callback class for: | |
- Confusion matrix calculation | |
- Weighted accuracy(AKA. WA) | |
- Unweighted accuracy(AKA. UA) | |
- Label distribution | |
- Predict correctness | |
Dependences: pandas, numpy | |
Usage: | |
Import at first: | |
from libTJNU.SavePredictResult import SavePredictResult | |
Append to your Keras callback list with test data | |
and label, like: | |
cb_list = [SavePredictResult(test_data=(X_test, y_test))] | |
model.fit(X_train, y_train, batch_size=batch_size, | |
epochs=epoch, validation_data=(X_test, y_test), | |
shuffle=True, callbacks=callbacks_list) | |
Example: | |
--------------------------------------------------------- | |
Result for current model | |
========================================================= | |
- WA : 0.6487804878048781 | |
- UA : 0.5528546658259773 | |
- Y Data length : [104 15 61 25] | |
- Correctness : [79 0 30 24] | |
- Confusion Matrix : | |
Classified as -> 0 1 2 3 Recall | |
0 79 0 1 24 76.0% | |
1 4 0 0 11 0.0% | |
2 31 0 30 0 49.2% | |
3 1 0 0 24 96.0% | |
Precision 68.7% nan% 96.8% 40.7% 0.0% | |
--------------------------------------------------------- | |
''' | |
import pandas as pd | |
import numpy as np | |
from keras.callbacks import Callback | |
class SavePredictResult(Callback): | |
def __init__(self, test_data): | |
self.filepath = filepath | |
def on_epoch_end(self, epoch, logs={}): | |
x, y = self.test_data | |
res = self.model.predict(x, batch_size=100) | |
# Calculate weighted accuracy | |
y = np.argmax(y, axis=-1) | |
p = np.argmax(res, axis=-1) | |
eql = np.equal(y, p) | |
crt = p[np.where(eql)] | |
dst = np.bincount(y) | |
prd = np.bincount(crt) | |
d_len = dst.shape[0] | |
# Bug fix | |
# April 9, 2018 | |
# Some times model give no predict result on the last category | |
# This will cause prd array has less length than dst | |
prd = np.pad(prd, (0, d_len - prd.shape[0]), 'constant', constant_values=(0,0)) | |
print('---------------------------------------------------------') | |
print(' Result for current model') | |
print('=========================================================') | |
print('- WA : ', np.sum(prd) / np.sum(dst)) | |
print('- UA : ', np.average(np.divide(prd, dst))) | |
print('- Y Data length : ', dst) | |
print('- Correctness : ', prd) | |
print('- Confusion Matrix :' ) | |
mtx = np.array([[ y[(y == i) & (p == j)].shape[0] for j in range(d_len) ] for i in range(d_len) ]) | |
recall = np.array([mtx[i][i] / np.sum(mtx[i]) for i in range(mtx.shape[0]) ]) | |
precision = np.array([mtx[i][i] / np.sum(mtx.transpose()[i]) for i in range(mtx.shape[0]) ]) | |
recall = np.array(['%.1f%%' % (i * 100) for i in recall]) | |
recall = np.expand_dims(recall, axis=-1) | |
precision = np.concatenate([precision, [0]]) | |
precision = np.array(['%.1f%%' % (i * 100) for i in precision]) | |
mtx = mtx.astype('str') | |
mtx = np.hstack([mtx, recall]) | |
mtx = np.vstack([mtx, [precision]]) | |
col = [i for i in range(d_len)] | |
idx = [i for i in range(d_len)] | |
col.append('Recall') | |
idx.append('Precision') | |
df = pd.DataFrame(mtx, columns=col, index=idx) | |
df.columns.names = ['Classified as ->'] | |
print(df) | |
print('---------------------------------------------------------') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment