热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

关于KNN的python3实现

关于KNN,有幸看到这篇文章,写的很好,这里就不在赘述。直接贴上代码了,有小的改动。(原来是python2版本的,这里改为python3的,主要就是print)环境:win732bit+

  关于KNN,有幸看到这篇文章,写的很好,这里就不在赘述。直接贴上代码了,有小的改动。(原来是python2版本的,这里改为python3的,主要就是print)

  环境:win7 32bit + spyder + anaconda3.5

  一、初阶

# -*- coding: utf-8 -*-
"""
Created on Sun Nov 6 16:09:00 2016

@author: Administrator
"""

#Input:
#newInput:待测的数据点(1xM)
#dataSet:已知的数据(NxM)
#labels:已知数据的标签(1xM)
#k:选取的最邻近数据点的个数
#
#Output:
#待测数据点的分类标签
#

from numpy import *

# creat a dataset which contain 4 samples with 2 class
def createDataSet():
# creat a matrix: each row as a sample
group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels


#classify using KNN
def KNNClassify(newInput, dataSet, labels, k):
numSamples = dataSet.shape[0] # row number
# step1:calculate Euclidean distance
# tile(A, reps):Constract an array by repeating A reps times
diff = tile(newInput, (numSamples, 1)) - dataSet
squreDiff = diff**2
squreDist = sum(squreDiff, axis=1) # sum if performed by row
distance = squreDist ** 0.5

#step2:sort the distance
# argsort() returns the indices that would sort an array in a ascending order
sortedDistIndices = argsort(distance)

classCount = {}
for i in range(k):
# choose the min k distance
voteLabel = labels[sortedDistIndices[i]]

#step4:count the times labels occur
# when the key voteLabel is not in dictionary classCount,
# get() will return 0
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
#step5:the max vote class will return
maxCount = 0
for k, v in classCount.items():
if v > maxCount:
maxCount = v
maxIndex = k

return maxIndex


# test

dataSet, labels = createDataSet()

testX = array([1.2, 1.0])
k = 3
outputLabel = KNNClassify(testX, dataSet, labels, 3)

print("Your input is:", testX, "and classified to class: ", outputLabel)


testX = array([0.1, 0.3])
k = 3
outputLabel = KNNClassify(testX, dataSet, labels, 3)

print("Your input is:", testX, "and classified to class: ", outputLabel)

  运行结果:

 

  二、进阶

  用到的手写识别数据库资料在这里下载。关于资料的介绍在上面的博文也已经介绍的很清楚了。

# -*- coding: utf-8 -*-
"""
Created on Sun Nov 6 16:09:00 2016

@author: Administrator
"""

#Input:
#newInput:待测的数据点(1xM)
#dataSet:已知的数据(NxM)
#labels:已知数据的标签(1xM)
#k:选取的最邻近数据点的个数
#
#Output:
#待测数据点的分类标签
#

from numpy import *



#classify using KNN
def KNNClassify(newInput, dataSet, labels, k):
numSamples = dataSet.shape[0] # row number
# step1:calculate Euclidean distance
# tile(A, reps):Constract an array by repeating A reps times
diff = tile(newInput, (numSamples, 1)) - dataSet
squreDiff = diff**2
squreDist = sum(squreDiff, axis=1) # sum if performed by row
distance = squreDist ** 0.5

#step2:sort the distance
# argsort() returns the indices that would sort an array in a ascending order
sortedDistIndices = argsort(distance)

classCount = {}
for i in range(k):
# choose the min k distance
voteLabel = labels[sortedDistIndices[i]]

#step4:count the times labels occur
# when the key voteLabel is not in dictionary classCount,
# get() will return 0
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
#step5:the max vote class will return
maxCount = 0
for k, v in classCount.items():
if v > maxCount:
maxCount = v
maxIndex = k

return maxIndex



# convert image to vector
def img2vector(filename):
rows = 32
cols = 32
imgVector = zeros((1, rows * cols))
fileIn = open(filename)
for row in range(rows):
lineStr = fileIn.readline()
for col in range(cols):
imgVector[0, row * 32 + col] = int(lineStr[col])

return imgVector


# load dataSet
def loadDataSet():
## step 1: Getting training set
print("---Getting training set...")
dataSetDir = 'F:\\Techonolgoy\\算法学习\\KNN\\进阶\\'
trainingFileList = os.listdir(dataSetDir + 'trainingDigits') # load the training set
numSamples = len(trainingFileList)

train_x = zeros((numSamples, 1024))
train_y = []
for i in range(numSamples):
filename = trainingFileList[i]

# get train_x
train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' % filename)

# get label from file name such as "1_18.txt"
label = int(filename.split('_')[0]) # return 1
train_y.append(label)

## step 2: Getting testing set
print("---Getting testing set...")
testingFileList = os.listdir(dataSetDir + 'testDigits') # load the testing set
numSamples = len(testingFileList)
test_x = zeros((numSamples, 1024))
test_y = []
for i in range(numSamples):
filename = testingFileList[i]

# get train_x
test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' % filename)

# get label from file name such as "1_18.txt"
label = int(filename.split('_')[0]) # return 1
test_y.append(label)

return train_x, train_y, test_x, test_y

# test hand writing class
def testHandWritingClass():
## step 1: load data
print("step 1: load data...")
train_x, train_y, test_x, test_y = loadDataSet()

## step 2: training...
print("step 2: training...")
pass

## step 3: testing
print("step 3: testing...")
numTestSamples = test_x.shape[0]
matchCount = 0
for i in range(numTestSamples):
predict = KNNClassify(test_x[i], train_x, train_y, 3)
if predict == test_y[i]:
matchCount += 1
accuracy = float(matchCount) / numTestSamples

## step 4: show the result
print("step 4: show the result...")
print('The classify accuracy is: %.2f%%' % (accuracy * 100))



testHandWritingClass()

  运行结果:

 


推荐阅读
  • 本文主要解析了Open judge C16H问题中涉及到的Magical Balls的快速幂和逆元算法,并给出了问题的解析和解决方法。详细介绍了问题的背景和规则,并给出了相应的算法解析和实现步骤。通过本文的解析,读者可以更好地理解和解决Open judge C16H问题中的Magical Balls部分。 ... [详细]
  • This article discusses the efficiency of using char str[] and char *str and whether there is any reason to prefer one over the other. It explains the difference between the two and provides an example to illustrate their usage. ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • 本文介绍了设计师伊振华受邀参与沈阳市智慧城市运行管理中心项目的整体设计,并以数字赋能和创新驱动高质量发展的理念,建设了集成、智慧、高效的一体化城市综合管理平台,促进了城市的数字化转型。该中心被称为当代城市的智能心脏,为沈阳市的智慧城市建设做出了重要贡献。 ... [详细]
  • IhaveconfiguredanactionforaremotenotificationwhenitarrivestomyiOsapp.Iwanttwodiff ... [详细]
  • 前景:当UI一个查询条件为多项选择,或录入多个条件的时候,比如查询所有名称里面包含以下动态条件,需要模糊查询里面每一项时比如是这样一个数组条件:newstring[]{兴业银行, ... [详细]
  • vue使用
    关键词: ... [详细]
  • 在Android开发中,使用Picasso库可以实现对网络图片的等比例缩放。本文介绍了使用Picasso库进行图片缩放的方法,并提供了具体的代码实现。通过获取图片的宽高,计算目标宽度和高度,并创建新图实现等比例缩放。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 本文介绍了P1651题目的描述和要求,以及计算能搭建的塔的最大高度的方法。通过动态规划和状压技术,将问题转化为求解差值的问题,并定义了相应的状态。最终得出了计算最大高度的解法。 ... [详细]
  • Python正则表达式学习记录及常用方法
    本文记录了学习Python正则表达式的过程,介绍了re模块的常用方法re.search,并解释了rawstring的作用。正则表达式是一种方便检查字符串匹配模式的工具,通过本文的学习可以掌握Python中使用正则表达式的基本方法。 ... [详细]
  • FeatureRequestIsyourfeaturerequestrelatedtoaproblem?Please ... [详细]
  • 本文介绍了一个题目的解法,通过二分答案来解决问题,但困难在于如何进行检查。文章提供了一种逃逸方式,通过移动最慢的宿管来锁门时跑到更居中的位置,从而使所有合格的寝室都居中。文章还提到可以分开判断两边的情况,并使用前缀和的方式来求出在任意时刻能够到达宿管即将锁门的寝室的人数。最后,文章提到可以改成O(n)的直接枚举来解决问题。 ... [详细]
  • 本文讨论了一个数列求和问题,该数列按照一定规律生成。通过观察数列的规律,我们可以得出求解该问题的算法。具体算法为计算前n项i*f[i]的和,其中f[i]表示数列中有i个数字。根据参考的思路,我们可以将算法的时间复杂度控制在O(n),即计算到5e5即可满足1e9的要求。 ... [详细]
  • Android自定义控件绘图篇之Paint函数大汇总
    本文介绍了Android自定义控件绘图篇中的Paint函数大汇总,包括重置画笔、设置颜色、设置透明度、设置样式、设置宽度、设置抗锯齿等功能。通过学习这些函数,可以更好地掌握Paint的用法。 ... [详细]
author-avatar
书友62423539
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有