如何修正线性SVM的误报率?

 Tibetan-妍自_557 发布于 2023-01-30 17:35

我是一个SVM新手,这是我的用例:我有很多不平衡的数据要使用线性SVM进行二进制分类.我需要修正某些值的误报率,并测量每个值的相应误差.我正在使用类似下面的代码使用scikit-learn svm实现:

# define training data
X = [[0, 0], [1, 1]]
y = [0, 1]

# define and train the SVM
clf = svm.LinearSVC(C=0.01, class_weight='auto') #auto for unbalanced distributions
clf.fit(X, y)

# compute false positives and false negatives
predictions = [clf.predict(ex) for ex in X]    
false_positives = [(a, b) for (a, b) in zip(predictions,y) if a != b and b == 0]
false_negatives = [(a, b) for (a, b) in zip(predictions,y) if a != b and b == 1] 

有没有办法使用分类器的参数(或几个参数),以便有效地修复测量指标?

2 个回答
  • class_weights参数允许您向上或向下推动此误报率.让我用一个日常的例子来说明这是如何工作的.假设您拥有一个夜总会,并且您在两个限制条件下运营:

      您希望尽可能多的人进入俱乐部(付费客户)

      你不希望任何未成年人进入,因为这会让你陷入困境

    平均每天,(比如)只有5%的人试图进入俱乐部将是未成年人.你面临着一个选择:宽容或严格.前者会使你的利润增加5%,但你冒着昂贵的诉讼风险.后者将不可避免地意味着一些超过法定年龄的人将被拒绝入境,这也将花费你的钱.你想调整relative cost宽大与严格.注意:你不能直接控制有多少未成年人进入俱乐部,但你可以控制你的保镖有多严格.

    这里有一些Python,它显示了在更改相对重要性时会发生什么.

    from collections import Counter
    import numpy as np
    from sklearn.datasets import load_iris
    from sklearn.svm import LinearSVC
    
    data = load_iris()
    
    # remove a feature to make the problem harder
    # remove the third class for simplicity
    X = data.data[:100, 0:1] 
    y = data.target[:100] 
    # shuffle data
    indices = np.arange(y.shape[0])
    np.random.shuffle(indices)
    X = X[indices, :]
    y = y[indices]
    
    for i in range(1, 20):
        clf = LinearSVC(class_weight={0: 1, 1: i})
        clf = clf.fit(X[:50, :], y[:50])
        print i, Counter(clf.predict(X[50:]))
        # print clf.decision_function(X[50:])
    

    哪个输出

    1 Counter({1: 22, 0: 28})
    2 Counter({1: 31, 0: 19})
    3 Counter({1: 39, 0: 11})
    4 Counter({1: 43, 0: 7})
    5 Counter({1: 43, 0: 7})
    6 Counter({1: 44, 0: 6})
    7 Counter({1: 44, 0: 6})
    8 Counter({1: 44, 0: 6})
    9 Counter({1: 47, 0: 3})
    10 Counter({1: 47, 0: 3})
    11 Counter({1: 47, 0: 3})
    12 Counter({1: 47, 0: 3})
    13 Counter({1: 47, 0: 3})
    14 Counter({1: 47, 0: 3})
    15 Counter({1: 47, 0: 3})
    16 Counter({1: 47, 0: 3})
    17 Counter({1: 48, 0: 2})
    18 Counter({1: 48, 0: 2})
    19 Counter({1: 48, 0: 2})
    

    注意分类为0减少的数据点的数量是类1增加的相对权重.假设您有计算资源和时间来训练和评估10个分类器,您可以绘制每个分类器的精确度和召回率,并得到如下图所示的数字(在互联网上无耻地被盗).然后,您可以使用它来确定class_weights用例的正确值.

    精确召回权衡

    2023-01-30 17:36 回答
  • LinearSVCsklearn中的预测方法如下所示

    def predict(self, X):
        """Predict class labels for samples in X.
    
        Parameters
        ----------
        X : {array-like, sparse matrix}, shape = [n_samples, n_features]
            Samples.
    
        Returns
        -------
        C : array, shape = [n_samples]
            Predicted class label per sample.
        """
        scores = self.decision_function(X)
        if len(scores.shape) == 1:
            indices = (scores > 0).astype(np.int)
        else:
            indices = scores.argmax(axis=1)
        return self.classes_[indices]
    

    因此,除了mbatchkarov建议你可以通过改变分类器所说的某个类是一类或另一类的边界来改变分类器(真正的任何分类器)所做的决定.

    from collections import Counter
    import numpy as np
    from sklearn.datasets import load_iris
    from sklearn.svm import LinearSVC
    
    data = load_iris()
    
    # remove a feature to make the problem harder
    # remove the third class for simplicity
    X = data.data[:100, 0:1] 
    y = data.target[:100] 
    # shuffle data
    indices = np.arange(y.shape[0])
    np.random.shuffle(indices)
    X = X[indices, :]
    y = y[indices]
    
    decision_boundary = 0
    print Counter((clf.decision_function(X[50:]) > decision_boundary).astype(np.int8))
    Counter({1: 27, 0: 23})
    
    decision_boundary = 0.5
    print Counter((clf.decision_function(X[50:]) > decision_boundary).astype(np.int8))
    Counter({0: 39, 1: 11})
    

    您可以根据需要优化决策边界.

    2023-01-30 17:36 回答
撰写答案
今天,你开发时遇到什么问题呢?
立即提问
热门标签
PHP1.CN | 中国最专业的PHP中文社区 | PNG素材下载 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有