本系列文章配套代码获取有以下两种途径:
-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
在上一篇文章中,介绍过DataLoader构造函数中的sampler参数用于控制数据加载时的采样策略。
PyTorch提供了几种默认的Sampler实现,如SequentialSampler(顺序采样)和RandomSampler(随机采样)。
需要注意的是,在DataLoader中,shuffle参数与Sampler是冲突的,两者只能指定一个。
-
如果指定了shuffle=True且未指定Sampler,则DataLoader会使用默认的RandomSampler。 -
如果指定了shuffle=False且未指定Sampler,则DataLoader会使用SequentialSampler。
└── dogcat_2
├── cat
│ ├── cat.1.jpg
│ ├── cat.2.jpg
│ ├── cat.3.jpg
│ ├── cat.4.jpg
│ └── cat.5.jpg
└── dog
├── dog.1.jpg
├── dog.2.jpg
├── dog.3.jpg
├── dog.4.jpg
└── dog.5.jpg
# 创建一个转换流水线
transform = transforms.Compose([
transforms.Resize(224), # 将图像大小调整为224x224像素
transforms.CenterCrop(224), # 在图像中心裁剪出224x224的区域
transforms.ToTensor() # 将PIL图像或NumPy ndarray转换为PyTorch张量
])
# 创建自定义数据集
dataset = ImageFolder('data/dogcat_2/' , transform=transform)
SequentialSampler(顺序采样):
在创建DataLoader时,如果指定shuffle=False且不提供自定义的Sampler,那么默认就会使用SequentialSampler。也可以显式地创建一个SequentialSampler实例并将其传递给DataLoader。
from torch.utils.data.sampler import SequentialSampler
# 创建 SequentialSampler 实例
sampler = SequentialSampler(dataset)
# 实例化DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,sampler=sampler)
# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
print(batch_labels.tolist())
[0, 0, 0]
[0, 0, 1]
[1, 1, 1]
[1]
RandomSampler(随机采样):
在创建DataLoader时,如果指定shuffle=True且不提供自定义的Sampler,那么默认会使用RandomSampler。也可以显式地创建一个RandomSampler实例并将其传递给DataLoader。
from torch.utils.data.sampler import RandomSampler
# 创建 SequentialSampler 实例
sampler = RandomSampler(dataset)
# 实例化DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,sampler=sampler)
# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
print(batch_labels.tolist())
输出结果如下:
[0, 1, 1]
[0, 1, 0]
[0, 0, 1]
[1]
通过输出结果,可以看出RandomSampler会随机地从数据集中抽取样本,每次迭代返回的索引是随机的,从而实现了数据的随机化加载。
WeightedRandomSampler(权重采样):
当数据集存在类别不平衡问题时,或者某些样本比其他样本更重要时,可以使用WeightedRandomSampler。通过为不同的样本分配不同的权重,可以调整模型在训练过程中对各类样本的关注度。
要使用WeightedRandomSampler,需要显式地创建一个WeightedRandomSampler实例,并将其传递给DataLoader:
-
计算样本权重:
-
创建WeightedRandomSampler对象:
-
创建数据加载器:
-
weights:用于采样的权重序列,权重与样本一一对应,表示每个样本被采样的相对概率。 -
num_samples:采样的数量,即最终希望从采样器中获取的样本数量。 -
replacement:是否可放回采样。如果为True,则样本可以被重复采样;如果为False,则每个样本在一次采样过程中只能被选择一次。 -
generator:用于采样的生成器,通常不需要用户自行设置。
本次,数据集只包含dog、cat两类,因此本次假设dog的取出概率为cat的两倍:
# 计算样本权重
weights = [2 if label==1 else 1 for data,label in dataset]
from torch.utils.data.sampler import WeightedRandomSampler
# 创建 SequentialSampler 实例
sampler = WeightedRandomSampler(weights , num_samples=12 , replacement=True)
# 实例化DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,sampler=sampler)
# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
print(batch_labels.tolist())
输出结果如下:
[1, 1, 0]
[1, 1, 1]
[0, 0, 0]
[1, 1, 1]
通过输出结果,可以看出dog、cat的样本比例约为2:1。另外,虽然数据集中只有10张图像,但是最终返回了12个样本,这是因为设置了replacement=True,允许有样本被重复采样。
如果,replacement=False,那么当某一类的样本被全部选取完,但是样本数仍未达到设置的num_samples值时,采样器不会再从该类中选取图像,此时会报错(RuntimeError)。
在replacement=False情况下,num_samples的值为样本总数时,为了不重复选取图像,Sampler将返回每一个样本,此时weights参数不在生效。
from torch.utils.data.sampler import WeightedRandomSampler
# 创建 SequentialSampler 实例,num_samples=10
sampler = WeightedRandomSampler(weights , num_samples=10 , replacement=False)
# 实例化DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,sampler=sampler)
# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
print(batch_labels.tolist())
输出结果如下:
[1, 0, 1]
[0, 1, 0]
[1, 1, 0]
[0]
自定义Sampler:
Sampler本质上是一个迭代器基类,通过提供一个__iter__方法来产生迭代索引值,以及一个__len__方法来返回迭代器的长度。
from torch.utils.data import Sampler
# 自定义的交替采样器
class AlternatingSampler(Sampler):
"""
首先定义了一个AlternatingSampler类,该类继承了Sampler基类。
在__init__方法中,我们遍历整个数据集,将所有样本按照类别分组,并存储每个类别的样本索引。
在__iter__方法中,我们交替地从每个类别中选取样本,从而实现交替采样的功能。
"""
def __init__(self, data_source):
self.data_source = data_source
# 循环遍历数据源,为每个类别创建一个索引列表。
# 这些列表保存在 self.class_indices 字典中,其中键是类别标签,值是属于该类别的样本索引列表。
# {0: [0, 1, 2, 3, 4], 1: [5, 6, 7, 8, 9]}
self.class_indices = {}
for idx, (_, label) in enumerate(self.data_source):
if label not in self.class_indices:
self.class_indices[label] = []
self.class_indices[label].append(idx)
def __iter__(self):
# 初始化一个空列表,用于存储交替采样的索引
indices = []
# 获取每个类别的索引
# 本次的数据源只包含dog、cat两个类别,标签分别为 0 和 1。
class0_indices = self.class_indices[0]
class1_indices = self.class_indices[1]
# 循环交替地从 class0_indices 和 class1_indices 中取出索引,并添加到 indices 列表中。
# 循环的次数是两个类别中样本数量较大值。
max_len = max(len(class0_indices), len(class1_indices))
for i in range(max_len):
if i < len(class0_indices):
indices.append(class0_indices[i])
if i < len(class1_indices):
indices.append(class1_indices[i])
# 将 indices 列表转换为迭代器并返回。
return iter(indices)
def __len__(self):
return len(self.data_source)
# 创建Sampler实例
sampler = AlternatingSampler(dataset)
# 使用Sampler创建DataLoader
dataloader = DataLoader(dataset, batch_size=3, sampler=sampler)
# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
print(batch_labels.tolist())
[0, 1, 0]
[1, 0, 1]
[0, 1, 0]
[1]
更多内容可以前往官网查看:
https://pytorch.org/
本篇文章来源于微信公众号: 码农设计师