pool5_flat = pool5.view(pool5.size(0), -1)
【1518,512,7,7】→【1518,25088】
--coding:utf-8-- 维度对应问题
import torch
import torch.nn as nn
a=torch.rand(4,25088)
class Net(nn.Module):
def init(self):
super().init()
self.layer = nn.Sequential(
nn.Linear(25088, 4096),
nn.ReLU(True),
nn.Linear(4096, 10),
nn.ReLU(True),
)
def forward(self, x):
x = self.layer(x)
return x
net=Net()
b=net(a)
print(b.shape)