200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > 图像识别:利用KNN实现手写数字识别(mnist数据集)

图像识别:利用KNN实现手写数字识别(mnist数据集)

时间:2020-04-20 10:40:33

相关推荐

图像识别:利用KNN实现手写数字识别(mnist数据集)

图像识别:利用KNN实现手写数字识别(mnist数据集)

步骤:

1、数据的加载(trainSize和testSize不要设置的太大)

2、k值的设定(不宜过大)

3、KNN的核心:距离的计算

4、k个最近的图片-->根据下标寻找对应的标签

5、根据标签转化成相应的数字

6、检测概率统计

在我看来,KNN算法最大的优点是简单,准确率较高;

最大的缺点是:当数据量较大时,计算量成倍增长,测试集与训练集之间的任意两个元素之间都要计算距离。

注意1:trainSize和testSize不要设置的太大,如果过大,数据处理中产生更加庞大的数据,内存溢出,导致程序崩溃。

注意2:k值的设定太大会提高计算机的计算量,而且会一定程度上降低准确率。

import tensorflow as tfimport numpy as npfrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('D:/MNIST_data', one_hot=True)trainNum = 55000testNum = 10000trainSize = 500testSize = 5k = 4# data 分解trainIndex = np.random.choice(trainNum, trainSize, replace=False)testIndex = np.random.choice(testNum, testSize, replace=False)trainData = mnist.train.images[trainIndex] # 训练图片trainLabel = mnist.train.labels[trainIndex] # 训练标签testData = mnist.test.images[testIndex] # 测试图片testLabel = mnist.test.labels[testIndex] # 测试标签# 利用placeholder来完成数据的加载trainDataInput = tf.placeholder(shape=[None, 784], dtype=tf.float32)trainLabelInput = tf.placeholder(shape=[None, 10], dtype=tf.float32)testDataInput = tf.placeholder(shape=[None, 784], dtype=tf.float32)testLabelInput = tf.placeholder(shape=[None, 10], dtype=tf.float32)# KNN的距离f1 = tf.expand_dims(testDataInput, 1) # 维度扩展f2 = tf.subtract(trainDataInput, f1) # 二者之差f3 = tf.reduce_sum(tf.abs(f2), reduction_indices=2)f4 = tf.negative(f3) # 取反f5, f6 = tf.nn.top_k(f4, k=k) # 最大的四个值 f5表示的是数据 f6表示的该数据所处的下标f7 = tf.gather(trainLabelInput, f6) # 根据f6下标去寻找trainLabelInput中对应的标签f8 = tf.reduce_sum(f7, reduction_indices=1)f9 = tf.argmax(f8, dimension=1)with tf.Session() as sess:p9 = sess.run(f9, feed_dict={trainDataInput: trainData, testDataInput: testData, trainLabelInput: trainLabel})p10 = np.argmax(testLabel, axis=1)print('预测值:', p9)print('真实值:', p10)j = 0for i in range(0, testSize):if p10[i] == p9[i]:j += 1print('accuracy:', j*100/testSize)

作死设置了一回,电脑是游戏本,屏幕出现卡顿,加速球爆满,还好运行出来了

设置的过高,直接报错,资源耗尽。

训练集数量和K值该如何确定

从上图可以看出,trainSize不是设置的越高越好,在满足较高准确率的同时,又不能使计算量过于庞大,需要把握训练数据集的大小。

从上图可以看出,K值的设置过大反而会在一定程度上降低预测的准确率 ,所以设置k值时,需要对数据集有一定的了解,并且在一定的范围内取值。

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。