首页人工智能Pytorch【深度学习(PyTorch...

【深度学习(PyTorch篇)】41.采样函数grid_sample()

本系列文章配套代码获取有以下两种途径:

  • 通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj 提取码:mnsj
  • 前往GitHub获取
https://github.com/returu/PyTorch





01

grid_sample()


torch.nn.functional用于对输入图像进行空间变换的grid_sample()函数也是一个常用函数,该函数通常用于实现一些高级的图像处理技术(仿射变换),比如图像扭曲、图像配准、图像拼接等。
仿射变换不仅可用于数据预处理和数据增强,还直接体现在某些神经网络层的设计中。
grid_sample()函数的基本形式如下:
torch.nn.functional.grid_sample(input, 
                                grid, 
                                mode='bilinear'
                                padding_mode='zeros'
                                align_corners=False)

参数说明:

  • input:输入张量,形状为 (N, C, H_in, W_in),其中 N 是批量大小,C 是通道数,H_in W_in 分别是输入图像的高和宽。
  • grid采样网格,通常由affine_grid()函数生成。对于 2D 图像,grid 的形状应为 (N, H_out, W_out, 2),其中 N batch 的大小,H_outW_out 分别是输出图像的高度和宽度。最后的 2 表示每个网格点的 (x, y) 坐标。其中,坐标系统是基于输入图像或特征图的,其范围是 [-1, 1]。这意味着 (x, y) (x, y, z) 坐标在这个范围内。例如,对于 2D 图像,坐标 (-1, -1) 对应于输入图像的左上角,而 (1, 1) 对应于右下角。
  • mode:插值模式,可以是 ‘bilinear’(双线性插值)或 ‘bicubic’(双三次插值)。默认为 ‘bilinear’
  • padding_mode:当坐标超出图像边界时的填充模式。可以是 ‘zeros’(默认值,用0填充)、‘border’(使用边界像素值填充)、‘reflection’(使用镜像填充)或 ‘wrap’(使用图像另一边的像素填充)。
  • align_corners:如果为 True,则输入和输出图像的四个角将被对齐,从而保持角点的像素值。默认为 False

该函数的返回值是一个形状为 (N, C, H_out, W_out) 的张量,表示经过空间变换后的图像。

02

仿射变换


要使用grid_sample()函数对输入图像进行空间变换,首先,需要根据所需的几何变换(如旋转、缩放、平移)计算一个2×3的仿射变换矩阵。
  • 仿射变换矩阵:
仿射变换矩阵是一个数学工具,用于描述二维或三维空间中的线性变换(如旋转、缩放、错切等)和平移变换。在二维空间中,仿射变换可以通过一个2×3的矩阵来表示,而在三维空间中,则使用一个3×4的矩阵。
二维空间中的仿射变换矩阵为例,该矩阵包含了平移、缩放、旋转和剪切等变换的参数。仿射变换矩阵的一般形式如下:

其中 (A, B) (C, D) 控制了图像的旋转和缩放,(Tx, Ty) 控制了图像的平移。
  • 仿射变换公式:
对于一个点 P(x, y) 在原始坐标系中,经过仿射变换后得到的新坐标 P'(x’, y’) 可以通过以下公式计算:
x' = A * x + B * y + Tx
y' = C * x + D * y + Ty

其中,

  • (x, y) 是原始坐标系中点的坐标。
  • (x’, y’) 是仿射变换后点的新坐标。
  • ABCD 是控制旋转、缩放和剪切的矩阵元素。
  • TxTy 是平移的量。

这两个公式描述了仿射变换对坐标点的影响。通过适当地设置矩阵元素和平移量,你可以实现各种类型的仿射变换,包括平移、旋转、缩放和剪切。

03

示例


使用grid_sample()函数对输入图像进行空间变换包括以下步骤:
  • 计算仿射变换矩阵:首先,需要根据所需的几何变换(如旋转、缩放、平移)计算一个2×3的仿射变换矩阵。
  • 生成采样网格:使用affine_grid()函数和仿射变换矩阵生成采样网格。这个网格包含了变换后每个像素应该从原图像的哪个位置采样的信息。
  • 应用grid_sample():将待处理图像和采样网格传入grid_sample()函数,它会根据网格中的信息对特征图进行重采样,从而实现变换。
  • 1、读取图片:

本次仍以经典的Lena图作为处理图像。

lena = Image.open('lena.png')

# 将其转换为PyTorch张量
to_tensor = transforms.ToTensor()
lena_tensor = to_tensor(lena).unsqueeze(0)

2、计算仿射变换矩阵:

本次将使用两个变换矩阵:
  • 绕原点旋转 𝜃 弧度的变换矩阵;

  • y轴方向上沿x轴剪切的变换矩阵。

# 旋转(rotate):绕原点旋转 𝜃 弧度的变换矩阵
# 定义一个旋转角度(逆时针旋转90度),将角度转换为弧度
angle = -90 * math.pi / 180
# 矩阵的第一行对应于x轴的变换,第二行对应于y轴的变换.这里只涉及旋转,所以平移部分为0
theta_1 = torch.tensor([[math.cos(angle) , math.sin(-angle) , 0] , 
                      [math.sin(-angle) , math.cos(angle) , 0]]
 , dtype=torch.float)


# 剪切(shear):在y轴方向上沿x轴剪切变换矩阵
theta_2 = torch.tensor([[1 , 0 , 0] , 
                      [0.5 , 1 , 0]]
 , dtype=torch.float)

3、生成采样网格:

使用F.affine_grid()函数根据仿射变换矩阵和输入图像的大小来创建一个仿射网格。

# 使用F.affine_grid函数根据仿射变换矩阵和输入图像的大小来创建一个仿射网格
# 这个网格定义了输出图像中每个像素点在输入图像中的对应位置
# grid形状为(N, H_out, W_out, 2)
grid_1 = F.affine_grid(theta_1.unsqueeze(0) , lena_tensor.size())

grid_2 = F.affine_grid(theta_2.unsqueeze(0) , lena_tensor.size())

4、应用grid_sample():

将待处理图像和采样网格传入grid_sample()函数,它会根据网格中的信息对特征图进行重采样,从而实现变换。

# 使用F.grid_sample函数根据前面创建的仿射网格对输入图像进行采样
out_1 = F.grid_sample(lena_tensor , grid=grid_1 , mode='bilinear')

out_2 = F.grid_sample(lena_tensor , grid=grid_2 , mode='bilinear')

5、可视化变换结果:

使用matplotlib分别可视化原图、旋转结果图像剪切结果图像。
# 显示旋转后的图像
to_pil = transforms.ToPILImage()

img_1 = to_pil(out_1.data.squeeze(0))
img_2 = to_pil(out_2.data.squeeze(0))

# 显示中文字符
plt.rcParams['font.family'] = ['SimHei'

# 创建一个一行三列的子图布局  
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(155))  

# 原图 
axs[0].imshow(lena)  
axs[0].set_title('Lena')  
axs[0].axis('off')  # 关闭坐标轴  

# 旋转结果图像
axs[1].imshow(img_1)  
axs[1].set_title('旋转(rotate)')  
axs[1].axis('off'

# 剪切结果图像
axs[2].imshow(img_2)  
axs[2].set_title('剪切(shear)')  
axs[2].axis('off'
可视化结果如下所示:



更多内容可以前往官网查看

https://pytorch.org/


本篇文章来源于微信公众号: 码农设计师

RELATED ARTICLES

欢迎留下您的宝贵建议

Please enter your comment!
Please enter your name here

- Advertisment -

Most Popular

Recent Comments