通过继承nn.Module
类来自定义深度学习模型是 PyTorch 中常见的做法。nn.Module
是所有神经网络模块的基类,提供了许多有用的方法和属性。自定义模型主要涉及以下几个步骤:
- 定义模型类,继承 nn.Module
- 在 init 方法中定义模型的层
- 在 forward 方法中定义前向传播的逻辑
- 实例化模型并使用
示例代码
- MLP
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 示例用法
input_size = 784 # 例如28x28的图像
hidden_size = 128
output_size = 10 # 10类分类
model = MLP(input_size, hidden_size, output_size)
print(model)
- CNN
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 示例用法
model = CNN()
print(model)
- Transformer
import torch
import torch.nn as nn
class TransformerModel(nn.Module):
def __init__(self, input_dim, model_dim, num_heads, num_layers, output_dim):
super(TransformerModel, self).__init__()
self.embedding = nn.Embedding(input_dim, model_dim)
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Linear(model_dim, output_dim)
def forward(self, src):
src = self.embedding(src) # [seq_len, batch_size, model_dim]
src = src.permute(1, 0, 2) # Transformer expects [batch_size, seq_len, model_dim]
output = self.transformer_encoder(src)
output = self.fc(output.mean(dim=1))
return output
# 示例用法
input_dim = 10000 # 词汇表大小
model_dim = 512
num_heads = 8
num_layers = 6
output_dim = 10 # 10类分类
model = TransformerModel(input_dim, model_dim, num_heads, num_layers, output_dim)
print(model)