本系列文章配套代码获取有以下两种途径:
-
通过百度网盘获取:
链接:https://pan.baidu.com/s/1XuxKa9_G00NznvSK0cr5qw?pwd=mnsj
提取码:mnsj
-
前往GitHub获取:
https://github.com/returu/PyTorch
grid_sample():
torch.nn.functional.grid_sample(input,
grid,
mode='bilinear',
padding_mode='zeros',
align_corners=False)
参数说明:
-
input:输入张量,形状为 (N, C, H_in, W_in),其中 N 是批量大小,C 是通道数,H_in 和 W_in 分别是输入图像的高和宽。 -
grid:采样网格,通常由affine_grid()函数生成。对于 2D 图像,grid 的形状应为 (N, H_out, W_out, 2),其中 N 是 batch 的大小,H_out 和 W_out 分别是输出图像的高度和宽度。最后的 2 表示每个网格点的 (x, y) 坐标。其中,坐标系统是基于输入图像或特征图的,其范围是 [-1, 1]。这意味着 (x, y) 或 (x, y, z) 坐标在这个范围内。例如,对于 2D 图像,坐标 (-1, -1) 对应于输入图像的左上角,而 (1, 1) 对应于右下角。 -
mode:插值模式,可以是 ‘bilinear’(双线性插值)或 ‘bicubic’(双三次插值)。默认为 ‘bilinear’。 -
padding_mode:当坐标超出图像边界时的填充模式。可以是 ‘zeros’(默认值,用0填充)、‘border’(使用边界像素值填充)、‘reflection’(使用镜像填充)或 ‘wrap’(使用图像另一边的像素填充)。 -
align_corners:如果为 True,则输入和输出图像的四个角将被对齐,从而保持角点的像素值。默认为 False。
该函数的返回值是一个形状为 (N, C, H_out, W_out) 的张量,表示经过空间变换后的图像。
仿射变换:
-
仿射变换矩阵:
-
仿射变换公式:
x' = A * x + B * y + Tx
y' = C * x + D * y + Ty
其中,
-
(x, y) 是原始坐标系中点的坐标。 -
(x’, y’) 是仿射变换后点的新坐标。 -
A、B、C 和 D 是控制旋转、缩放和剪切的矩阵元素。 -
Tx 和 Ty 是平移的量。
这两个公式描述了仿射变换对坐标点的影响。通过适当地设置矩阵元素和平移量,你可以实现各种类型的仿射变换,包括平移、旋转、缩放和剪切。
示例:
-
计算仿射变换矩阵:首先,需要根据所需的几何变换(如旋转、缩放、平移)计算一个2×3的仿射变换矩阵。 -
生成采样网格:使用affine_grid()函数和仿射变换矩阵生成采样网格。这个网格包含了变换后每个像素应该从原图像的哪个位置采样的信息。 -
应用grid_sample():将待处理图像和采样网格传入grid_sample()函数,它会根据网格中的信息对特征图进行重采样,从而实现变换。
-
1、读取图片:
本次仍以经典的Lena图作为处理图像。
lena = Image.open('lena.png')
# 将其转换为PyTorch张量
to_tensor = transforms.ToTensor()
lena_tensor = to_tensor(lena).unsqueeze(0)
2、计算仿射变换矩阵:
-
绕原点旋转 𝜃 弧度的变换矩阵;
-
在y轴方向上沿x轴剪切的变换矩阵。
# 旋转(rotate):绕原点旋转 𝜃 弧度的变换矩阵
# 定义一个旋转角度(逆时针旋转90度),将角度转换为弧度
angle = -90 * math.pi / 180
# 矩阵的第一行对应于x轴的变换,第二行对应于y轴的变换.这里只涉及旋转,所以平移部分为0
theta_1 = torch.tensor([[math.cos(angle) , math.sin(-angle) , 0] ,
[math.sin(-angle) , math.cos(angle) , 0]] , dtype=torch.float)
# 剪切(shear):在y轴方向上沿x轴剪切变换矩阵
theta_2 = torch.tensor([[1 , 0 , 0] ,
[0.5 , 1 , 0]] , dtype=torch.float)
3、生成采样网格:
使用F.affine_grid()函数根据仿射变换矩阵和输入图像的大小来创建一个仿射网格。
# 使用F.affine_grid函数根据仿射变换矩阵和输入图像的大小来创建一个仿射网格
# 这个网格定义了输出图像中每个像素点在输入图像中的对应位置
# grid形状为(N, H_out, W_out, 2)
grid_1 = F.affine_grid(theta_1.unsqueeze(0) , lena_tensor.size())
grid_2 = F.affine_grid(theta_2.unsqueeze(0) , lena_tensor.size())
4、应用grid_sample():
将待处理图像和采样网格传入grid_sample()函数,它会根据网格中的信息对特征图进行重采样,从而实现变换。
# 使用F.grid_sample函数根据前面创建的仿射网格对输入图像进行采样
out_1 = F.grid_sample(lena_tensor , grid=grid_1 , mode='bilinear')
out_2 = F.grid_sample(lena_tensor , grid=grid_2 , mode='bilinear')
5、可视化变换结果:
# 显示旋转后的图像
to_pil = transforms.ToPILImage()
img_1 = to_pil(out_1.data.squeeze(0))
img_2 = to_pil(out_2.data.squeeze(0))
# 显示中文字符
plt.rcParams['font.family'] = ['SimHei']
# 创建一个一行三列的子图布局
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
# 原图
axs[0].imshow(lena)
axs[0].set_title('Lena')
axs[0].axis('off') # 关闭坐标轴
# 旋转结果图像
axs[1].imshow(img_1)
axs[1].set_title('旋转(rotate)')
axs[1].axis('off')
# 剪切结果图像
axs[2].imshow(img_2)
axs[2].set_title('剪切(shear)')
axs[2].axis('off')
更多内容可以前往官网查看:
https://pytorch.org/
本篇文章来源于微信公众号: 码农设计师