首先计算类别频率,D是整个数据集种各类别的像素总个数。cityscapes数据集中如下图:
image.png
然后计算权重
代码如下:
import torch
D = [2.01e+9,2.98e+8,9.96e+8,3.39e+7,4.50e+7,6.54e+7,
9.57e+7,2.62e+7,7.21e+8,5.92e+7,1.45e+8,8.21e+7,
1.00e+7,4.13e+8,1.45e+7,1.28e+7,1.45e+7,5.64e+6,
2.57e+7]
class_freq = torch.log(torch.FloatTensor(D))
weights = 1 / class_freq # 或者1 / torch.log1p(class_freq)以便处理接近零的小数值
weights = 19 * weights / torch.sum(weights)
参考链接:weight use for loss function · Issue #14 · openseg-group/OCNet.pytorch