有个很好的例子能看出 log_prob(action) 做了什么事:
import torch
import torch.nn.functional as F
action_logits = torch.rand(5)
action_probs = F.softmax(action_logits, dim=-1)
dist = torch.distributions.Categorical(action_probs)
action = dist.sample()
print(dist.log_prob(action), torch.log(action_probs[action]))
会发现输出的值相等。说明 dist.log_prob(action) 同 torch.log(action_probs[action]) 等价。
总的来说,会输出这个 action 概率的 log 值。