卷积学习笔记:Conv2D做MNIST手写数字识别

卷积可视化https://animatedai.github.io/

image.png
# 第1段:导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

print("PyTorch version:", torch.__version__)
print("Device:", "CUDA" if torch.cuda.is_available() else "CPU")
PyTorch version: 2.7.1
Device: CPU
# 第2段:加载MNIST数据集
def load_data():
    transform = transforms.ToTensor()
    
    train_dataset = torchvision.datasets.MNIST(root='./data', 
                                              train=True, 
                                              download=True, 
                                              transform=transform)
    
    test_dataset = torchvision.datasets.MNIST(root='./data', 
                                             train=False, 
                                             transform=transform)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, 
                                              batch_size=64, 
                                              shuffle=True)
    
    test_loader = torch.utils.data.DataLoader(test_dataset, 
                                             batch_size=64, 
                                             shuffle=False)
    
    return train_loader, test_loader

# 加载数据
train_loader, test_loader = load_data()

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")
Training samples: 60000
Test samples: 10000
# 第3段:查看数据样本
def show_sample_data(dataloader, num_samples=8):
    data, labels = next(iter(dataloader))
    
    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    for i in range(num_samples):
        ax = axes[i//4, i%4]
        ax.imshow(data[i][0], cmap='gray')
        ax.set_title(f'Label: {labels[i].item()}')
        ax.axis('off')
    
    plt.suptitle('Sample MNIST Images')
    plt.tight_layout()
    plt.show()
    
    print(f"Image shape: {data[0].shape}")  # [1, 28, 28]
    print(f"Batch shape: {data.shape}")     # [64, 1, 28, 28]

show_sample_data(train_loader)
output_2_0.png
Image shape: torch.Size([1, 28, 28])
Batch shape: torch.Size([64, 1, 28, 28])
# 第4段:定义CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 第一层卷积: 1→16通道,5x5卷积核
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5, padding=2)
        self.pool1 = nn.MaxPool2d(2, 2)  # 28x28 → 14x14
        
        # 第二层卷积: 16→32通道,5x5卷积核
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(2, 2)  # 14x14 → 7x7
        
        # 全连接层
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)  # 10个数字类别
        
        # 激活函数
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
    
    def forward(self, x):
        # 卷积层1
        x = self.conv1(x)  # [64, 1, 28, 28] → [64, 16, 28, 28]
        x = self.relu(x)
        x = self.pool1(x)  # [64, 16, 28, 28] → [64, 16, 14, 14]
        
        # 卷积层2
        x = self.conv2(x)  # [64, 16, 14, 14] → [64, 32, 14, 14]
        x = self.relu(x)
        x = self.pool2(x)  # [64, 32, 14, 14] → [64, 32, 7, 7]
        
        # 展平
        x = x.view(x.size(0), -1)  # [64, 32, 7, 7] → [64, 1568]
        
        # 全连接层
        x = self.fc1(x)    # [64, 1568] → [64, 128]
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)    # [64, 128] → [64, 10]
        
        return x

# 创建模型并查看参数
model = SimpleCNN()
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
SimpleCNN(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=1568, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.2, inplace=False)
)

Total parameters: 215,370
# 第5段:分析Conv2D权重形状
print("=== Conv2D Weight Shape Analysis ===")
print(f"Conv1 weight shape: {model.conv1.weight.shape}")
print("Meaning: [16 output channels, 1 input channel, 5 height, 5 width]")

print(f"\nConv2 weight shape: {model.conv2.weight.shape}")
print("Meaning: [32 output channels, 16 input channels, 5 height, 5 width]")

print("\n=== Data Flow Through Model ===")
sample_input = torch.randn(1, 1, 28, 28)
print(f"Input shape: {sample_input.shape}")

with torch.no_grad():
    x = model.conv1(sample_input)
    print(f"After Conv1: {x.shape}")
    
    x = model.relu(x)
    x = model.pool1(x)
    print(f"After Pool1: {x.shape}")
    
    x = model.conv2(x)
    print(f"After Conv2: {x.shape}")
    
    x = model.relu(x)
    x = model.pool2(x)
    print(f"After Pool2: {x.shape}")
    
    x = x.view(x.size(0), -1)
    print(f"After Flatten: {x.shape}")
    
    x = model.fc1(x)
    print(f"After FC1: {x.shape}")
    
    x = model.fc2(x)
    print(f"Final Output: {x.shape}")
=== Conv2D Weight Shape Analysis ===
Conv1 weight shape: torch.Size([16, 1, 5, 5])
Meaning: [16 output channels, 1 input channel, 5 height, 5 width]

Conv2 weight shape: torch.Size([32, 16, 5, 5])
Meaning: [32 output channels, 16 input channels, 5 height, 5 width]

=== Data Flow Through Model ===
Input shape: torch.Size([1, 1, 28, 28])
After Conv1: torch.Size([1, 16, 28, 28])
After Pool1: torch.Size([1, 16, 14, 14])
After Conv2: torch.Size([1, 32, 14, 14])
After Pool2: torch.Size([1, 32, 7, 7])
After Flatten: torch.Size([1, 1568])
After FC1: torch.Size([1, 128])
Final Output: torch.Size([1, 10])
# 第6段:定义训练函数
def train_model(model, train_loader, epochs=3):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    model.train()
    train_losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            # 计算准确率
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # 每200个batch打印一次
            if batch_idx % 200 == 0:
                print(f'Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        avg_loss = epoch_loss / len(train_loader)
        accuracy = 100 * correct / total
        train_losses.append(avg_loss)
        
        print(f'Epoch {epoch+1} Complete: Loss={avg_loss:.4f}, Accuracy={accuracy:.2f}%')
    
    return train_losses
# 第7段:训练模型
print("Starting training...")
train_losses = train_model(model, train_loader, epochs=3)

# 绘制训练损失
plt.figure(figsize=(8, 5))
plt.plot(train_losses, marker='o')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()
Starting training...
Epoch 1, Batch 0, Loss: 2.3039
Epoch 1, Batch 200, Loss: 0.3510
Epoch 1, Batch 400, Loss: 0.1626
Epoch 1, Batch 600, Loss: 0.0927
Epoch 1, Batch 800, Loss: 0.0329
Epoch 1 Complete: Loss=0.2248, Accuracy=93.09%
Epoch 2, Batch 0, Loss: 0.0343
Epoch 2, Batch 200, Loss: 0.0524
Epoch 2, Batch 400, Loss: 0.0231
Epoch 2, Batch 600, Loss: 0.0102
Epoch 2, Batch 800, Loss: 0.0253
Epoch 2 Complete: Loss=0.0660, Accuracy=97.89%
Epoch 3, Batch 0, Loss: 0.0603
Epoch 3, Batch 200, Loss: 0.0493
Epoch 3, Batch 400, Loss: 0.0388
Epoch 3, Batch 600, Loss: 0.0033
Epoch 3, Batch 800, Loss: 0.1216
Epoch 3 Complete: Loss=0.0472, Accuracy=98.53%
output_6_1.png
# 第8段:测试模型
def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * 10
    class_total = [0] * 10
    
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            _, predicted = torch.max(output, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # 每个数字的准确率
            for i in range(target.size(0)):
                label = target[i]
                class_correct[label] += (predicted[i] == label).item()
                class_total[label] += 1
    
    print(f'Overall Test Accuracy: {100 * correct / total:.2f}%')
    
    print('\nPer-digit Recognition Accuracy:')
    for i in range(10):
        if class_total[i] > 0:
            acc = 100 * class_correct[i] / class_total[i]
            print(f'Digit {i}: {acc:.1f}% ({class_correct[i]}/{class_total[i]})')

test_model(model, test_loader)
Overall Test Accuracy: 98.94%

Per-digit Recognition Accuracy:
Digit 0: 99.4% (974/980)
Digit 1: 99.7% (1132/1135)
Digit 2: 99.6% (1028/1032)
Digit 3: 99.2% (1002/1010)
Digit 4: 97.7% (959/982)
Digit 5: 98.7% (880/892)
Digit 6: 99.1% (949/958)
Digit 7: 98.4% (1012/1028)
Digit 8: 98.3% (957/974)
Digit 9: 99.2% (1001/1009)
# 第9段:可视化预测结果
def visualize_predictions(model, test_loader, num_samples=16):
    model.eval()
    
    data, target = next(iter(test_loader))
    
    with torch.no_grad():
        output = model(data)
        _, predicted = torch.max(output, 1)
    
    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    for i in range(num_samples):
        ax = axes[i//4, i%4]
        ax.imshow(data[i][0], cmap='gray')
        
        true_label = target[i].item()
        pred_label = predicted[i].item()
        color = 'green' if true_label == pred_label else 'red'
        
        ax.set_title(f'True: {true_label}, Pred: {pred_label}', color=color)
        ax.axis('off')
    
    plt.suptitle('Prediction Results (Green=Correct, Red=Wrong)')
    plt.tight_layout()
    plt.show()

visualize_predictions(model, test_loader)
output_8_0.png
# 第10段:分析学习到的卷积核
def analyze_conv_kernels(model):
    conv1_weight = model.conv1.weight.data  # [16, 1, 5, 5]
    
    print(f"First layer kernels shape: {conv1_weight.shape}")
    print("These are the learned feature detectors!")
    
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    for i in range(16):
        ax = axes[i//8, i%8]
        kernel = conv1_weight[i, 0].numpy()
        im = ax.imshow(kernel, cmap='RdBu')
        ax.set_title(f'Kernel {i+1}')
        ax.axis('off')
    
    plt.suptitle('Learned Conv Kernels (Feature Detectors)')
    plt.tight_layout()
    plt.show()

analyze_conv_kernels(model)
First layer kernels shape: torch.Size([16, 1, 5, 5])
These are the learned feature detectors!
output_9_1.png
# 第11段:可视化特征图
def visualize_feature_maps(model, test_loader):
    model.eval()
    
    data, _ = next(iter(test_loader))
    sample_image = data[0:1]  # 取第一张图
    
    with torch.no_grad():
        # 第一层卷积后的特征图
        x = model.conv1(sample_image)
        x = model.relu(x)
        feature_maps = x[0]  # [16, 28, 28]
    
    # 显示原图和16个特征图
    fig, axes = plt.subplots(3, 6, figsize=(15, 8))
    
    # 原图
    axes[0, 0].imshow(sample_image[0, 0], cmap='gray')
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    # 特征图
    for i in range(16):
        if i < 5:  # 第一行剩余位置
            row, col = 0, i+1
        elif i < 11:  # 第二行
            row, col = 1, i-5
        else:  # 第三行
            row, col = 2, i-11
        
        axes[row, col].imshow(feature_maps[i], cmap='viridis')
        axes[row, col].set_title(f'Feature {i+1}')
        axes[row, col].axis('off')
    
    plt.suptitle('Feature Maps after First Conv Layer')
    plt.tight_layout()
    plt.show()

visualize_feature_maps(model, test_loader)
output_10_0.png
# 第12段:对比不同数字的特征激活
def compare_digit_features(model, test_loader):
    model.eval()
    
    # 找到不同数字的样本
    digits_found = {}
    with torch.no_grad():
        for data, target in test_loader:
            for i, label in enumerate(target):
                digit = label.item()
                if digit not in digits_found and len(digits_found) < 5:
                    digits_found[digit] = data[i:i+1]
                if len(digits_found) == 5:
                    break
            if len(digits_found) == 5:
                break
    
    fig, axes = plt.subplots(5, 6, figsize=(15, 12))
    
    for row, (digit, image) in enumerate(digits_found.items()):
        # 原图
        axes[row, 0].imshow(image[0, 0], cmap='gray')
        axes[row, 0].set_title(f'Digit {digit}')
        axes[row, 0].axis('off')
        
        # 特征图
        with torch.no_grad():
            features = model.relu(model.conv1(image))[0]
        
        # 显示前5个特征图
        for col in range(1, 6):
            axes[row, col].imshow(features[col-1], cmap='viridis')
            axes[row, col].set_title(f'Feature {col}')
            axes[row, col].axis('off')
    
    plt.suptitle('How Different Digits Activate Different Features')
    plt.tight_layout()
    plt.show()

compare_digit_features(model, test_loader)
output_11_0.png
# 第13段:实验总结
print("=== MNIST CNN 实验总结 ===")
print("\n1. Conv2D 形状理解:")
print("   - Conv1: [16, 1, 5, 5] = [输出通道, 输入通道, 高, 宽]")
print("   - Conv2: [32, 16, 5, 5] = [输出通道, 输入通道, 高, 宽]")

print("\n2. 实验收获:")
print("   - CNN 能自动学习有用的特征检测器")
print("   - 不同卷积核专门检测不同的模式")
print("   - 特征图显示了网络'看到'的内容")
print("   - 深层网络学习更复杂的特征")

print("\n3. 性能表现:")
print(f"   - 仅用3个epoch就达到95%+准确率")
print(f"   - 总参数量: {sum(p.numel() for p in model.parameters()):,}")
print("   - 比全连接网络参数少得多")

print("\n4. 关键洞察:")
print("   - Conv2D 非常适合图像识别任务")
print("   - 权重共享让CNN参数效率很高")
print("   - 可视化帮助理解模型学到了什么")
print("   - 卷积核就是自动学习的特征检测器")

print("\n5. Conv2D 核心原理:")
print("   - 输入: [batch, 1, 28, 28] 灰度图像")
print("   - Conv1: 1通道→16通道,学习16种基础特征")
print("   - Conv2: 16通道→32通道,组合成32种复杂特征")
print("   - 最终: 32个7x7特征图→全连接层→10个类别")

print("\n6. 为什么CNN这么强:")
print("   - 局部连接: 只关注邻近像素,符合图像特性")
print("   - 权重共享: 同一特征在图像任何位置都能检测")
print("   - 平移不变性: 数字在图像中移动位置仍能识别")
print("   - 层次特征: 从简单边缘到复杂形状逐层抽象")
=== MNIST CNN 实验总结 ===

1. Conv2D 形状理解:
   - Conv1: [16, 1, 5, 5] = [输出通道, 输入通道, 高, 宽]
   - Conv2: [32, 16, 5, 5] = [输出通道, 输入通道, 高, 宽]

2. 实验收获:
   - CNN 能自动学习有用的特征检测器
   - 不同卷积核专门检测不同的模式
   - 特征图显示了网络'看到'的内容
   - 深层网络学习更复杂的特征

3. 性能表现:
   - 仅用3个epoch就达到95%+准确率
   - 总参数量: 215,370
   - 比全连接网络参数少得多

4. 关键洞察:
   - Conv2D 非常适合图像识别任务
   - 权重共享让CNN参数效率很高
   - 可视化帮助理解模型学到了什么
   - 卷积核就是自动学习的特征检测器

5. Conv2D 核心原理:
   - 输入: [batch, 1, 28, 28] 灰度图像
   - Conv1: 1通道→16通道,学习16种基础特征
   - Conv2: 16通道→32通道,组合成32种复杂特征
   - 最终: 32个7x7特征图→全连接层→10个类别

6. 为什么CNN这么强:
   - 局部连接: 只关注邻近像素,符合图像特性
   - 权重共享: 同一特征在图像任何位置都能检测
   - 平移不变性: 数字在图像中移动位置仍能识别
   - 层次特征: 从简单边缘到复杂形状逐层抽象

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容