首页人工智能MXNet4.图像分类数据集Fash...

4.图像分类数据集Fashion-MNIST

大多数的框架入门实例都是采用手写数字识别数据集MNIST,不过因为大部分的模型在MNIST数据集上的分类精度都超过了95%,不能很好的反映算法之间的差异性,因为MXNet系列采用了一个更加复杂的Fashion-MNIST数据集。
Fashion-MNIST数据集一共包含10个类别,分别为t-shirt(T恤), trouser(裤子), pullover(套衫), dress(连衣裙), coat(外套),sandal(凉鞋), shirt(衬衫), sneaker(运动鞋), bag(包), ankle boot(短靴)。

import d2lzh as d2l
from mxnet.gluon import data as gdata
import sys

mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)


# print(len(mnist_train),len(mnist_test)) # (60000, 10000)
# feature,label = mnist_train[0]
# print(feature.shape,label) # ((28, 28, 1), 2)

# 将数值标签转换为文本标签
def get_fashion_mnist_labels(labels):
	text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
	return [text_labels[int(i)] for i in labels]


# 可以在一行内画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):
	d2l.use_svg_display()
	# _表示忽略(不使用)的变量
	_, figs = d2l.plt.subplots(1, len(images), figsize=(12, 12)) # 设置图的尺寸
	for f, img, lbl in zip(figs, images, labels):
		f.imshow(img.reshape(28, 28).asnumpy())
		f.set_title(lbl) # 标题显示为图像的文本标签
		f.axes.get_xaxis().set_visible(False) # 坐标轴不可见
		f.axes.get_yaxis().set_visible(False) # 坐标轴不可见
	d2l.plt.show()

# 查看训练样本中前10个样本的图像内容及文本标签
X, y = mnist_train[0:10]
show_fashion_mnist(X, get_fashion_mnist_labels(y))

Reference:
《动手学深度学习》

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments