热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

K-fold交叉验证实现python。-K-foldcrossvalidationimplementationpython

Iamtryingtoimplementthek-foldcross-validationalgorithminpython.IknowSKLearnprovidesan

I am trying to implement the k-fold cross-validation algorithm in python. I know SKLearn provides an implementation but still... This is my code as of right now.

我正在尝试在python中实现k-fold交叉验证算法。我知道SKLearn提供了一个实现,但是…这是我现在的代码。

from sklearn import metrics
import numpy as np

class Cross_Validation:

@staticmethod
def partition(vector, fold, k):
    size = vector.shape[0]
    start = (size/k)*fold
    end = (size/k)*(fold+1)
    validation = vector[start:end]
    if str(type(vector)) == "":
        indices = range(start, end)
        mask = np.ones(vector.shape[0], dtype=bool)
        mask[indices] = False
        training = vector[mask]
    elif str(type(vector)) == "":
        training = np.concatenate((vector[:start], vector[end:]))
    return training, validation

@staticmethod
def Cross_Validation(learner, k, examples, labels):
    train_folds_score = []
    validation_folds_score = []
    for fold in range(0, k):
        training_set, validation_set = Cross_Validation.partition(examples, fold, k)
        training_labels, validation_labels = Cross_Validation.partition(labels, fold, k)
        learner.fit(training_set, training_labels)
        training_predicted = learner.predict(training_set)
        validation_predicted = learner.predict(validation_set)
        train_folds_score.append(metrics.accuracy_score(training_labels, training_predicted))
        validation_folds_score.append(metrics.accuracy_score(validation_labels, validation_predicted))
    return train_folds_score, validation_folds_score

The learner parameter is a classifier from SKlearn library, k is the number of folds, examples is a sparse matrix produced by the CountVectorizer (again SKlearn) that is the representation of the bag of words. For example:

学习者参数是SKlearn库中的分类器,k是折叠数,例子是由CountVectorizer(再次SKlearn)制作的稀疏矩阵,它是单词包的表示形式。例如:

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from Cross_Validation import Cross_Validation as cv

vectorizer = CountVectorizer(stop_words='english', lowercase=True, min_df=2, analyzer="word")
data = vectorizer.fit_transform("""textual data""")
clfMNB = MultinomialNB(alpha=.0001)
score = cv.Cross_Validation(clfMNB, 10, data, labels)
print "Train score" + str(score[0])
print "Test score" + str(score[1])

I'm assuming there is some logic error somewhere since the scores are 95% on the training set (as expected) but practically 0 on the test test, but I can't find it.

我假设有一些逻辑上的错误,因为在训练集上的分数是95%(如预期的),但是在测试测试中几乎是0,但是我找不到。

I hope I was clear. Thanks in advance.

我希望我是清白的。提前谢谢。

________________________________EDIT___________________________________

________________________________EDIT___________________________________

This is the code that loads the text into the vector that can be passed to the vectorizer. It also returns the label vector.

这是将文本加载到可以传递给vectorizer的向量的代码。它还返回标签向量。

from nltk.tokenize import word_tokenize
from Categories_Data import categories
import numpy as np
import codecs
import glob
import os
import re

class Data_Preprocessor:

def tokenize(self, text):
    tokens = word_tokenize(text)
    alpha = [t for t in tokens if unicode(t).isalpha()]
    return alpha

def header_not_fully_removed(self, text):
    if ":" in text.splitlines()[0]:
        return len(text.splitlines()[0].split(":")[0].split()) == 1
    else:
        return False

def strip_newsgroup_header(self, text):
    _before, _blankline, after = text.partition('\n\n')
    if len(after) > 0 and self.header_not_fully_removed(after):
        after = self.strip_newsgroup_header(after)
    return after

def strip_newsgroup_quoting(self, text):
    _QUOTE_RE = re.compile(r'(writes in|writes:|wrote:|says:|said:'r'|^In article|^Quoted from|^\||^>)')
    good_lines = [line for line in text.split('\n')
        if not _QUOTE_RE.search(line)]
    return '\n'.join(good_lines)

def strip_newsgroup_footer(self, text):
    lines = text.strip().split('\n')
    for line_num in range(len(lines) - 1, -1, -1):
        line = lines[line_num]
        if line.strip().strip('-') == '':
            break
    if line_num > 0:
        return '\n'.join(lines[:line_num])
    else:
        return text

def raw_to_vector(self, path, to_be_stripped=["header", "footer", "quoting"], noise_threshold=-1):
    base_dir = os.getcwd()
    train_data = []
    label_data = []
    for category in categories:
        os.chdir(base_dir)
        os.chdir(path+"/"+category[0])
        for filename in glob.glob("*"):
            with codecs.open(filename, 'r', encoding='utf-8', errors='replace') as target:
                data = target.read()
                if "quoting" in to_be_stripped:
                    data = self.strip_newsgroup_quoting(data)
                if "header" in to_be_stripped:
                    data = self.strip_newsgroup_header(data)
                if "footer" in to_be_stripped:
                    data = self.strip_newsgroup_footer(data)
                if len(data) > noise_threshold:
                    train_data.append(data)
                    label_data.append(category[1])
    os.chdir(base_dir)
    return np.array(train_data), np.array(label_data)

This is what "from Categories_Data import categories" imports...

这是“从分类数据导入类别”导入的内容……

categories = [
    ('alt.atheism',0),
    ('comp.graphics',1),
    ('comp.os.ms-windows.misc',2),
    ('comp.sys.ibm.pc.hardware',3),
    ('comp.sys.mac.hardware',4),
    ('comp.windows.x',5),
    ('misc.forsale',6),
    ('rec.autos',7),
    ('rec.motorcycles',8),
    ('rec.sport.baseball',9),
    ('rec.sport.hockey',10),
    ('sci.crypt',11),
    ('sci.electronics',12),
    ('sci.med',13),
    ('sci.space',14),
    ('soc.religion.christian',15),
    ('talk.politics.guns',16),
    ('talk.politics.mideast',17),
    ('talk.politics.misc',18),
    ('talk.religion.misc',19)
 ]

1 个解决方案

#1


2  

The reason why your validation score is low is subtle.

你的验证分数低的原因很微妙。

The issue is how you have partitioned the dataset. Remember, when doing cross-validation you should randomly split the dataset. It is the randomness that you are missing.

问题是如何划分数据集。记住,在进行交叉验证时,应该随机地分割数据集。这就是你缺少的随机性。

Your data is loaded category by category, which means in your input dataset, class labels and examples follow one after the other. By not doing the random split, you have completely removed a class which your model never sees during the training phase and hence you get a bad result on your test/validation phase.

您的数据是按类别加载的,这意味着在您的输入数据集中,类标签和示例会跟随一个接着一个。通过不进行随机分割,您已经完全删除了在训练阶段中您的模型从未看到的类,因此您将在测试/验证阶段得到一个糟糕的结果。

You can solve this by doing a random shuffle. So, do this:

你可以通过随机洗牌来解决这个问题。所以,这样做:

from sklearn.utils import shuffle    

processor = Data_Preprocessor()
td, tl = processor.raw_to_vector(path="C:/Users/Pankaj/Downloads/ng/")
vectorizer = CountVectorizer(stop_words='english', lowercase=True, min_df=2, analyzer="word")
data = vectorizer.fit_transform(td)
# Shuffle the data and labels
data, tl = shuffle(data, tl, random_state=0)
clfMNB = MultinomialNB(alpha=.0001)
score = Cross_Validation.Cross_Validation(clfMNB, 10, data, tl)

print("Train score" + str(score[0]))
print("Test score" + str(score[1]))

推荐阅读
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • Spring特性实现接口多类的动态调用详解
    本文详细介绍了如何使用Spring特性实现接口多类的动态调用。通过对Spring IoC容器的基础类BeanFactory和ApplicationContext的介绍,以及getBeansOfType方法的应用,解决了在实际工作中遇到的接口及多个实现类的问题。同时,文章还提到了SPI使用的不便之处,并介绍了借助ApplicationContext实现需求的方法。阅读本文,你将了解到Spring特性的实现原理和实际应用方式。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 展开全部下面的代码是创建一个立方体Thisexamplescreatesanddisplaysasimplebox.#Thefirstlineloadstheinit_disp ... [详细]
  • 标题: ... [详细]
  • Java学习笔记之面向对象编程(OOP)
    本文介绍了Java学习笔记中的面向对象编程(OOP)内容,包括OOP的三大特性(封装、继承、多态)和五大原则(单一职责原则、开放封闭原则、里式替换原则、依赖倒置原则)。通过学习OOP,可以提高代码复用性、拓展性和安全性。 ... [详细]
  • 本文讨论了clone的fork与pthread_create创建线程的不同之处。进程是一个指令执行流及其执行环境,其执行环境是一个系统资源的集合。在调用系统调用fork创建一个进程时,子进程只是完全复制父进程的资源,这样得到的子进程独立于父进程,具有良好的并发性。但是二者之间的通讯需要通过专门的通讯机制,另外通过fork创建子进程系统开销很大。因此,在某些情况下,使用clone或pthread_create创建线程可能更加高效。 ... [详细]
  • 深入理解Kafka服务端请求队列中请求的处理
    本文深入分析了Kafka服务端请求队列中请求的处理过程,详细介绍了请求的封装和放入请求队列的过程,以及处理请求的线程池的创建和容量设置。通过场景分析、图示说明和源码分析,帮助读者更好地理解Kafka服务端的工作原理。 ... [详细]
  • IB 物理真题解析:比潜热、理想气体的应用
    本文是对2017年IB物理试卷paper 2中一道涉及比潜热、理想气体和功率的大题进行解析。题目涉及液氧蒸发成氧气的过程,讲解了液氧和氧气分子的结构以及蒸发后分子之间的作用力变化。同时,文章也给出了解题技巧,建议根据得分点的数量来合理分配答题时间。最后,文章提供了答案解析,标注了每个得分点的位置。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文介绍了C#中数据集DataSet对象的使用及相关方法详解,包括DataSet对象的概述、与数据关系对象的互联、Rows集合和Columns集合的组成,以及DataSet对象常用的方法之一——Merge方法的使用。通过本文的阅读,读者可以了解到DataSet对象在C#中的重要性和使用方法。 ... [详细]
  • Android源码深入理解JNI技术的概述和应用
    本文介绍了Android源码中的JNI技术,包括概述和应用。JNI是Java Native Interface的缩写,是一种技术,可以实现Java程序调用Native语言写的函数,以及Native程序调用Java层的函数。在Android平台上,JNI充当了连接Java世界和Native世界的桥梁。本文通过分析Android源码中的相关文件和位置,深入探讨了JNI技术在Android开发中的重要性和应用场景。 ... [详细]
author-avatar
呆保保_369
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有