声明一个关于网络的类
import torch.nn as nn
class NetName(nn.Module):
def __init__(self):
super(NetName, self).__init__()
nn.module1 = ...
nn.module2 = ...
nn.module3 = ...
def forward(self,x):
x = self.module1(x)
x = self.module1(x)
x = self.module2(x)
x = self.module3(x)
return x
其中在构造函数__init__
中构造这个NN中需要使用的各种模块(module),比如:参数完全相同的maxpooling声明为一个模块,或者例如在CV任务中,把feature_extraction的网络和classification的网络分别声明。
forward
函数用于声明各个模块间的关系。即,连接整个网络。
net = NetName().to(device) # 创建网络,并放入指定的device
网络创建后,可以通过以下方式遍历模块信息:
for name, module in net._modules.items():
print(name) # name就是__init__中的各个模块名
print(module) # module就是各个模块内具体的层
示例:AlexNet
注释中的tensor大小变化是基于cifar10的图片----(channel=3, height=32, width=32)
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), # (3,32,32) -> (64,8,8)
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # (64,8,8) -> (64,4,4)
nn.Conv2d(64, 192, kernel_size=5, padding=2), # (64,4,4) -> (192,4,4)
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # (192,4,4) -> (192,2,2)
nn.Conv2d(192, 384, kernel_size=3, padding=1), # (192,2,2) -> (384,2,2)
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1), # (384,2,2) -> (256,2,2)
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1), # (256,2,2) -> (256,2,2)
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # (256,2,2) -> (256,1,1)
)
self.classifier = nn.Linear(256, 10) # (batch_size,256) -> (batch_size,10)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1) # flatten to (batch_size, 256*1*1)
x = self.classifier(x)
return x