-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
DataLoader:
在深度学习中,我们通常不会一个样本一个样本地处理数据,而是会批量处理数据,这样可以提高计算效率。同时,为了增加模型的泛化能力和避免过拟合,我们经常会在每个epoch开始时对数据进行随机打乱(shuffle)。
DataLoader(dataset, batch_size=1, shuffle=None, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None,
generator=None, *, prefetch_factor=None,
persistent_workers=False, pin_memory_device='')
其中主要参数含义如下:
-
dataset:加载的数据集,通常是Dataset类的实例。 -
batch_size:每个批次的样本数量。 -
shuffle:是否对数据进行洗牌操作,即打乱操作。 -
sampler:抽取样本的采样器。如果提供了sampler,shuffle参数必须设置为False。默认情况下,如果不指定sampler,并且shuffle为True,DataLoader 会随机打乱数据。通过自定义sampler,可以实现不同的数据抽样策略,例如按类别抽样、过采样/欠采样等。 -
num_workers:加载数据时的并发线程或进程数,0代表不使用多线程。 -
collate_fn:如何将多个样本拼接成一个batch。可以根据需要进行自定义,以实现复杂的数据处理逻辑。 -
pin_memory:当设置为True时,数据将被加载到固定的内存区域,这样数据可以从CPU更快地转移到GPU。 -
drop_last:当数据样本数量不能被批次大小整除时,是否丢弃最后一个不完整的批次。该参数通常用于确保每个批次都具有相同数量的样本。 -
timeout:进程读取数据的最大时间,若超时,则丢弃数据。 -
worker_init_fn:每个worker的初始化函数。该函数用于初始化每个工作进程的随机数生成器种子,以确保数据增强操作的可重复性。 -
multiprocessing_context:指定多进程的上下文语境。 -
generator:多进程使用的生成器。 -
prefetch_factor:指定每个worker应该预先加载的数据批次数量。 -
persistent_workers:设置为True时,worker进程在数据加载完毕后不会关闭,而是等待下一个数据加载任务。这可以减少因为频繁启动和关闭worker进程而产生的开销,但可能会增加内存使用。 -
pin_memory_device:指定用于固定内存的设备。
以下是一个DataLoader的使用示例:
-
创建自定义数据集:
首先使用ImageFolder生成的自定义数据集,然后实例化一个数据集对象。
# 创建一个转换流水线
transform = transforms.Compose([
transforms.Resize(224), # 将图像大小调整为224x224像素
transforms.CenterCrop(224), # 在图像中心裁剪出224x224的区域
transforms.ToTensor() # 将PIL图像或NumPy ndarray转换为PyTorch张量
])
# 创建自定义数据集
dataset = ImageFolder('data/dogcat_2/' , transform=transform)
-
实例化DataLoader:
# 实例化DataLoader :将实例化的数据集对象传递给DataLoader
dataloader = DataLoader(dataset ,
batch_size=3 ,
shuffle=True ,
num_workers=0 ,
drop_last=False)
-
使用DataLoader进行迭代:
# 使用DataLoader进行迭代
for batch in dataloader:
batch_datas , batch_labels = batch # 解包批次数据
print(batch_datas.size() , batch_labels.size())
# 输出结果:
# torch.Size([3, 3, 224, 224]) torch.Size([3])
# torch.Size([3, 3, 224, 224]) torch.Size([3])
# torch.Size([3, 3, 224, 224]) torch.Size([3])
# torch.Size([1, 3, 224, 224]) torch.Size([1])
# 使用DataLoader进行迭代
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, width
# 输出结果:torch.Size([3, 3, 224, 224])
数据集中存在损坏图像:
在之前文章中,介绍了当数据集中存在损坏图像时,可以在__getitem__()函数中加入错误处理机制。例如,捕获加载图像时抛出的异常,并返回None对象或者随机选择一张图片替代损坏的图像。
【深度学习(PyTorch篇)】27.自定义数据集中存在损坏图像
将实例化的数据集对象传递给DataLoader时,如果随机选择一张图片替代损坏的图像,则不会影响每个batch的样本数。如果返回None对象,则需要在DataLoader中实现自定义的collate_fn,从而将为None的数据过滤掉。
# 自定义数据集对象
class DogCat(Dataset):
def __init__(self,root,transforms=None):
imgs = os.listdir(root)
self.imgs = [os.path.join(root,img) for img in imgs]
self.transforms = transforms
def __getitem__(self,index):
img_path = self.imgs[index]
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
if self.transforms:
data = self.transforms(data)
return data,label
def __len__(self):
return len(self.imgs)
# 自定义数据集对象——继承前面实现的DogCat数据集
class NewDogCat(DogCat):
def __getitem__(self,index):
try:
# 调用父类的获取函数,即DogCat.__getitem__(self,index)
return super().__getitem__(index)
except:
# 直接返回None
return None, None
# 创建数据集实例
dataset = NewDogCat('./data/dogcat_wrong/' , transforms=transform)
# 导入默认的拼接方式
from torch.utils.data.dataloader import default_collate
# 实现自定义的collate_fn
def my_collate_fn(batch):
# batch是一个list,每个元素都是Dataset的返回值,形如(data ,label)
# 过滤为None的数据
batch = [_ for _ in batch if _[0] is not None]
if len(batch) ==0:
return torch.Tensor() # 返回一个空张量
return default_collate(batch) # 用默认方式拼接过滤后的batch数据
# 实例化DataLoader :将实例化的数据集对象传递给DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,shuffle=True , num_workers=0 , drop_last=False , collate_fn=my_collate_fn)
# 使用DataLoader进行迭代
for batch in dataloader:
batch_datas , batch_labels = batch # 解包批次数据
print(batch_datas.size() , batch_labels.size())
返回结果如下所示:
torch.Size([3, 3, 224, 224]) torch.Size([3])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([3, 3, 224, 224]) torch.Size([3])
torch.Size([2, 3, 224, 224]) torch.Size([2])
从上面的输出中可以看出,第2个batch的batch_size为2,这是因为有一张图像损坏被过滤掉了。第4个batch的batch_size也为2,这是因为自定义数据集中共有11张图像(包含损坏图像),无法整除以3(batch_size)。
更多内容可以前往官网查看:
https://pytorch.org/
本篇文章来源于微信公众号: 码农设计师