Scikit-learn GridSearch给出"ValueError:不支持多类格式"错误

 dongtiankzh 发布于 2022-12-18 12:19

我正在尝试使用GridSearch进行LinearSVC()的参数估计,如下所示 -

clf_SVM = LinearSVC()
params = {
          'C': [0.5, 1.0, 1.5],
          'tol': [1e-3, 1e-4, 1e-5],
          'multi_class': ['ovr', 'crammer_singer'],
          }
gs = GridSearchCV(clf_SVM, params, cv=5, scoring='roc_auc')
gs.fit(corpus1, y)

corpus1有形状(1726,7001),y有形状(1726,)

这是一个多类分类,y的值为0到3,包括两个值,即有四个类.

但这给了我以下错误 -

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
 in ()
      5           }
      6 gs = GridSearchCV(clf_SVM, params, cv=5, scoring='roc_auc')
----> 7 gs.fit(corpus1, y)

/usr/local/lib/python2.7/dist-packages/sklearn/grid_search.pyc in fit(self, X, y)
    594 
    595         """
--> 596         return self._fit(X, y, ParameterGrid(self.param_grid))
    597 
    598 

/usr/local/lib/python2.7/dist-packages/sklearn/grid_search.pyc in _fit(self, X, y, parameter_iterable)
    376                                     train, test, self.verbose, parameters,
    377                                     self.fit_params, return_parameters=True)
--> 378             for parameters in parameter_iterable
    379             for train, test in cv)
    380 

/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.pyc in __call__(self, iterable)
    651             self._iterating = True
    652             for function, args, kwargs in iterable:
--> 653                 self.dispatch(function, args, kwargs)
    654 
    655             if pre_dispatch == "all" or n_jobs == 1:

/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.pyc in dispatch(self, func, args, kwargs)
    398         """
    399         if self._pool is None:
--> 400             job = ImmediateApply(func, args, kwargs)
    401             index = len(self._jobs)
    402             if not _verbosity_filter(index, self.verbose):

/usr/local/lib/python2.7/dist-packages/sklearn/externals/joblib/parallel.pyc in __init__(self, func, args, kwargs)
    136         # Don't delay the application, to avoid keeping the input
    137         # arguments in memory
--> 138         self.results = func(*args, **kwargs)
    139 
    140     def get(self):

/usr/local/lib/python2.7/dist-packages/sklearn/cross_validation.pyc in _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters)
   1238     else:
   1239         estimator.fit(X_train, y_train, **fit_params)
-> 1240     test_score = _score(estimator, X_test, y_test, scorer)
   1241     if return_train_score:
   1242         train_score = _score(estimator, X_train, y_train, scorer)

/usr/local/lib/python2.7/dist-packages/sklearn/cross_validation.pyc in _score(estimator, X_test, y_test, scorer)
   1294         score = scorer(estimator, X_test)
   1295     else:
-> 1296         score = scorer(estimator, X_test, y_test)
   1297     if not isinstance(score, numbers.Number):
   1298         raise ValueError("scoring must return a number, got %s (%s) instead."

/usr/local/lib/python2.7/dist-packages/sklearn/metrics/scorer.pyc in __call__(self, clf, X, y)
    136         y_type = type_of_target(y)
    137         if y_type not in ("binary", "multilabel-indicator"):
--> 138             raise ValueError("{0} format is not supported".format(y_type))
    139 
    140         try:

ValueError: multiclass format is not supported

user1269942.. 10

从:

http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score

"注意:此实现仅限于标签指示符格式的二进制分类任务或多标签分类任务."

尝试:

y = label_binarize(y, classes=[0, 1, 2, 3])

在你训练之前.这将执行你的y的"一热"编码.

1 个回答
  • 从:

    http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score

    "注意:此实现仅限于标签指示符格式的二进制分类任务或多标签分类任务."

    尝试:

    y = label_binarize(y, classes=[0, 1, 2, 3])
    

    在你训练之前.这将执行你的y的"一热"编码.

    2022-12-18 12:20 回答
撰写答案
今天,你开发时遇到什么问题呢?
立即提问
热门标签
PHP1.CN | 中国最专业的PHP中文社区 | PNG素材下载 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有