本系列文章配套代码获取有以下两种途径:
-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
torchvision.datasets模块:
在Pytorch中,torchvision.datasets是一个非常重要的模块,它提供了一系列预定义的数据集,如 MNIST、CIFAR10、CIFAR100、ImageNet、COCO 等,这些数据集可以直接用于训练和测试机器学习模型。
torchvision.datasets提供了一些加载数据的函数及常用的数据集接口,方便用户快速加载数据集。下面以下载CIFAR10数据集为例,其中:
-
第一个参数:数据集的下载位置; -
第二个参数:指定下载训练集还是验证集; -
第三个参数:如果在第一个参数指定的位置找不到数据,是否允许PyTorch下载数据。
from torchvision import datasets
cifar10 = datasets.CIFAR10('./data/' , train=True , download=True)
在Python中,__mro__ 是“Method Resolution Order”的缩写,指的是方法解析顺序。当你调用一个对象的方法时,Python会按照一定的顺序在类的继承链中查找这个方法,这个顺序就是通过 __mro__ 来确定的。
每个类都有一个 __mro__ 属性,它是一个包含类及其所有超类的元组,按照特定的顺序排列。这个顺序是根据C3线性化算法来计算的,这是一种确保所有类都能被正确地、无歧义地解析其方法的算法。
# 通过__mro__ 属性,获取方法解析顺序
type(cifar10).__mro__
# 输出结果
# (torchvision.datasets.cifar.CIFAR10,
# torchvision.datasets.vision.VisionDataset,
# torch.utils.data.dataset.Dataset,
# typing.Generic,
# object)
-
__len__():返回数据中的样本总数;
-
__getitem__():根据给定的索引 index 返回一个样本。这个样本通常是一个包含数据和标签的元组 (data, label)。
# __len__()
len(cifar10) # 相当于调用cifar10.__len__()
# 输出结果:50000
# __getitem__()
img , label = cifar10[9] # 相当于调用cifar10.__getitem__(9)
class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
plt.figure(figsize=(3,3))
plt.axis('off')
plt.imshow(img)
plt.title(class_names[label])
程序输出如下图所示:
Dataset类:
-
图像数据:
dogcat
├── cat.1.jpg
├── cat.2.jpg
├── cat.3.jpg
├── cat.4.jpg
├── cat.5.jpg
├── dog.1.jpg
├── dog.2.jpg
├── dog.3.jpg
├── dog.4.jpg
└── dog.5.jpg
-
自定义数据集类:
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
# 自定义数据集对象
class DogCat(Dataset):
def __init__(self,root):
imgs = os.listdir(root)
# 指定所有图像的绝对路径,当调用__getitem__()时才会真正读取图像
self.imgs = [os.path.join(root,img) for img in imgs]
def __getitem__(self,index):
# 根据索引idx返回数据和标签
img_path = self.imgs[index]
# 获取图像对应标签
label = 1 if 'dog' in img_path.split('/')[-1] else 0
pil_img = Image.open(img_path)
array = np.asarray(pil_img)
data = torch.tensor(array)
return data,label
def __len__(self):
# 返回数据集的样本总数
return len(self.imgs)
-
访问自定义数据集数据:
# 创建数据集实例
dataset = DogCat('./data/dogcat/')
# 获取数据集的样本总数
len(dataset)
# 输出结果:10
# 根据索引获取图像数据和对应标签
img,label = dataset[0]
img.shape,label
# 输出结果:(torch.Size([280, 300, 3]), 0)
# 遍历自定义的数据集
for img,label in dataset:
print(img.shape ,img.float().mean() ,label)
# 输出结果:
# torch.Size([280, 300, 3]) tensor(71.6653) 0
# torch.Size([396, 312, 3]) tensor(131.8400) 0
# torch.Size([414, 500, 3]) tensor(156.6921) 0
# torch.Size([375, 499, 3]) tensor(96.8243) 0
# torch.Size([425, 320, 3]) tensor(158.8270) 0
# torch.Size([119, 158, 3]) tensor(90.2523) 1
# torch.Size([481, 500, 3]) tensor(144.1063) 1
# torch.Size([374, 500, 3]) tensor(119.4067) 1
# torch.Size([272, 265, 3]) tensor(100.7006) 1
# torch.Size([400, 300, 3]) tensor(128.1550) 1
ImageFolder:
ImageFolder(root,
transform=None,
target_transform=None,
loader=default_loader,
is_valid_file=None)
-
root:在root指定的路径下寻找图片。ImageFolder会在该路径下的每个子文件夹中查找图像文件,并假设每个子文件夹代表一个类别。 -
transform:对PIL Image进行数据增强的操作(例如,裁剪、缩放或归一化等预处理操作),会应用到从数据集中加载的每张图片上。如果未指定,则不进行任何转换。 -
target_transform:对标签(label)进行转换。类似于transform,但是这个转换是应用到标签上的,而不是图像数据。如果未指定,则不进行任何转换。 -
loader:指定加载图片的函数。默认情况下,它会加载为RGB格式的PIL Image对象。如果需要加载不同格式的图片或者以不同的方式加载图片,可以提供自定义的加载函数。 -
is_valid_file:获取图像路径,检查文件的有效性。
-
数据组织:
在使用ImageFolder时,需要将数据集组织成特定的结构。ImageFolder假设所有图像文件是按文件夹保存的,且文件夹名为类名。即每个类别的图像应该放在单独的子文件夹中,而这些子文件夹都位于同一个根目录下。例如,如果有猫和狗两个类别,那么文件结构应该如下,其中,dogcat_2是数据集的根目录,cat和dog是不同类别的子文件夹:
└── dogcat_2
├── cat
│ ├── cat.1.jpg
│ ├── cat.2.jpg
│ ├── cat.3.jpg
│ ├── cat.4.jpg
│ └── cat.5.jpg
└── dog
├── dog.1.jpg
├── dog.2.jpg
├── dog.3.jpg
├── dog.4.jpg
└── dog.5.jpg
-
创建ImageFolder实例:
在创建了适当的数据结构后,就可以创建ImageFolder实例了。
from torchvision.datasets import ImageFolder
# 创建ImageFolder实例
dataset = ImageFolder('data/dogcat_2/')
-
访问数据集:
创建ImageFolder实例后,可以通过索引来访问数据集中的图像和对应的标签。
image, label = dataset[4] # 获取第4个样本的图像和标签
len(dataset) # 获取数据集中的样本总数
-
imgs和class_to_idx属性:
-
imgs 属性:
# 获取所有图像文件的路径及对应的label
dataset.imgs
# 输出结果:
# [('data/dogcat_2/cat\cat.1.jpg', 0),
# ('data/dogcat_2/cat\cat.2.jpg', 0),
# ('data/dogcat_2/cat\cat.3.jpg', 0),
# ('data/dogcat_2/cat\cat.4.jpg', 0),
# ('data/dogcat_2/cat\cat.5.jpg', 0),
# ('data/dogcat_2/dog\dog.1.jpg', 1),
# ('data/dogcat_2/dog\dog.2.jpg', 1),
# ('data/dogcat_2/dog\dog.3.jpg', 1),
# ('data/dogcat_2/dog\dog.4.jpg', 1),
# ('data/dogcat_2/dog\dog.5.jpg', 1)]
-
class_to_idx属性:
在生成数据的label时,首先会按照文件夹名进行排序,然后将文件夹名保存为字典形式({类名:类序号})。
例如,假设有一个图像数据集,其中包含三个类别的图像:猫、狗和鸟。这些类别的图像分别存储在名为 “cat”, “dog”, 和 “bird” 的文件夹中。当使用ImageFolder加载数据集时,“cat” 类别被映射为索引 0,“dog” 类别被映射为索引 1,“bird” 类别被映射为索引 2。这样,当模型进行训练和预测时,它会使用这些索引值而不是类别的名称。
{
'cat': 0,
'dog': 1,
'bird': 2
}
因此,一般来说,最好将文件夹命名为从0开始的数字,这样就可以和ImageFolder实际的label保持一致。如果没有遵循这种命名方式,建议通过class_to_idx属性来了解label和文件夹名的映射关系。
# 通过self.class_to_idx属性获取label和文件夹名的映射关系
dataset.class_to_idx
# 输出结果:{'cat': 0, 'dog': 1}
更多内容可以前往官网查看:
https://pytorch.org/
本篇文章来源于微信公众号: 码农设计师