200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > 用Keras进行手写字体识别(MNIST数据集)

用Keras进行手写字体识别(MNIST数据集)

时间:2021-01-05 23:51:33

相关推荐

用Keras进行手写字体识别(MNIST数据集)

数据

首先加载数据

from keras.datasets import mnist(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

接下来,看看这个数据集的基本情况:

train_images.shape

(60000, 28, 28)

len(train_labels)

60000

train_labels

array([5, 0, 4, …, 5, 6, 8], dtype=uint8)

test_images.shape

(10000, 28, 28)

len(test_labels)

10000

test_labels

array([7, 2, 1, …, 4, 5, 6], dtype=uint8)

网络构架

from keras import modelsfrom keras import layersnetwork = models.Sequential()network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))network.add(layers.Dense(10, activation='softmax'))

Keras可以帮助我们实现一层一层的连接起来,在本例中的网络包含2个Dense层,他们是密集连接(也叫全连接)的神经层。第二层是一个10路softmax层,他将返回一个由10个概率值(总和为1)组成的数组。每个概率值表示当前数字图像属于10个数字类别中的某一个的概率。

编译

要想训练网络,我们还需要设置编译步骤的三个参数:

损失函数

优化器(optimizer):基于训练数据和损失函数来更新网络的机制

在训练和测试过程中需要监控的指标(metric):本例只关心精度,即正确分类的图像所占的比例。

pile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])

图像数据预处理

在训练之前,需要对图像数据预处理,将其变换成网络要求的形状,并缩放所有值都在[0,1]区间。比如,之前训练图像保存在一个uint8类型的数组中,其形状为(60000,28, 28),取值范围为[0, 255]。我们需要将其变换成一个float32数组,其形状为(60000, 28*28),取值范围为0-1

train_images = train_images.reshape((60000, 28 * 28)) #把一个图像变成一列数据用于学习train_images = train_images.astype('float32') / 255 #astype用于进行数据类型转换test_images = test_images.reshape((10000, 28 * 28))test_images = test_images.astype('float32') / 255

训练

from keras.utils import to_categoricaltrain_labels = to_categorical(train_labels)test_labels = to_categorical(test_labels)network.fit(train_images, train_labels, epochs=5, batch_size=128)

会有一个输出

看看测试集表现如何:

test_loss, test_acc = network.evaluate(test_images, test_labels)print('test_acc:', test_acc)

我们还可以看某一个具体的图像显示情况

digit = train_images[4]import matplotlib.pyplot as pltplt.imshow(digit, cmap = plt.cm.binary)plt.show()

更多精彩内容,欢迎关注我的微信公众号:数据瞎分析

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。