大多数的框架入门实例都是采用手写数字识别数据集MNIST,不过因为大部分的模型在MNIST数据集上的分类精度都超过了95%,不能很好的反映算法之间的差异性,因为MXNet系列采用了一个更加复杂的Fashion-MNIST数据集。
Fashion-MNIST数据集一共包含10个类别,分别为t-shirt(T恤), trouser(裤子), pullover(套衫), dress(连衣裙), coat(外套),sandal(凉鞋), shirt(衬衫), sneaker(运动鞋), bag(包), ankle boot(短靴)。
import d2lzh as d2l
from mxnet.gluon import data as gdata
import sys
mnist_train = gdata.vision.FashionMNIST(train=True)
mnist_test = gdata.vision.FashionMNIST(train=False)
# print(len(mnist_train),len(mnist_test)) # (60000, 10000)
# feature,label = mnist_train[0]
# print(feature.shape,label) # ((28, 28, 1), 2)
# 将数值标签转换为文本标签
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
# 可以在一行内画出多张图像和对应标签的函数
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
# _表示忽略(不使用)的变量
_, figs = d2l.plt.subplots(1, len(images), figsize=(12, 12)) # 设置图的尺寸
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.reshape(28, 28).asnumpy())
f.set_title(lbl) # 标题显示为图像的文本标签
f.axes.get_xaxis().set_visible(False) # 坐标轴不可见
f.axes.get_yaxis().set_visible(False) # 坐标轴不可见
d2l.plt.show()
# 查看训练样本中前10个样本的图像内容及文本标签
X, y = mnist_train[0:10]
show_fashion_mnist(X, get_fashion_mnist_labels(y))
Reference:
《动手学深度学习》