卷积可视化:https://animatedai.github.io/
image.png
# 第1段:导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
print("PyTorch version:", torch.__version__)
print("Device:", "CUDA" if torch.cuda.is_available() else "CPU")
PyTorch version: 2.7.1
Device: CPU
# 第2段:加载MNIST数据集
def load_data():
transform = transforms.ToTensor()
train_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
download=True,
transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data',
train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=64,
shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=64,
shuffle=False)
return train_loader, test_loader
# 加载数据
train_loader, test_loader = load_data()
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
Training samples: 60000
Test samples: 10000
# 第3段:查看数据样本
def show_sample_data(dataloader, num_samples=8):
data, labels = next(iter(dataloader))
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(num_samples):
ax = axes[i//4, i%4]
ax.imshow(data[i][0], cmap='gray')
ax.set_title(f'Label: {labels[i].item()}')
ax.axis('off')
plt.suptitle('Sample MNIST Images')
plt.tight_layout()
plt.show()
print(f"Image shape: {data[0].shape}") # [1, 28, 28]
print(f"Batch shape: {data.shape}") # [64, 1, 28, 28]
show_sample_data(train_loader)
output_2_0.png
Image shape: torch.Size([1, 28, 28])
Batch shape: torch.Size([64, 1, 28, 28])
# 第4段:定义CNN模型
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
# 第一层卷积: 1→16通道,5x5卷积核
self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2)
self.pool1 = nn.MaxPool2d(2, 2) # 28x28 → 14x14
# 第二层卷积: 16→32通道,5x5卷积核
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2)
self.pool2 = nn.MaxPool2d(2, 2) # 14x14 → 7x7
# 全连接层
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10) # 10个数字类别
# 激活函数
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.2)
def forward(self, x):
# 卷积层1
x = self.conv1(x) # [64, 1, 28, 28] → [64, 16, 28, 28]
x = self.relu(x)
x = self.pool1(x) # [64, 16, 28, 28] → [64, 16, 14, 14]
# 卷积层2
x = self.conv2(x) # [64, 16, 14, 14] → [64, 32, 14, 14]
x = self.relu(x)
x = self.pool2(x) # [64, 32, 14, 14] → [64, 32, 7, 7]
# 展平
x = x.view(x.size(0), -1) # [64, 32, 7, 7] → [64, 1568]
# 全连接层
x = self.fc1(x) # [64, 1568] → [64, 128]
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x) # [64, 128] → [64, 10]
return x
# 创建模型并查看参数
model = SimpleCNN()
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
SimpleCNN(
(conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=1568, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
(relu): ReLU()
(dropout): Dropout(p=0.2, inplace=False)
)
Total parameters: 215,370
# 第5段:分析Conv2D权重形状
print("=== Conv2D Weight Shape Analysis ===")
print(f"Conv1 weight shape: {model.conv1.weight.shape}")
print("Meaning: [16 output channels, 1 input channel, 5 height, 5 width]")
print(f"\nConv2 weight shape: {model.conv2.weight.shape}")
print("Meaning: [32 output channels, 16 input channels, 5 height, 5 width]")
print("\n=== Data Flow Through Model ===")
sample_input = torch.randn(1, 1, 28, 28)
print(f"Input shape: {sample_input.shape}")
with torch.no_grad():
x = model.conv1(sample_input)
print(f"After Conv1: {x.shape}")
x = model.relu(x)
x = model.pool1(x)
print(f"After Pool1: {x.shape}")
x = model.conv2(x)
print(f"After Conv2: {x.shape}")
x = model.relu(x)
x = model.pool2(x)
print(f"After Pool2: {x.shape}")
x = x.view(x.size(0), -1)
print(f"After Flatten: {x.shape}")
x = model.fc1(x)
print(f"After FC1: {x.shape}")
x = model.fc2(x)
print(f"Final Output: {x.shape}")
=== Conv2D Weight Shape Analysis ===
Conv1 weight shape: torch.Size([16, 1, 5, 5])
Meaning: [16 output channels, 1 input channel, 5 height, 5 width]
Conv2 weight shape: torch.Size([32, 16, 5, 5])
Meaning: [32 output channels, 16 input channels, 5 height, 5 width]
=== Data Flow Through Model ===
Input shape: torch.Size([1, 1, 28, 28])
After Conv1: torch.Size([1, 16, 28, 28])
After Pool1: torch.Size([1, 16, 14, 14])
After Conv2: torch.Size([1, 32, 14, 14])
After Pool2: torch.Size([1, 32, 7, 7])
After Flatten: torch.Size([1, 1568])
After FC1: torch.Size([1, 128])
Final Output: torch.Size([1, 10])
# 第6段:定义训练函数
def train_model(model, train_loader, epochs=3):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
train_losses = []
for epoch in range(epochs):
epoch_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# 计算准确率
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# 每200个batch打印一次
if batch_idx % 200 == 0:
print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
avg_loss = epoch_loss / len(train_loader)
accuracy = 100 * correct / total
train_losses.append(avg_loss)
print(f'Epoch {epoch+1} Complete: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%')
return train_losses
# 第7段:训练模型
print("Starting training...")
train_losses = train_model(model, train_loader, epochs=3)
# 绘制训练损失
plt.figure(figsize=(8, 5))
plt.plot(train_losses, marker='o')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
Starting training...
Epoch 1, Batch 0, Loss: 2.3039
Epoch 1, Batch 200, Loss: 0.3510
Epoch 1, Batch 400, Loss: 0.1626
Epoch 1, Batch 600, Loss: 0.0927
Epoch 1, Batch 800, Loss: 0.0329
Epoch 1 Complete: Loss=0.2248, Accuracy=93.09%
Epoch 2, Batch 0, Loss: 0.0343
Epoch 2, Batch 200, Loss: 0.0524
Epoch 2, Batch 400, Loss: 0.0231
Epoch 2, Batch 600, Loss: 0.0102
Epoch 2, Batch 800, Loss: 0.0253
Epoch 2 Complete: Loss=0.0660, Accuracy=97.89%
Epoch 3, Batch 0, Loss: 0.0603
Epoch 3, Batch 200, Loss: 0.0493
Epoch 3, Batch 400, Loss: 0.0388
Epoch 3, Batch 600, Loss: 0.0033
Epoch 3, Batch 800, Loss: 0.1216
Epoch 3 Complete: Loss=0.0472, Accuracy=98.53%
output_6_1.png
# 第8段:测试模型
def test_model(model, test_loader):
model.eval()
correct = 0
total = 0
class_correct = [0] * 10
class_total = [0] * 10
with torch.no_grad():
for data, target in test_loader:
output = model(data)
_, predicted = torch.max(output, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
# 每个数字的准确率
for i in range(target.size(0)):
label = target[i]
class_correct[label] += (predicted[i] == label).item()
class_total[label] += 1
print(f'Overall Test Accuracy: {100 * correct / total:.2f}%')
print('\nPer-digit Recognition Accuracy:')
for i in range(10):
if class_total[i] > 0:
acc = 100 * class_correct[i] / class_total[i]
print(f'Digit {i}: {acc:.1f}% ({class_correct[i]}/{class_total[i]})')
test_model(model, test_loader)
Overall Test Accuracy: 98.94%
Per-digit Recognition Accuracy:
Digit 0: 99.4% (974/980)
Digit 1: 99.7% (1132/1135)
Digit 2: 99.6% (1028/1032)
Digit 3: 99.2% (1002/1010)
Digit 4: 97.7% (959/982)
Digit 5: 98.7% (880/892)
Digit 6: 99.1% (949/958)
Digit 7: 98.4% (1012/1028)
Digit 8: 98.3% (957/974)
Digit 9: 99.2% (1001/1009)
# 第9段:可视化预测结果
def visualize_predictions(model, test_loader, num_samples=16):
model.eval()
data, target = next(iter(test_loader))
with torch.no_grad():
output = model(data)
_, predicted = torch.max(output, 1)
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i in range(num_samples):
ax = axes[i//4, i%4]
ax.imshow(data[i][0], cmap='gray')
true_label = target[i].item()
pred_label = predicted[i].item()
color = 'green' if true_label == pred_label else 'red'
ax.set_title(f'True: {true_label}, Pred: {pred_label}', color=color)
ax.axis('off')
plt.suptitle('Prediction Results (Green=Correct, Red=Wrong)')
plt.tight_layout()
plt.show()
visualize_predictions(model, test_loader)
output_8_0.png
# 第10段:分析学习到的卷积核
def analyze_conv_kernels(model):
conv1_weight = model.conv1.weight.data # [16, 1, 5, 5]
print(f"First layer kernels shape: {conv1_weight.shape}")
print("These are the learned feature detectors!")
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(16):
ax = axes[i//8, i%8]
kernel = conv1_weight[i, 0].numpy()
im = ax.imshow(kernel, cmap='RdBu')
ax.set_title(f'Kernel {i+1}')
ax.axis('off')
plt.suptitle('Learned Conv Kernels (Feature Detectors)')
plt.tight_layout()
plt.show()
analyze_conv_kernels(model)
First layer kernels shape: torch.Size([16, 1, 5, 5])
These are the learned feature detectors!
output_9_1.png
# 第11段:可视化特征图
def visualize_feature_maps(model, test_loader):
model.eval()
data, _ = next(iter(test_loader))
sample_image = data[0:1] # 取第一张图
with torch.no_grad():
# 第一层卷积后的特征图
x = model.conv1(sample_image)
x = model.relu(x)
feature_maps = x[0] # [16, 28, 28]
# 显示原图和16个特征图
fig, axes = plt.subplots(3, 6, figsize=(15, 8))
# 原图
axes[0, 0].imshow(sample_image[0, 0], cmap='gray')
axes[0, 0].set_title('Original Image')
axes[0, 0].axis('off')
# 特征图
for i in range(16):
if i < 5: # 第一行剩余位置
row, col = 0, i+1
elif i < 11: # 第二行
row, col = 1, i-5
else: # 第三行
row, col = 2, i-11
axes[row, col].imshow(feature_maps[i], cmap='viridis')
axes[row, col].set_title(f'Feature {i+1}')
axes[row, col].axis('off')
plt.suptitle('Feature Maps after First Conv Layer')
plt.tight_layout()
plt.show()
visualize_feature_maps(model, test_loader)
output_10_0.png
# 第12段:对比不同数字的特征激活
def compare_digit_features(model, test_loader):
model.eval()
# 找到不同数字的样本
digits_found = {}
with torch.no_grad():
for data, target in test_loader:
for i, label in enumerate(target):
digit = label.item()
if digit not in digits_found and len(digits_found) < 5:
digits_found[digit] = data[i:i+1]
if len(digits_found) == 5:
break
if len(digits_found) == 5:
break
fig, axes = plt.subplots(5, 6, figsize=(15, 12))
for row, (digit, image) in enumerate(digits_found.items()):
# 原图
axes[row, 0].imshow(image[0, 0], cmap='gray')
axes[row, 0].set_title(f'Digit {digit}')
axes[row, 0].axis('off')
# 特征图
with torch.no_grad():
features = model.relu(model.conv1(image))[0]
# 显示前5个特征图
for col in range(1, 6):
axes[row, col].imshow(features[col-1], cmap='viridis')
axes[row, col].set_title(f'Feature {col}')
axes[row, col].axis('off')
plt.suptitle('How Different Digits Activate Different Features')
plt.tight_layout()
plt.show()
compare_digit_features(model, test_loader)
output_11_0.png
# 第13段:实验总结
print("=== MNIST CNN 实验总结 ===")
print("\n1. Conv2D 形状理解:")
print(" - Conv1: [16, 1, 5, 5] = [输出通道, 输入通道, 高, 宽]")
print(" - Conv2: [32, 16, 5, 5] = [输出通道, 输入通道, 高, 宽]")
print("\n2. 实验收获:")
print(" - CNN 能自动学习有用的特征检测器")
print(" - 不同卷积核专门检测不同的模式")
print(" - 特征图显示了网络'看到'的内容")
print(" - 深层网络学习更复杂的特征")
print("\n3. 性能表现:")
print(f" - 仅用3个epoch就达到95%+准确率")
print(f" - 总参数量: {sum(p.numel() for p in model.parameters()):,}")
print(" - 比全连接网络参数少得多")
print("\n4. 关键洞察:")
print(" - Conv2D 非常适合图像识别任务")
print(" - 权重共享让CNN参数效率很高")
print(" - 可视化帮助理解模型学到了什么")
print(" - 卷积核就是自动学习的特征检测器")
print("\n5. Conv2D 核心原理:")
print(" - 输入: [batch, 1, 28, 28] 灰度图像")
print(" - Conv1: 1通道→16通道,学习16种基础特征")
print(" - Conv2: 16通道→32通道,组合成32种复杂特征")
print(" - 最终: 32个7x7特征图→全连接层→10个类别")
print("\n6. 为什么CNN这么强:")
print(" - 局部连接: 只关注邻近像素,符合图像特性")
print(" - 权重共享: 同一特征在图像任何位置都能检测")
print(" - 平移不变性: 数字在图像中移动位置仍能识别")
print(" - 层次特征: 从简单边缘到复杂形状逐层抽象")
=== MNIST CNN 实验总结 ===
1. Conv2D 形状理解:
- Conv1: [16, 1, 5, 5] = [输出通道, 输入通道, 高, 宽]
- Conv2: [32, 16, 5, 5] = [输出通道, 输入通道, 高, 宽]
2. 实验收获:
- CNN 能自动学习有用的特征检测器
- 不同卷积核专门检测不同的模式
- 特征图显示了网络'看到'的内容
- 深层网络学习更复杂的特征
3. 性能表现:
- 仅用3个epoch就达到95%+准确率
- 总参数量: 215,370
- 比全连接网络参数少得多
4. 关键洞察:
- Conv2D 非常适合图像识别任务
- 权重共享让CNN参数效率很高
- 可视化帮助理解模型学到了什么
- 卷积核就是自动学习的特征检测器
5. Conv2D 核心原理:
- 输入: [batch, 1, 28, 28] 灰度图像
- Conv1: 1通道→16通道,学习16种基础特征
- Conv2: 16通道→32通道,组合成32种复杂特征
- 最终: 32个7x7特征图→全连接层→10个类别
6. 为什么CNN这么强:
- 局部连接: 只关注邻近像素,符合图像特性
- 权重共享: 同一特征在图像任何位置都能检测
- 平移不变性: 数字在图像中移动位置仍能识别
- 层次特征: 从简单边缘到复杂形状逐层抽象