分类任务中绘制混淆矩阵

在做分类任务时,我们通常需要绘制混淆矩阵来分析分类结果,sklearn包中提供了计算分类结果的混淆矩阵的函数from sklearn.metrics import confusion_matrix,但是将混淆矩阵绘制成图片还需要我们利用matplotlib库自己变吗实现。

下面是具体实现代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from sklearn.metrics import confusion_matrix
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

def confusion_matrix_plot_matplotlib(y_true, y_pred,cmap=plt.cm.Blues):
print ('Drawing confusion_matrix...')
cm = confusion_matrix(y_true, y_pred)
plt.matshow(cm, cmap=cmap)
plt.colorbar()

for x in range(len(cm)):
for y in range(len(cm)):
plt.annotate(cm[x, y], xy=(x, y), horizontalalignment='center', verticalalignment='center')

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
plt.savefig('../data/confusion_matrix.png')

参考:

百度知道

写的还不错?那就来个红包吧!
0%