一、参数共享含义
参数共享(Parameter Sharing)是模型压缩与加速中的一种重要技术。通过参数共享,多个神经元或层可以共享相同的权重参数,而不是每个神经元或层都有独立的参数。
二、一个超级简单的示例
定义一个简单的卷积神经网络,一共二层,第二层共享第一层的参数
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 第一个卷积层,使用32个3x3的卷积核
self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
# 第二个卷积层,使用32个3x3的卷积核,但我们将共享第一个卷积层的参数
self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
# 共享参数:将conv2的参数设置为conv1的参数
self.conv2.weight = self.conv1.weight
self.conv2.bias = self.conv1.bias
# 全连接层
self.fc = nn.Linear(32 * 28 * 28, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = x.view(x.size(0), -1) # 展平
x = self.fc(x)
return x
# 创建模型实例
model = SimpleCNN()
# 打印模型结构
print(model)
# 打印模型的参数
for name, param in model.named_parameters():
print(name, param.size())
# 示例输入
input_data = torch.randn(1, 1, 28, 28) # 1个样本,1个通道,28x28的图像
output = model(input_data)
print(output)
二、指定共享某一模块
假设我们有以下两个模型:
class ANN1(nn.Module):
def __init__(self,features):
super(ANN1, self).__init__()
self.features = features
self.nn_same = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_diff = torch.nn.Sequential(
nn.Linear(128, 1)
)
def forward(self, x):
# x(batch_size, features)
x = self.nn_same(x)
x = self.nn_diff(x)
return x
class ANN2(nn.Module):
def __init__(self,features):
super(ANN2, self).__init__()
self.features = features
self.nn_same = torch.nn.Sequential(
nn.Linear(features, 128),
torch.nn.ReLU(),
)
self.nn_diff = torch.nn.Sequential(
nn.Linear(128, 1)
)
def forward(self, x):
# x(batch_size, features)
x = self.nn_same(x)
x = self.nn_diff(x)
return x
model1 = ANN1(10)
model2 = ANN2(10)
print(model1)
print(model2)
ANN1(
(nn_same): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_diff): Sequential(
(0): Linear(in_features=128, out_features=1, bias=True)
)
)
ANN2(
(nn_same): Sequential(
(0): Linear(in_features=10, out_features=128, bias=True)
(1): ReLU()
)
(nn_diff): Sequential(
(0): Linear(in_features=128, out_features=1, bias=True)
)
)
其中 nn_same 代表要共享参数的模块,模块名称可以不相同,但是模块结构必须完全相同。因为模型初始化时参数是随机初始化的,所以两个模型的参数肯定不相同。
下面我们开始进行参数共享:
print("****************迁移前*****************")
for param_tensor in model2.nn_same.state_dict():#输出迁移前的参数
print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
model_nn_same = model1.nn_same.state_dict() ##获取model的nn_same部分的参数
model2.nn_same.load_state_dict(model_nn_same,strict=True) #更新model2 nn_same部分的参数,#更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)
print("****************迁移后*****************")
for param_tensor in model2.nn_same.state_dict():#输出迁移后的参数
print(param_tensor, "\t", model2.nn_same.state_dict()[param_tensor])
#此时nn_same参数更新,nn_diff2参数不变
三、 共享所有相同名称的模块
只需要修改这两句即可
model_all = model1.state_dict() #获取model1的所有的参数
model2.load_state_dict(model_all,strict=False) #更新model2所有的参数,False表示跳过名称不同的层,True表示必须全部匹配(默认)
strict=False,表示两个模型的模块名不需要完全匹配,只会更新名称相同的模块。如果两个模型的模块名不完全相同但是strict=True那么就会报错。
本文部分参考了《Pytorch中模型之间的参数共享》原文链接:https://blog.csdn.net/cyj972628089/article/details/127325735
如有侵权,请原作者联系删除。