这里对lstm的解释挺好:https://zhuanlan.zhihu.com/p/32085405

https://blog.csdn.net/weixin_42769131/article/details/104728842


class ConvLSTMCell(nn.Module):
"""
Generate a convolutional LSTM cell
"""
def __init__(self, input_size, hidden_size):
super(ConvLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.Gates = nn.Conv2d(input_size + hidden_size, 4 * hidden_size, kernel_size=3, stride=1, padding=1)
def forward(self, input_, prev_state):
# get batch and spatial sizes
batch_size = input_.data.size()[0]
spatial_size = input_.data.size()[2:]
# generate empty prev_state, if None is provided
if prev_state is None:
state_size = [batch_size, self.hidden_size] + list(spatial_size)
prev_state = (
torch.zeros(state_size).cuda(),
torch.zeros(state_size).cuda()
)
prev_hidden, prev_cell = prev_state
# data size is [batch, channel, height, width]
stacked_inputs = torch.cat((input_, prev_hidden), 1)
gates = self.Gates(stacked_inputs)
# chunk across channel dimension
in_gate, remember_gate, out_gate, cell_gate = gates.chunk(4, 1)
# cell_gate, 表示当前的输入xt和前面时刻的输出 的和,tanh拉到-1~1之间 是输入数据
# 忘记阶段:remember_gate 遗忘门,控制上一个细胞状态留下多少信息,
# 选择记忆:in_gate 对当前的输入信息(information) xt有选择的进行记忆,
# 输出阶段:out_gate 决定哪些作为当前状态的输出
# apply sigmoid non linearity
in_gate = F.sigmoid(in_gate)
remember_gate = F.sigmoid(remember_gate)
out_gate = F.sigmoid(out_gate)
# apply tanh non linearity
cell_gate = F.tanh(cell_gate) # -1~1 之间的特行,这是作为输入数据而不是门控信号
# compute current cell and hidden state
cell = (remember_gate * prev_cell) + (in_gate * cell_gate)
hidden = out_gate * F.tanh(cell)
return hidden, cell