热力图展示:
多元正太分布生成2d的heatmap, 多个中心点生成的heatmap 进行累积,最后输出多个中心点的热力图。
# coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
import random
from scipy.stats import multivariate_normal
covs =[40,50,60] # 可选cov list
def CenterLabelHeatMap(img_width, img_height, c_x, c_y, sigma):
X1 = np.linspace(1, img_width, img_width)
Y1 = np.linspace(1, img_height, img_height)
[X, Y] = np.meshgrid(X1, Y1)
X = X - c_x
Y = Y - c_y
D2 = X * X + Y * Y
E2 = 2.0 * sigma * sigma
Exponent = D2 / E2
heatmap = np.exp(-Exponent)
return heatmap
def heatmap_n_point(points, size, scaler=30):
"""
根据中心点生成热力值(基于multivariate_normal)
:param points: 2d list [[point1], point2]
:param size: tuple (h,w)
:param scaler: 热力值缩放因子,默认为1(调整方差大小)
:return: np.array
"""
assert isinstance(size, tuple), "size 输入错误"
assert isinstance(points, list) and isinstance(points[0], list), "points input error"
xx, yy = np.meshgrid(range(size[1]), range(size[0]))
# evaluate kernels at grid points
xxyy = np.c_[xx.ravel(), yy.ravel()]
kernel = 0.0
for point in points:
kernel += multivariate_normal(point, scaler*random.choice([90,80,100])).pdf(xxyy)
# kernel += CenterLabelHeatMap(size[0], size[1], point[0], point[1],40)
return kernel.reshape(size)
if __name__ == '__main__':
import time
s = time.time()
img = heatmap_n_point([[40,30], [400,70], [200, 800], [80,300]], (800, 1200))
print(time.time()-s)
plt.imshow(img)
plt.show()