200字范文,内容丰富有趣,生活中的好帮手!
200字范文 > 图像分类数据集 (FASHION-MNIST)

图像分类数据集 (FASHION-MNIST)

时间:2022-09-30 01:02:17

相关推荐

图像分类数据集 (FASHION-MNIST)

文章目录

引入1 获取数据集2 简单操作3 读取小批量4 完整代码致谢

引入

图像分类数据集最常用的是手写数字识别数据集MNIST (1),但是大部分模型在其上的分类精度都超过了95%。为了更直观地观察算法之间的差异,将使用一个图像内容更加复杂的数据集[Fashion-MNIST (2)]。

接下来的部分将使用torchvision包,主要用于构建计算机视觉模型,主要由以下4部分组成:

代码已上传至github:

/InkiInki/Python/blob/master/Python1/deepLearning/ImageMnist.py

1 获取数据集

需要导入的包如下:

import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport timeimport sysfrom IPython import display

下面,将通过torchvision.datasets下载数据集,第一次调用时会自动从网上获取数据 (若出现速度较慢,请向后查看注意);通过参数train来指定获取训练集或者测试集;通过transform = transforms.Tensor()将数据转化为Tensor,如果不转换,则返回PIL图片。

transforms.Tensor()将尺寸为 (H×W×CH×W×CH×W×C)且数据位于 (0, 255)的PIL图片或数据类型为np.uint8的Numpy转换为尺寸为 (C×H×WC×H×WC×H×W)且数据类型为torch.float32且位于 (0.0, 1.0)的Tensor。

使用代码如下:

class ImageMnist():def __init__(self):self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=True, download=True, transform=transforms.ToTensor())self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=False, download=True, transform=transforms.ToTensor())if __name__ == "__main__":test = ImageDataSet()test.__init__()print(test.mnist_train)print(len(test.mnist_train), len(test.mnist_test))

运行结果:

Dataset FashionMNISTNumber of datapoints: 60000Root location: C:\Users\Administrator/DataSets/FashionMNISTSplit: TrainStandardTransformTransform: ToTensor()60000 10000

注意:

1)如果用像素值表示图片数据,那么一律将其类型设置成unit8,以避免不必要的bug;

2)第一次下载时速度也许很慢,推荐在cmd中输入以下代码,并复制出现的http链接下载:

import torchvisionimport torchvision.transforms as transformstorchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())

2 简单操作

可以通过下标来访问任意一个样本:

if __name__ == "__main__":test = ImageMnist()test.__init__()data, label = test.mnist_train[0]print(data.shape)print(label)

运行结果:

torch.Size([1, 28, 28]) # 分别对应通道数、图像高、图像宽9

Fashion-MNIST共10个类别,分别为t-shirt、trouser、pullover、dress、coat、sandal、shirt、sneaker、bag和ankle boot,以下函数可以将数值标签转换成相应的文本标签:

...def get_text_labels(self, labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]if __name__ == "__main__":test = ImageMnist()test.__init__()data, label = test.mnist_train[0]print(test.get_text_labels([label]))

运行结果:

['ankle boot']

现在定义一个可以在一行里画出多张图像和对应标签的函数:

...def show_mnist(self, images, labels):display.set_matplotlib_formats('svg')_, figs = plt.subplots(1, len(images), figsize=(12, 12))# zip()接受一系列可迭代对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axis('off')plt.show()if __name__ == "__main__":test = ImageMnist()test.__init__()x, y = [], []for i in range(10):x.append(test.mnist_train[i][0])y.append(test.mnist_train[i][1])test.show_mnist(x, test.get_text_labels(y))

运行结果:

3 读取小批量

torch的DataLoader中一个很方便的功能是运行使用多进程来加速读取数据,这里通过参数num_workers来设置4个进程读取数据。

...def data_iter(self, batch_size=256):if sys.platform.startswith('win'):num_workers = 0 # 0表示不需要额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iterif __name__ == "__main__":start = time.time()test = ImageMnist()test.__init__()train_iter, test_iter = test.data_iter()for x, y in train_iter:continueprint("%.2f sec" % (time.time() - start))

运行结果:

6.65 sec

4 完整代码

'''@(#)test.pyThe class of test.Author: Yu-Xuan ZhangEmail: inki.yinji@Created on May 05, Last Modified on May 05, @author: inki'''import torchimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltimport timeimport sysfrom IPython import displayclass ImageMnist():def __init__(self):self.mnist_train = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=True, download=True, transform=transforms.ToTensor())self.mnist_test = torchvision.datasets.FashionMNIST(root='~/DataSets/FashionMNIST',train=False, download=True, transform=transforms.ToTensor())def get_text_labels(self, labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]def show_mnist(self, images, labels):display.set_matplotlib_formats('svg')_, figs = plt.subplots(1, len(images), figsize=(12, 12))for f, img, lbl in zip(figs, images, labels):f.imshow(img.view((28, 28)).numpy())f.set_title(lbl)f.axis('off')plt.show()def data_iter(self, batch_size=256):if sys.platform.startswith('win'):num_workers = 0else:num_workers = 4train_iter = torch.utils.data.DataLoader(self.mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(self.mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iterif __name__ == "__main__":start = time.time()test = ImageMnist()test.__init__()train_iter, test_iter = test.data_iter()for x, y in train_iter:continueprint("%.2f sec" % (time.time() - start))

致谢

特别感谢李沐、Aston Zhang等老师的这本《动手学深度学习》一书~

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。