首页人工智能Pytorch【深度学习(PyTorch...

【深度学习(PyTorch篇)】30.DataLoader中的采样器Sampler

本系列文章配套代码获取有以下两种途径:

  • 通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj 提取码:mnsj
  • 前往GitHub获取
https://github.com/returu/PyTorch





在上一篇文章中,介绍过DataLoader构造函数中的sampler参数用于控制数据加载时的采样策略。

PyTorch提供了几种默认的Sampler实现,如SequentialSampler(顺序采样)和RandomSampler(随机采样)。

需要注意的是,在DataLoader中,shuffle参数与Sampler是冲突的,两者只能指定一个。

  • 如果指定了shuffle=True且未指定Sampler,则DataLoader会使用默认的RandomSampler
  • 如果指定了shuffle=False且未指定Sampler,则DataLoader使用SequentialSampler
首先使用ImageFolder创建自定义数据集用于后续演示。文件结构如下所示,总共包含10张图像
└── dogcat_2
    ├── cat
    │   ├── cat.1.jpg
    │   ├── cat.2.jpg
    │   ├── cat.3.jpg
    │   ├── cat.4.jpg
    │   └── cat.5.jpg
    └── dog
        ├── dog.1.jpg
        ├── dog.2.jpg
        ├── dog.3.jpg
        ├── dog.4.jpg
        └── dog.5.jpg
代码如下所示:
# 创建一个转换流水线 
transform = transforms.Compose([
    transforms.Resize(224), # 将图像大小调整为224x224像素  
    transforms.CenterCrop(224), # 在图像中心裁剪出224x224的区域
    transforms.ToTensor() # 将PIL图像或NumPy ndarray转换为PyTorch张量
])

# 创建自定义数据集
dataset = ImageFolder('data/dogcat_2/' , transform=transform)

01

SequentialSampler(顺序采样)


在创建DataLoader时,如果指定shuffle=False且不提供自定义的Sampler,那么默认就会使用SequentialSampler也可以显式地创建一个SequentialSampler实例并将其传递给DataLoader

from torch.utils.data.sampler import SequentialSampler

# 创建 SequentialSampler 实例
sampler = SequentialSampler(dataset)

# 实例化DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,sampler=sampler)

# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
    print(batch_labels.tolist())
输出结果如下:
[000]
[001]
[111]
[1]
通过输出结果,可以看出SequentialSampler会按照数据集中的顺序依次采样数据。每次迭代时,它会返回下一个索引,直到遍历完整个数据集。
02

RandomSampler(随机采样)


在创建DataLoader时,如果指定shuffle=True且不提供自定义的Sampler,那么默认会使用RandomSampler。也可以显式地创建一个RandomSampler实例并将其传递给DataLoader

from torch.utils.data.sampler import RandomSampler

# 创建 SequentialSampler 实例
sampler = RandomSampler(dataset)

# 实例化DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,sampler=sampler)

# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
    print(batch_labels.tolist())

输出结果如下:

[011]
[010]
[001]
[1]

通过输出结果,可以看出RandomSampler会随机地从数据集中抽取样本,每次迭代返回的索引是随机的,从而实现了数据的随机化加载。

03

WeightedRandomSampler(权重采样)


当数据集存在类别不平衡问题时,或者某些样本比其他样本更重要时,可以使用WeightedRandomSampler。通过为不同的样本分配不同的权重,可以调整模型在训练过程中对各类样本的关注度。

要使用WeightedRandomSampler需要显式地创建一个WeightedRandomSampler实例,并将其传递给DataLoader

  • 计算样本权重:

根据数据集中每个样本所属类别的数量,可以计算出每个样本的权重。常见的计算方法包括使用倒数、平衡因子等。
  • 创建WeightedRandomSampler对象:

使用torch.utils.data.WeightedRandomSampler类创建一个采样器对象。传入计算好的样本权重。
  • 创建数据加载器:

将采样器对象作为参数传入torch.utils.data.DataLoader类,用于创建数据加载器。数据加载器会根据采样器对象的权重进行样本选择。
WeightedRandomSampler包含以下参数:
  • weights:用于采样的权重序列,权重与样本一一对应,表示每个样本被采样的相对概率。
  • num_samples:采样的数量,即最终希望从采样器中获取的样本数量。
  • replacement:是否可放回采样。如果为True,则样本可以被重复采样;如果为False,则每个样本在一次采样过程中只能被选择一次。
  • generator:用于采样的生成器,通常不需要用户自行设置。

本次,数据集只包含dog、cat两类,因此本次假设dog的取出概率cat的两倍:

# 计算样本权重
weights = [2 if label==1 else 1 for data,label in dataset]

from torch.utils.data.sampler import WeightedRandomSampler

# 创建 SequentialSampler 实例
sampler = WeightedRandomSampler(weights , num_samples=12 , replacement=True)

# 实例化DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,sampler=sampler)

# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
    print(batch_labels.tolist())

输出结果如下:

[110]
[111]
[000]
[111]

通过输出结果,可以看出dog、cat的样本比例约为2:1。另外,虽然数据集中只有10张图像,但是最终返回了12个样本,这是因为设置了replacement=True,允许有样本被重复采样。

如果,replacement=False,那么当某一类的样本被全部选取完,但是样本数仍未达到设置的num_samples值时,采样器不会再从该类中选取图像,此时会报错(RuntimeError)。

replacement=False情况下,num_samples的值为样本总数时,为了不重复选取图像,Sampler将返回每一个样本,此时weights参数不在生效。

from torch.utils.data.sampler import WeightedRandomSampler

# 创建 SequentialSampler 实例,num_samples=10
sampler = WeightedRandomSampler(weights , num_samples=10 , replacement=False)

# 实例化DataLoader
dataloader = DataLoader(dataset , batch_size=3 ,sampler=sampler)

# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
    print(batch_labels.tolist())

输出结果如下:

[101]
[010]
[110]
[0]

04

自定义Sampler


Sampler本质上是一个迭代器基类,通过提供一个__iter__方法来产生迭代索引值,以及一个__len__方法来返回迭代器的长度。

用户可以根据需要自定义Sampler,只需实现__iter__()__len__()方法即可。自定义的Sampler可以与DataLoader一起使用,以实现对数据加载的精细控制。
在下面这个例子中,我们将创建一个简单的Sampler,它交替地从数据集中抽取样本,模拟一种“交替采样”策略。
from torch.utils.data import Sampler 

# 自定义的交替采样器
class AlternatingSampler(Sampler):
    """
    首先定义了一个AlternatingSampler类,该类继承了Sampler基类。
    在__init__方法中,我们遍历整个数据集,将所有样本按照类别分组,并存储每个类别的样本索引。
    在__iter__方法中,我们交替地从每个类别中选取样本,从而实现交替采样的功能。
    "
""
    def __init__(self, data_source):
        self.data_source = data_source
        # 循环遍历数据源,为每个类别创建一个索引列表。
        # 这些列表保存在 self.class_indices 字典中,其中键是类别标签,值是属于该类别的样本索引列表。
        # {0: [0, 1, 2, 3, 4], 1: [5, 6, 7, 8, 9]}
        self.class_indices = {}
        for idx, (_, label) in enumerate(self.data_source):
            if label not in self.class_indices:
                self.class_indices[label] = []
            self.class_indices[label].append(idx)

    def __iter__(self):
        # 初始化一个空列表,用于存储交替采样的索引
        indices = []
        # 获取每个类别的索引
        # 本次的数据源只包含dog、cat两个类别,标签分别为 0 和 1。
        class0_indices = self.class_indices[0]
        class1_indices = self.class_indices[1]

        # 循环交替地从 class0_indices 和 class1_indices 中取出索引,并添加到 indices 列表中。
        # 循环的次数是两个类别中样本数量较大值。
        max_len = max(len(class0_indices), len(class1_indices))
        for i in range(max_len):
            if i < len(class0_indices):
                indices.append(class0_indices[i])
            if i < len(class1_indices):
                indices.append(class1_indices[i])

        # 将 indices 列表转换为迭代器并返回。
        return iter(indices)

    def __len__(self):
        return len(self.data_source)

# 创建Sampler实例  
sampler = AlternatingSampler(dataset)

# 使用Sampler创建DataLoader  
dataloader = DataLoader(dataset, batch_size=3, sampler=sampler)  

# 遍历 DataLoader,并打印每个批次的label数据
for batch_datas , batch_labels in dataloader:
    print(batch_labels.tolist())
输出结果如下:
[010]
[101]
[010]
[1]
通过输出结果,可以看出采样器会交替地从每个类别中选取样本,从而实现交替采样的功能。

更多内容可以前往官网查看

https://pytorch.org/


本篇文章来源于微信公众号: 码农设计师

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments