首页人工智能Pytorch【深度学习(PyTorch...

【深度学习(PyTorch篇)】26.Dataset类

本系列文章配套代码获取有以下两种途径:

  • 通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj 提取码:mnsj
  • 前往GitHub获取
https://github.com/returu/PyTorch





01

torchvision.datasets模块:


Pytorch中,torchvision.datasets是一个非常重要的模块,它提供了一系列预定义的数据集,MNISTCIFAR10CIFAR100ImageNetCOCO 等,这些数据集可以直接用于训练和测试机器学习模型。

torchvision.datasets提供了一些加载数据的函数及常用的数据集接口,方便用户快速加载数据集。下面以下载CIFAR10数据集为例,其中:

  • 第一个参数:数据集的下载位置;
  • 第二个参数:指定下载训练集还是验证集;
  • 第三个参数:如果在第一个参数指定的位置找不到数据,是否允许PyTorch下载数据。

from torchvision import datasets
cifar10 = datasets.CIFAR10('./data/' , train=True , download=True)
通过__mro__ 属性,获取方法解析顺序,可以看到以上数据集都是作为torch.utils.data.Dataset的子类返回的。

在Python中,__mro__ 是“Method Resolution Order”的缩写,指的是方法解析顺序。当你调用一个对象的方法时,Python会按照一定的顺序在类的继承链中查找这个方法,这个顺序就是通过 __mro__ 来确定的。

每个类都有一个 __mro__ 属性,它是一个包含类及其所有超类的元组,按照特定的顺序排列。这个顺序是根据C3线性化算法来计算的,这是一种确保所有类都能被正确地、无歧义地解析其方法的算法。

# 通过__mro__ 属性,获取方法解析顺序
type(cifar10).__mro__

# 输出结果
# (torchvision.datasets.cifar.CIFAR10,
#  torchvision.datasets.vision.VisionDataset,
#  torch.utils.data.dataset.Dataset,
#  typing.Generic,
#  object)
torch.utils.data.Dataset是一个需要实现2种函数的对象:
  • __len__():返回数据中的样本总数;

  • __getitem__()根据给定的索引 index 返回一个样本。这个样本通常是一个包含数据和标签的元组 (data, label)

# __len__()
len(cifar10) # 相当于调用cifar10.__len__()
# 输出结果:50000

# __getitem__()
img , label = cifar10[9] # 相当于调用cifar10.__getitem__(9)

class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']

plt.figure(figsize=(3,3))
plt.axis('off')
plt.imshow(img)
plt.title(class_names[label])

程序输出如下图所示:

02

Dataset类:


Dataset 类是 PyTorch 中的一个核心抽象类,用于表示数据集,并提供了一种统一的方式来处理数据。
Dataset 类本身并不包含数据,而是一个框架,指导你如何组织和访问数据。为了创建一个可用的数据集,需要继承 Dataset 类并重写__len__()__getitem__()两个方法。
通过继承 Dataset 类并实现其特定的方法,用户可以自定义数据集,使其能够与 PyTorch 的数据加载工具(如 DataLoader)无缝集成,以方便地进行批量加载、打乱和并行加载数据。
下面是一个简单的例子,展示了如何创建一个继承自 torch.utils.data.Dataset 的自定义数据集类:
  • 图像数据:
所有图像均存放在一个文件夹(dogcat)下,根据图像文件名可以获取图像对应的标签(cat/dog)。
dogcat
├── cat.1.jpg
├── cat.2.jpg
├── cat.3.jpg
├── cat.4.jpg
├── cat.5.jpg
├── dog.1.jpg
├── dog.2.jpg
├── dog.3.jpg
├── dog.4.jpg
└── dog.5.jpg
  • 自定义数据集类:
自定义的DogCat类需要继承torch.utils.data.Dataset,并且实现__len__ __getitem__ 方法。
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np

# 自定义数据集对象
class DogCat(Dataset):
    def __init__(self,root):
        imgs = os.listdir(root)
        # 指定所有图像的绝对路径,当调用__getitem__()时才会真正读取图像
        self.imgs = [os.path.join(root,img) for img in imgs]

    def __getitem__(self,index):
        # 根据索引idx返回数据和标签
        img_path = self.imgs[index]
        # 获取图像对应标签
        label = 1 if 'dog' in img_path.split('/')[-1else 0
        pil_img = Image.open(img_path)
        array = np.asarray(pil_img)
        data = torch.tensor(array)
        return data,label

    def __len__(self):
        # 返回数据集的样本总数
        return len(self.imgs)
  • 访问自定义数据集数据:
使用自定义的DogCat创建数据集实例即可访问数据集。
# 创建数据集实例  
dataset = DogCat('./data/dogcat/')

# 获取数据集的样本总数
len(dataset)
# 输出结果:10

# 根据索引获取图像数据和对应标签
img,label = dataset[0]
img.shape,label
# 输出结果:(torch.Size([280, 300, 3]), 0)

# 遍历自定义的数据集
for img,label in dataset:
    print(img.shape ,img.float().mean() ,label)
# 输出结果:
# torch.Size([280, 300, 3]) tensor(71.6653) 0
# torch.Size([396, 312, 3]) tensor(131.8400) 0
# torch.Size([414, 500, 3]) tensor(156.6921) 0
# torch.Size([375, 499, 3]) tensor(96.8243) 0
# torch.Size([425, 320, 3]) tensor(158.8270) 0
# torch.Size([119, 158, 3]) tensor(90.2523) 1
# torch.Size([481, 500, 3]) tensor(144.1063) 1
# torch.Size([374, 500, 3]) tensor(119.4067) 1
# torch.Size([272, 265, 3]) tensor(100.7006) 1
# torch.Size([400, 300, 3]) tensor(128.1550) 1

03

ImageFolder:


ImageFolderPyTorch中一个用于处理图像数据集的类,它提供了一种方便的方式来加载和处理具有特定文件结构的图像数据集。
其构造函数如下:
ImageFolder(root, 
            transform=None
            target_transform=None
            loader=default_loader,
            is_valid_file=None)
其中:
  • root:在root指定的路径下寻找图片。ImageFolder会在该路径下的每个子文件夹中查找图像文件,并假设每个子文件夹代表一个类别。
  • transform:对PIL Image进行数据增强的操作(例如,裁剪、缩放或归一化等预处理操作),会应用到从数据集中加载的每张图片上。如果未指定,则不进行任何转换。
  • target_transform:对标签(label)进行转换。类似于transform,但是这个转换是应用到标签上的,而不是图像数据。如果未指定,则不进行任何转换。
  • loader:指定加载图片的函数。默认情况下,它会加载为RGB格式的PIL Image对象。如果需要加载不同格式的图片或者以不同的方式加载图片,可以提供自定义的加载函数。
  • is_valid_file:获取图像路径,检查文件的有效性。
  • 数据组织:

在使用ImageFolder时,需要将数据集组织成特定的结构。ImageFolder假设所有图像文件是按文件夹保存的,且文件夹名为类名。即每个类别的图像应该放在单独的子文件夹中,而这些子文件夹都位于同一个根目录下。例如,如果有猫和狗两个类别,那么文件结构应该如下,其中,dogcat_2是数据集的根目录,catdog是不同类别的子文件夹

└── dogcat_2
    ├── cat
    │   ├── cat.1.jpg
    │   ├── cat.2.jpg
    │   ├── cat.3.jpg
    │   ├── cat.4.jpg
    │   └── cat.5.jpg
    └── dog
        ├── dog.1.jpg
        ├── dog.2.jpg
        ├── dog.3.jpg
        ├── dog.4.jpg
        └── dog.5.jpg
  • 创建ImageFolder实例:

在创建了适当的数据结构后,就可以创建ImageFolder实例了。

from torchvision.datasets import ImageFolder

# 创建ImageFolder实例
dataset = ImageFolder('data/dogcat_2/')
  • 访问数据集:

创建ImageFolder实例后,可以通过索引来访问数据集中的图像和对应的标签。

image, label = dataset[4]  # 获取第4个样本的图像和标签

len(dataset)  # 获取数据集中的样本总数
  • imgs和class_to_idx属性:

  • imgs 属性:
当创建一个 ImageFolder 实例时,这个实例会自动扫描指定目录下的所有图像,并根据子文件夹的名称分配标签。然后,它将所有图像的路径和对应的标签存储在 dataset.imgs 属性中。
因此,dataset.imgs 将返回一个包含所有图像路径和对应标签的列表。每个元素是一个元组,其中第一个元素是图像的路径,第二个元素是该图像的类别标签(通常是整数)。
# 获取所有图像文件的路径及对应的label
dataset.imgs

#
 输出结果:
# [('data/dogcat_2/cat\cat.1.jpg', 0),
#  ('data/dogcat_2/cat\cat.2.jpg', 0),
#  ('data/dogcat_2/cat\cat.3.jpg', 0),
#  ('data/dogcat_2/cat\cat.4.jpg', 0),
#  ('data/dogcat_2/cat\cat.5.jpg', 0),
#  ('data/dogcat_2/dog\dog.1.jpg', 1),
#  ('data/dogcat_2/dog\dog.2.jpg', 1),
#  ('data/dogcat_2/dog\dog.3.jpg', 1),
#  ('data/dogcat_2/dog\dog.4.jpg', 1),
#  ('data/dogcat_2/dog\dog.5.jpg', 1)]
  • class_to_idx性:

在生成数据的label时,首先会按照文件夹名进行排序,然后将文件夹名保存为字典形式({类名:类序号})。

例如,假设有一个图像数据集,其中包含三个类别的图像:猫、狗和鸟。这些类别的图像分别存储在名为 “cat”, “dog”, 和 “bird” 的文件夹中。当使用ImageFolder加载数据集时,“cat” 类别被映射为索引 0,“dog” 类别被映射为索引 1,“bird” 类别被映射为索引 2。这样,当模型进行训练和预测时,它会使用这些索引值而不是类别的名称。

{  
    'cat': 0,  
    'dog': 1,  
    'bird': 2  
}

因此,一般来说,最好将文件夹命名为从0开始的数字,这样就可以和ImageFolder实际的label保持一致。如果没有遵循这种命名方式,建议通过class_to_idx属性来了解label和文件夹名的映射关系。

# 通过self.class_to_idx属性获取label和文件夹名的映射关系
dataset.class_to_idx
# 输出结果:{'cat': 0, 'dog': 1}


更多内容可以前往官网查看

https://pytorch.org/


本篇文章来源于微信公众号: 码农设计师

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments