• 技术文章 >Python技术 >Python基础教程

    如何实现python绘制混淆矩阵?

    2020-11-04 11:10:38原创8800
    大家从python基础到如今的入门,想必都对python有一定基础,今天小编给大家带来一个关于python的高阶内容——绘制混淆矩阵,一起来看下吧~

    介绍:

    混淆矩阵通过表示正确/不正确标签的计数来表示模型在表格格式中的准确性。

    计算/绘制混淆矩阵:

    以下是计算混淆矩阵的过程。

    您需要一个包含预期结果值的测试数据集或验证数据集。

    然后将这些数字组织成表格或矩阵,如下所示:

    然后将正确和不正确分类的计数填入表格中。

    Reading混淆矩阵:

    一个类的正确预测的总数进入该类值的预期行,以及该类值的预测列。

    以同样的方式,一个类别的不正确预测总数进入该类别值的预期行,以及该类别值的预测列。

    对角元素表示预测标签等于真实标签的点的数量,而非对角线元素是分类器错误标记的元素。混淆矩阵的对角线值越高越好,表明许多正确的预测。

    Python绘制混淆矩阵 :

    1

    2

    3

    4

    5

    6

    7

    8

    9

    10

    11

    12

    13

    14

    15

    16

    17

    18

    19

    20

    21

    22

    23

    24

    25

    26

    27

    28

    29

    30

    31

    32

    33

    34

    35

    36

    37

    38

    39

    40

    41

    42

    43

    44

    45

    46

    47

    48

    49

    50

    51

    52

    53

    54

    55

    56

    57

    58

    59

    60

    61

    62

    63

    64

    65

    66

    67

    68

    69

    70

    71

    72

    73

    74

    75

    76

    77

    78

    79

    80

    81

    82

    83

    84

    85

    86

    87

    88

    89

    90

    91

    92

    93

    94

    95

    96

    97

    98

    99

    100

    101

    102

    103

    104

    105

    106

    107

    108

    109

    110

    111

    112

    113

    114

    115

    116

    117

    import itertools

      

    import numpy as np

      

    import matplotlib.pyplot as plt

      

    from sklearn import svm, datasets

      

    from sklearn.model_selection import train_test_split

      

    from sklearn.metrics import confusion_matrix

      

    # import some data to play with

      

    iris = datasets.load_iris()

      

    X = iris.data

      

    y = iris.target

      

    class_names = iris.target_names

      

    # Split the data into a training set and a test set

      

    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

      

    # Run classifier, using a model that is too regularized (C too low) to see

      

    # the impact on the results

      

    classifier = svm.SVC(kernel='linear', C=0.01)

      

    y_pred = classifier.fit(X_train, y_train).predict(X_test)

      

    def plot_confusion_matrix(cm, classes,

      

    normalize=False,

      

    title='Confusion matrix',

      

    cmap=plt.cm.Blues):

      

    """

      

    This function prints and plots the confusion matrix.

      

    Normalization can be applied by setting `normalize=True`.

      

    """

      

    if normalize:

      

    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

      

    print("Normalized confusion matrix")

      

    else:

      

    print('Confusion matrix, without normalization')

      

    print(cm)

      

    plt.imshow(cm, interpolation='nearest', cmap=cmap)

      

    plt.title(title)

      

    plt.colorbar()

      

    tick_marks = np.arange(len(classes))

      

    plt.xticks(tick_marks, classes, rotation=45)

      

    plt.yticks(tick_marks, classes)

      

    fmt = '.2f' if normalize else 'd'

      

    thresh = cm.max() / 2.

      

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):

      

    plt.text(j, i, format(cm[i, j], fmt),

      

    horizontalalignment="center",

      

    color="white" if cm[i, j] > thresh else "black")

      

    color="white" if cm[i, j] > thresh else "black")

      

    plt.tight_layout()

      

    plt.ylabel('True label')

      

    plt.xlabel('Predicted label')

      

    # Compute confusion matrix

      

    cnf_matrix = confusion_matrix(y_test, y_pred)

      

    np.set_printoptions(precision=2)

      

    # Plot non-normalized confusion matrix

      

    plt.figure()

      

    plot_confusion_matrix(cnf_matrix, classes=class_names,

      

    title='Confusion matrix, without normalization')

      

    # Plot normalized confusion matrix

      

    plt.figure()

      

    plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,

      

    title='Normalized confusion matrix')

      

    plt.show()

    1

    2

    3

    4

    5

    6

    7

    8

    9

    10

    11

    12

    13

    14

    15

    Confusion matrix, without normalization

      

    [[13 0 0]

      

    [ 0 10 6]

      

    [ 0 0 9]]

      

    Normalized confusion matrix

      

    [[ 1. 0. 0. ]

      

    [ 0. 0.62 0.38]

      

    [ 0. 0. 1. ]]

    好了,大家可以消化学习下哦~如需了解更多python实用知识,点击进入PyThon学习网教学中心

    专题推荐:实现python绘制混淆矩阵
    上一篇:关于python matplotlib绘图使用详解 下一篇:如何使用python中的opengl?

    相关文章推荐

    • python计算三角形面积详细代码演示• 关于python property用法操作详解• 如何使用Python中item()方法遍历字典?• 关于python matplotlib绘图使用详解

    全部评论我要评论

    © 2021 Python学习网 苏ICP备2021003149号-1

  • 取消发布评论
  • 

    Python学习网