首页人工智能Pytorch【深度学习(PyTorch...

【深度学习(PyTorch篇)】29.DataLoader

本系列文章配套代码获取有以下两种途径:
  • 通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj 提取码:mnsj
  • 前往GitHub获取
https://github.com/returu/PyTorch





01

DataLoader


Pytorch中,Dataset用于表示数据集,并提供了一种方式来访问数据集中的数据。当实现一个自定义的 Dataset 类时,通常需要重写 __len__() __getitem__() 这两个方法。即,Dataset 只负责数据的抽象和单个样本的访问,每次调用 __getitem__() 时,它只会返回一个样本。

在深度学习中,我们通常不会一个样本一个样本地处理数据,而是会批量处理数据,这样可以提高计算效率。同时,为了增加模型的泛化能力和避免过拟合,我们经常会在每个epoch开始时对数据进行随机打乱(shuffle)。

考虑到上述需求,PyTorch提供了DataLoader这个工具。DataLoader是一个可以加载数据集、对数据进行批处理、打乱以及使用多线程进行数据加载的迭代器。其构造函数如下:
DataLoader(dataset, batch_size=1, shuffle=None, sampler=None
           batch_sampler=None, num_workers=0, collate_fn=None
           pin_memory=False, drop_last=False, timeout=0
           worker_init_fn=None, multiprocessing_context=None,
           generator=None, *, prefetch_factor=None
           persistent_workers=False, pin_memory_device='')

其中主要参数含义如下:

  • dataset加载的数据集,通常是Dataset类的实例。
  • batch_size每个批次的样本数量。
  • shuffle是否对数据进行洗牌操作,即打乱操作。
  • sampler抽取样本的采样器。如果提供了samplershuffle参数必须设置为False默认情况下,如果不指定sampler,并且shuffleTrueDataLoader 会随机打乱数据。通过自定义sampler,可以实现不同的数据抽样策略,例如按类别抽样、过采样/欠采样等。
  • num_workers加载数据时的并发线程或进程数,0代表不使用多线程。
  • collate_fn:如何将多个样本拼接成一个batch。可以根据需要进行自定义,以实现复杂的数据处理逻辑。
  • pin_memory当设置为True时,数据将被加载到固定的内存区域,这样数据可以从CPU更快地转移到GPU
  • drop_last当数据样本数量不能被批次大小整除时,是否丢弃最后一个不完整的批次。该参数通常用于确保每个批次都具有相同数量的样本。
  • timeout:进程读取数据的最大时间,若超时,则丢弃数据。
  • worker_init_fn:每个worker的初始化函数。该函数用于初始化每个工作进程的随机数生成器种子,以确保数据增强操作的可重复性。
  • multiprocessing_context:指定多进程的上下文语境。
  • generator:多进程使用的生成器。
  • prefetch_factor指定每个worker应该预先加载的数据批次数量。
  • persistent_workers设置为True时,worker进程在数据加载完毕后不会关闭,而是等待下一个数据加载任务。这可以减少因为频繁启动和关闭worker进程而产生的开销,但可能会增加内存使用。
  • pin_memory_device指定用于固定内存的设备。

以下是一个DataLoader的使用示例:

  • 创建自定义数据集:

首先使用ImageFolder生成的自定义数据集,然后实例化一个数据集对象

# 创建一个转换流水线 
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为224x224像素  
    transforms.CenterCrop(224), # 在图像中心裁剪出224x224的区域
    transforms.ToTensor() # 将PIL图像或NumPy ndarray转换为PyTorch张量
])

# 创建自定义数据集
dataset = ImageFolder('data/dogcat_2/' , transform=transform)
  • 实例化DataLoader:

将实例化的数据集对象传递给DataLoader
# 实例化DataLoader :将实例化的数据集对象传递给DataLoader
dataloader = DataLoader(dataset , 
                        batch_size=3 ,
                        shuffle=True , 
                        num_workers=0 , 
                        drop_last=False)
  • 使用DataLoader进行迭代:

DataLoader是一个可迭代(iterable)对象,可以像使用迭代器一样使用。
# 使用DataLoader进行迭代
for batch in dataloader:
    batch_datas , batch_labels = batch # 解包批次数据  
    print(batch_datas.size() , batch_labels.size())
# 输出结果:
# torch.Size([3, 3, 224, 224]) torch.Size([3])
# torch.Size([3, 3, 224, 224]) torch.Size([3])
# torch.Size([3, 3, 224, 224]) torch.Size([3])
# torch.Size([1, 3, 224, 224]) torch.Size([1])
从上面的输出中可以看出,最后一个batchbatch_size为1,这是因为自定义数据集中共有10张图像,无法整除以3(batch_size)。
也可以使用以下方式进行迭代:
# 使用DataLoader进行迭代
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, width
# 输出结果:torch.Size([3, 3, 224, 224])

02

数据集中存在损坏图


在之前文章中,介绍了当数据集中存在损坏图像时,可以__getitem__()函数中加入错误处理机制。例如,捕获加载图像时抛出的异常,并返回None对象或者随机选择一张图片替代损坏的图像。

【深度学习(PyTorch篇)】27.自定义数据集中存在损坏图像

将实例化的数据集对象传递给DataLoader时,如果随机选择一张图片替代损坏的图像,则不会影响每个batch的样本数。如果返回None象,则需要DataLoader中实现自定义的collate_fn,从而将为None的数据过滤掉。

# 自定义数据集对象
class DogCat(Dataset):
    def __init__(self,root,transforms=None):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root,img) for img in imgs]
        self.transforms = transforms

    def __getitem__(self,index):
        img_path = self.imgs[index]
        label = 1 if 'dog' in img_path.split('/')[-1else 0
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        return data,label

    def __len__(self):
        return len(self.imgs)

# 自定义数据集对象——继承前面实现的DogCat数据集
class NewDogCat(DogCat):
    def __getitem__(self,index):
        try:
            # 调用父类的获取函数,即DogCat.__getitem__(self,index)
            return super().__getitem__(index)
        except:
            # 直接返回None
            return NoneNone

# 创建数据集实例  
dataset = NewDogCat('./data/dogcat_wrong/' , transforms=transform)

# 导入默认的拼接方式
from torch.utils.data.dataloader import default_collate

# 实现自定义的collate_fn
def my_collate_fn(batch):
    # batch是一个list,每个元素都是Dataset的返回值,形如(data ,label)
    # 过滤为None的数据
    batch = [_ for _ in batch if _[0is not None]
    if len(batch) ==0:
        return torch.Tensor() # 返回一个空张量
    return default_collate(batch) # 用默认方式拼接过滤后的batch数据

# 实例化DataLoader :将实例化的数据集对象传递给DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,shuffle=True , num_workers=0 , drop_last=False , collate_fn=my_collate_fn)

# 使用DataLoader进行迭代
for batch in dataloader:
    batch_datas , batch_labels = batch # 解包批次数据  
    print(batch_datas.size() , batch_labels.size())

返回结果如下所示:

torch.Size([3, 3, 224, 224]torch.Size([3])
torch.Size([2, 3, 224, 224]torch.Size([2])
torch.Size([3, 3, 224, 224]torch.Size([3])
torch.Size([2, 3, 224, 224]torch.Size([2])

从上面的输出中可以看出,第2个batchbatch_size为2,这是因为有一张图像损坏被过滤掉了。第4个batchbatch_size为2,这是因为自定义数据集中共有11张图像(包含损坏图像),无法整除以3(batch_size)。



更多内容可以前往官网查看

https://pytorch.org/


本篇文章来源于微信公众号: 码农设计师

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments