别再死记公式了!用Python+PyTorch亲手画图理解卷积的‘放大’与‘缩小’

张开发
2026/4/15 14:55:41 15 分钟阅读

分享文章

别再死记公式了!用Python+PyTorch亲手画图理解卷积的‘放大’与‘缩小’
别再死记公式了用PythonPyTorch亲手画图理解卷积的‘放大’与‘缩小’卷积神经网络CNN中的下采样和上采样概念常常让初学者感到困惑。与其死记硬背公式不如通过代码和可视化来直观理解这些操作的本质。本文将带你用Python和PyTorch亲手实现这些操作并通过动态可视化来观察特征图的变化过程。1. 准备工作搭建可视化环境在开始之前我们需要准备一个能够实时显示卷积操作效果的环境。推荐使用Jupyter Notebook配合Matplotlib进行交互式可视化。首先安装必要的库!pip install torch torchvision matplotlib numpy然后导入所需的模块import torch import torch.nn as nn import torch.nn.functional as F import matplotlib.pyplot as plt import numpy as np from matplotlib.animation import FuncAnimation为了更直观地观察卷积过程我们可以创建一个简单的可视化函数def visualize_convolution(input_tensor, kernel, output_tensor, title): fig, (ax1, ax2, ax3) plt.subplots(1, 3, figsize(15, 5)) ax1.imshow(input_tensor.squeeze(), cmapgray) ax1.set_title(Input) ax1.axis(off) ax2.imshow(kernel.squeeze(), cmapgray) ax2.set_title(Kernel) ax2.axis(off) ax3.imshow(output_tensor.squeeze(), cmapgray) ax3.set_title(Output) ax3.axis(off) plt.suptitle(title) plt.show()2. 下采样特征图的缩小过程下采样是CNN中常见的操作它通过卷积和池化等方式减小特征图的尺寸。让我们通过代码来观察这一过程。2.1 标准卷积的下采样效果首先创建一个简单的4×4输入矩阵input_data torch.tensor([ [1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16] ], dtypetorch.float32).unsqueeze(0).unsqueeze(0) # shape: [1, 1, 4, 4]定义一个3×3的卷积核步长(stride)为1conv nn.Conv2d(1, 1, kernel_size3, stride1, padding0, biasFalse) # 手动设置卷积核权重 with torch.no_grad(): conv.weight.data torch.ones_like(conv.weight.data)执行卷积操作并可视化output conv(input_data) visualize_convolution(input_data, conv.weight.data, output, Standard Convolution (stride1))观察输出结果你会发现特征图从4×4缩小到了2×2。这是因为在没有填充(padding0)的情况下3×3的卷积核在4×4的输入上只能滑动2×2次。2.2 增大步长的下采样效果现在让我们增大步长(stride)来观察更明显的下采样效果conv_stride2 nn.Conv2d(1, 1, kernel_size3, stride2, padding0, biasFalse) with torch.no_grad(): conv_stride2.weight.data torch.ones_like(conv_stride2.weight.data) output_stride2 conv_stride2(input_data) visualize_convolution(input_data, conv_stride2.weight.data, output_stride2, Convolution with stride2)这次输出变成了1×1的特征图。通过调整步长我们可以控制下采样的程度。3. 上采样特征图的放大过程上采样是下采样的逆过程常用于图像分割等需要输出与输入尺寸相同的任务中。PyTorch提供了几种上采样方法我们重点看看转置卷积。3.1 转置卷积的基本原理转置卷积Transposed Convolution常被误称为反卷积它实际上是一种特殊的正向卷积操作能够实现上采样。创建一个2×2的输入small_input torch.tensor([ [1, 2], [3, 4] ], dtypetorch.float32).unsqueeze(0).unsqueeze(0) # shape: [1, 1, 2, 2]定义转置卷积层trans_conv nn.ConvTranspose2d(1, 1, kernel_size3, stride1, padding0, biasFalse) with torch.no_grad(): trans_conv.weight.data torch.ones_like(trans_conv.weight.data)执行转置卷积并可视化output_trans trans_conv(small_input) visualize_convolution(small_input, trans_conv.weight.data, output_trans, Transposed Convolution (stride1))你会看到2×2的输入被放大到了4×4。这是因为转置卷积在输入元素之间插入了零值然后进行常规卷积操作。3.2 转置卷积的步长效应增大转置卷积的步长可以进一步放大特征图trans_conv_stride2 nn.ConvTranspose2d(1, 1, kernel_size3, stride2, padding0, biasFalse) with torch.no_grad(): trans_conv_stride2.weight.data torch.ones_like(trans_conv_stride2.weight.data) output_trans_stride2 trans_conv_stride2(small_input) visualize_convolution(small_input, trans_conv_stride2.weight.data, output_trans_stride2, Transposed Convolution (stride2))这次2×2的输入被放大到了5×5。理解转置卷积的关键在于认识到它实际上是在输入元素之间插入(stride-1)个零值然后进行常规卷积。4. 空洞卷积扩大感受野而不增加参数空洞卷积Dilated Convolution通过在卷积核元素之间插入空洞来扩大感受野同时不增加参数数量。4.1 基本空洞卷积实现创建一个7×7的输入large_input torch.tensor([ [1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14], [15, 16, 17, 18, 19, 20, 21], [22, 23, 24, 25, 26, 27, 28], [29, 30, 31, 32, 33, 34, 35], [36, 37, 38, 39, 40, 41, 42], [43, 44, 45, 46, 47, 48, 49] ], dtypetorch.float32).unsqueeze(0).unsqueeze(0)定义空洞卷积层空洞率2dilated_conv nn.Conv2d(1, 1, kernel_size3, stride1, padding2, dilation2, biasFalse) with torch.no_grad(): dilated_conv.weight.data torch.ones_like(dilated_conv.weight.data)执行空洞卷积并可视化output_dilated dilated_conv(large_input) visualize_convolution(large_input, dilated_conv.weight.data, output_dilated, Dilated Convolution (rate2))虽然使用了3×3的卷积核但由于空洞率为2实际感受野相当于5×5。观察输出结果你会发现中心像素受到了更广泛区域的影响。4.2 空洞卷积与标准卷积的对比为了更清楚地看到空洞卷积的效果我们可以创建一个动画来对比标准卷积和空洞卷积def create_comparison_animation(): fig, (ax1, ax2) plt.subplots(1, 2, figsize(12, 6)) # 标准卷积 standard_conv nn.Conv2d(1, 1, kernel_size3, stride1, padding1, biasFalse) with torch.no_grad(): standard_conv.weight.data torch.ones_like(standard_conv.weight.data) # 空洞卷积 dilated_conv nn.Conv2d(1, 1, kernel_size3, stride1, padding2, dilation2, biasFalse) with torch.no_grad(): dilated_conv.weight.data torch.ones_like(dilated_conv.weight.data) def update(i): ax1.clear() ax2.clear() # 在输入上标记当前卷积核位置 marked_input large_input.clone() h, w marked_input.shape[-2:] # 标准卷积的覆盖区域 center_h, center_w i//w, i%w for dh in [-1, 0, 1]: for dw in [-1, 0, 1]: nh, nw center_h dh, center_w dw if 0 nh h and 0 nw w: marked_input[0, 0, nh, nw] 10 # 高亮显示 ax1.imshow(marked_input.squeeze(), cmapgray) ax1.set_title(Standard Conv Coverage) # 空洞卷积的覆盖区域 marked_input_dilated large_input.clone() for dh in [-2, 0, 2]: for dw in [-2, 0, 2]: nh, nw center_h dh, center_w dw if 0 nh h and 0 nw w: marked_input_dilated[0, 0, nh, nw] 10 # 高亮显示 ax2.imshow(marked_input_dilated.squeeze(), cmapgray) ax2.set_title(Dilated Conv Coverage (rate2)) plt.suptitle(fPosition {i}: ({center_h}, {center_w})) anim FuncAnimation(fig, update, frames49, interval200) plt.close() return anim # 显示动画 create_comparison_animation()这个动画清晰地展示了标准卷积和空洞卷积在感受野上的差异。虽然两者都使用3×3的卷积核但空洞卷积能够覆盖更大的区域。5. 综合应用构建简单的上采样-下采样网络现在让我们把这些概念结合起来构建一个简单的网络先下采样再上采样一张图片class SimpleUpDownNet(nn.Module): def __init__(self): super().__init__() self.down1 nn.Conv2d(1, 1, kernel_size3, stride2, padding1) self.down2 nn.Conv2d(1, 1, kernel_size3, stride2, padding1) self.up1 nn.ConvTranspose2d(1, 1, kernel_size3, stride2, padding1, output_padding1) self.up2 nn.ConvTranspose2d(1, 1, kernel_size3, stride2, padding1, output_padding1) def forward(self, x): x F.relu(self.down1(x)) x F.relu(self.down2(x)) x F.relu(self.up1(x)) x self.up2(x) return x # 加载测试图像 from skimage.data import camera test_image torch.from_numpy(camera()).float().unsqueeze(0).unsqueeze(0) / 255.0 # 创建并运行网络 net SimpleUpDownNet() with torch.no_grad(): output_image net(test_image) # 可视化结果 plt.figure(figsize(10, 5)) plt.subplot(1, 2, 1) plt.imshow(test_image.squeeze(), cmapgray) plt.title(Original Image) plt.axis(off) plt.subplot(1, 2, 2) plt.imshow(output_image.squeeze(), cmapgray) plt.title(After Down-Up Sampling) plt.axis(off) plt.show()通过这个简单的例子你可以看到下采样和上采样操作对图像的影响。虽然最终图像尺寸恢复了但一些细节信息在过程中丢失了。

更多文章