热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

matlab画混淆矩阵加入值,混淆矩阵的绘制(Plotaconfusionmatrix)|文艺数学君

摘要这一篇简单介绍一下混淆矩阵的计算和绘制,混淆矩阵可以用来判断模型预测的结果。介绍这一篇主要介绍一下绘制混淆矩阵(confusionmatrix)的方

摘要这一篇简单介绍一下混淆矩阵的计算和绘制,混淆矩阵可以用来判断模型预测的结果。

介绍

这一篇主要介绍一下绘制混淆矩阵(confusion matrix)的方式。通常在看model的效果的时候,我们会使用混淆矩阵来进行检测。

主要参考资料 :

具体绘制方式

混淆矩阵的计算

混淆矩阵就是我们会计算最后分类错误的个数, 如计算将class1分为class2的个数,以此类推。

我们可以使用下面的方式来进行混淆矩阵的计算。

# 绘制混淆矩阵

def confusion_matrix(preds, labels, conf_matrix):

preds = torch.argmax(preds, 1)

for p, t in zip(preds, labels):

conf_matrix[p, t] += 1

return conf_matrix

conf_matrix = torch.zeros(10, 10)

for data, target in test_loader:

output = fullModel(data.to(device))

conf_matrix = confusion_matrix(output, target, conf_matrix)

最后得到的conf_matrix就是混淆矩阵的值。

484101d1f0c608b3380c62cf8aaac97d.png

混淆矩阵的可视化

有了上面的混淆矩阵中具体的值,下面就是进行可视化的步骤。可视化我们使用seaborn来进行完成。因为我这里conf_matrix的值是tensor, 所以需要先转换为Numpy.

import seaborn as sn

df_cm = pd.DataFrame(conf_matrix.numpy(),

index = [i for i in list(Attack2Index.keys())],

columns = [i for i in list(Attack2Index.keys())])

plt.figure(figsize = (10,7))

sn.heatmap(df_cm, annot=True, cmap="BuPu")

最终的混淆矩阵的图如下所示:

aacff6386a52802b01ebbfc7ba46c3cc.png

混淆矩阵的可视化(进行美化)

当然, 我们还可以对混淆矩阵做更多的处理, 使得显示的时候能更加好看一些. 下面的绘制混淆矩阵的函数我是在下面的链接里看到的, 最终的效果很是不错。

这里简单贴一下代码,可以方便直接进行使用。

import itertools

# 绘制混淆矩阵

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):

"""

This function prints and plots the confusion matrix.

Normalization can be applied by setting `normalize=True`.

Input

- cm : 计算出的混淆矩阵的值

- classes : 混淆矩阵中每一行每一列对应的列

- normalize : True:显示百分比, False:显示个数

"""

if normalize:

cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

print("Normalized confusion matrix")

else:

print('Confusion matrix, without normalization')

print(cm)

plt.imshow(cm, interpolation='nearest', cmap=cmap)

plt.title(title)

plt.colorbar()

tick_marks = np.arange(len(classes))

plt.xticks(tick_marks, classes, rotation=45)

plt.yticks(tick_marks, classes)

fmt = '.2f' if normalize else 'd'

thresh = cm.max() / 2.

for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):

plt.text(j, i, format(cm[i, j], fmt),

horizOntalalignment="center",

color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()

plt.ylabel('True label')

plt.xlabel('Predicted label')

测试数据如下所示:

cnf_matrix = np.array([[8707, 64, 731, 164, 45],

[1821, 5530, 79, 0, 28],

[266, 167, 1982, 4, 2],

[691, 0, 107, 1930, 26],

[30, 0, 111, 17, 42]])

attack_types = ['Normal', 'DoS', 'Probe', 'R2L', 'U2R']

我们分别测试normalize=True/False的效果。

plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Normalized confusion matrix')

a50697e330ac8a3a1198a5c5b3659b22.png

plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=False, title='Normalized confusion matrix')

eca635d3661de47fe0acfb46626d6655.png


推荐阅读
author-avatar
SU大肥婆_545
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有