本系列文章配套代码获取有以下两种途径:
-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
如果数据集中包含损坏的图像而无法读取等问题,需要采取一些策略来处理这些图像,否则在__getitem__()函数中会抛出异常。
# 自定义数据集对象
class DogCat(Dataset):
def __init__(self,root):
imgs = os.listdir(root)
self.imgs = [os.path.join(root,img) for img in imgs]
def __getitem__(self,index):
img_path = self.imgs[index]
label = 1 if 'dog' in img_path.split('/')[-1] else 0
pil_img = Image.open(img_path)
array = np.asarray(pil_img)
data = torch.tensor(array)
return data,label
def __len__(self):
return len(self.imgs)
# 创建数据集实例
dataset = DogCat('./data/dogcat_wrong/')
# 根据索引获取损坏的图像时会抛出错误
img,label = dataset[10]
img.shape,label
# 输出结果:
# UnidentifiedImageError: cannot identify image file './data/dogcat_wrong/dog.wrong.jpg'
对于图像损坏或数据集加载异常等情况,可以通过以下方式处理。
删除损坏的图像:
在数据集准备阶段,进行一次全面的数据清洗操作,检查每个图像文件是否可读、是否完整,以及是否符合预期的格式。
def check_images_in_directory(directory):
supported_formats = ('.jpg', '.jpeg', '.png', '.gif', '.bmp') # 支持的图像格式
damaged_images = [] # 用于存储损坏的图像文件路径
# 遍历指定目录下的所有文件
for filename in os.listdir(directory):
if filename.lower().endswith(supported_formats):
file_path = os.path.join(directory, filename)
try:
# 尝试打开图像文件
with Image.open(file_path) as img:
img.verify() # 验证图像文件的完整性
except (IOError, SyntaxError) as e:
# 如果无法打开或验证图像,则将其标记为损坏
damaged_images.append(file_path)
print(f"Damaged image detected: {file_path}")
return damaged_images
# 使用示例
directory_to_check = './data/dogcat_wrong/' # 图像文件夹路径
damaged = check_images_in_directory(directory_to_check)
# 输出结果:
# Damaged image detected: ./data/dogcat_wrong/dog.wrong.jpg
读取时过滤掉损坏的图像:
# 自定义数据集对象——继承前面实现的DogCat数据集
class NewDogCat(DogCat):
def __getitem__(self,index):
try:
# 调用父类的获取函数,即DogCat.__getitem__(self,index)
return super().__getitem__(index)
except:
# # 直接返回None
# return None, None
# 随机选择一张图片替代损坏的图像
new_index = random.randint(0, len(self) - 1)
# 确保新索引与原始索引不同
if new_index == index:
new_index = (index + 1) % len(self)
return self[new_index]
# 创建数据集实例
dataset = NewDogCat('./data/dogcat_wrong/')
# 根据索引获取损坏的图像时会随机选择一张图片
img,label = dataset[10]
img.shape,label
# 输出结果:(torch.Size([374, 500, 3]), 1)
更多内容可以前往官网查看:
https://pytorch.org/
本篇文章来源于微信公众号: 码农设计师