121 11 个案例掌握 Python 数据可视化--星际探索

星际探索

星空是无数人梦寐以求想了解的一个领域,远古的人们通过肉眼观察星空,并制定了太阴历,指导农业发展。随着现代科技发展,有了更先进的设备进行星空的探索。本实验获取了美国国家航空航天局(NASA)官网发布的地外行星数据,研究及可视化了地外行星各参数、寻找到了一颗类地行星并研究了天体参数的相关关系。
输入并执行魔法命令 %matplotlib inline, 设置全局字号,去除图例边框,去除右侧和顶部坐标轴。

import matplotlib.pyplot as plt
import warnings
%matplotlib inline

plt.rcParams['xtick.labelsize'] = 15
plt.rcParams['ytick.labelsize'] = 15
plt.rcParams['legend.fontsize'] = 15
plt.rcParams['axes.labelsize'] = 15
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['legend.frameon'] = False
warnings.filterwarnings("ignore") # 屏蔽报警

数据准备

本数据集来自 NASA,行星发现是 NASA 的重要工作之一,本数据集搜集了 NASA 官网发布的 4296 颗行星的数据,本数据集字段包括:


image.png

导入数据并查看前 5 行。

import pandas as pd
df = pd.read_excel(
    'https://labfile.oss.aliyuncs.com/courses/3023/NASA_planets.xls')
df.head()

每年发现的行星数

截至 2020 年 10 月 22 日 全球共发现 4296 颗行星,按年聚合并绘制年度行星发现数,并在左上角绘制 NASA 的官方 LOGO 。

import matplotlib.image as imgplt
plt.rcParams['figure.figsize'] = (6, 4)

fig = plt.figure()
# 生成主图
ax = fig.add_axes([0, 0, 1, 1], aspect='auto')
# 生成主图内子图
ax_im = fig.add_axes([0.04, 0.3, 0.5, 0.5], aspect='auto')
# 关闭子图坐标系
ax_im.axis('off')

# 主图绘制年度行星发现数
data = df.groupby(['disc_year']).size()
ax.bar(data.index, data, color='tab:red')

# 子图绘制 NASA 图标
img = imgplt.imread(
    fname='https://labfile.oss.aliyuncs.com/courses/3023/NASA.png')  # 读入 NASA 图标
ax_im.imshow(img)

ax.set_xlabel('Year')
ax.set_ylabel('Discovery Planet Number')
ax.set_title('Discovery Planet Number by Year')

从运行结果可以看出,2005 年以前全球行星发现数是非常少的,经计算总计 173 颗,2014 和 2016 是行星发现成果最多的年份,2016 年度发现行星 1505 颗。

不同机构/项目/计划发现的行星数

对不同机构/项目/计划进行聚合并降序排列,绘制发现行星数目的前 20 。

plt.rcParams['figure.figsize'] = (6, 12)
top_num = 20
data = df.groupby(['disc_facility']).size().sort_values(ascending=False)
data = data[:top_num].sort_values()
plt.barh(data.index, data)
plt.title('Discovery Planet Number by Facility Top %d\n' % top_num, ha='right')

# ----------------- 此处演示了坐标轴微调的相关技巧 -------------------------#
ax = plt.gca()
# 将底部轴调节到顶部
ax.spines['bottom'].set_position(('data', len(data)))
# 左侧轴略微向左调节 90 个单位
ax.spines['left'].set_position(('data', -90))
# 将 x 轴刻度标签移动至坐标轴外侧
ax.tick_params(axis='x', pad=-20)
# 绘制各个机构发现的星星数目
for y, width in zip(range(len(data.index)), data):
    plt.text(width+0.5, y, width, fontsize=15, va='center')

2009 年至 2013 年,开普勒太空望远镜成为有史以来最成功的系外行星发现者。在一片天空中至少找到了 1030 颗系外行星以及超过 4600 颗疑似行星。当机械故障剥夺了该探测器对于恒星的精确定位功能后,地球上的工程师们于 2014 年对其进行了彻底改造,并以 K2 计划命名,后者将在更短的时间内搜寻宇宙的另一片区域。


image.png

行星发现方式的占比

对发现行星的方式进行聚合并降序排列,绘制各种方法发现行星的比例,由于排名靠后的几种方式发现行星数较少,因此不显示其标签。

import numpy as np
plt.rcParams['figure.figsize'] = (6, 12)
data = df.groupby(['discoverymethod']).size().sort_values(ascending=False)

def autopct_fun(x):
    if x < 1:
        return None
    else:
        return '%.2f%%' % (2*x)
    
plt.pie(
    x=0.5*data/np.sum(data),
    labels=list(data.index[:3]) + ['']*(len(data)-3),  # 只显示前 3 个
    autopct=autopct_fun,  # 匿名函数,设置百分比显示方法
    pctdistance=0.7,
    textprops=dict(fontsize=18)
)

plt.ylim(0.5, 1)
plt.title('Discovery Planet Number by Method\n')

行星在宇宙中并不会发光,因此无法直接观察,行星发现的方式多为间接方式。从输出结果可以看出,发现行星主要有以下 3 种方式,其原理如下:


image.png

地球质量与发现行星质量的对比

针对不同的行星质量,绘制比其质量大(或者小)的行星比例,由于行星质量量纲分布跨度较大,因此采用对数坐标。

plt.rcParams['figure.figsize'] = (12, 6)


def gt_Percent(x, array):
    '''计算大于当前行星质量的比例'''
    return 100*np.sum(array >= x)/len(array)


def lt_Percent(x, array):
    '''计算小于当前行星质量的比例,取负值是为了将其映射到 y=0 轴下方'''
    return -1*100*np.sum(array < x)/len(array)


data = df[['pl_bmasse']].copy()
data.dropna(inplace=True)
data['gt_pl_bmasse'] = data['pl_bmasse'].apply(
    lambda x: gt_Percent(x, data['pl_bmasse']))
data['lt_pl_bmasse'] = data['pl_bmasse'].apply(
    lambda x: lt_Percent(x, data['pl_bmasse']))
data = data.sort_values(['gt_pl_bmasse'], ascending=False)

# 分别绘制大于当前质量和小于当前质量的条形图
plt.bar(np.log(data['pl_bmasse']), data['gt_pl_bmasse'], color='tab:blue')
plt.bar(np.log(data['pl_bmasse']), data['lt_pl_bmasse'], color='tab:red')

# 将底部x轴移动到 0 轴
ax = plt.gca()
ax.spines['bottom'].set_position(('data', 0))

# 将对数坐标转化为十进制坐标,此处是通过 np.exp 接口逆向( np.log 的逆向)计算而来
plt.xticks(
    [-4,      -2,        0,       2,      5,       7,      10,     12],
    ['0.02e', '0.14e', 'e', '7.39e', '148e', '1.1ke', '2.2we','16.3we']   
    # 上述单位解释   k:1000  w:10000  e:mass of earth
)

plt.text(2.5, 70, "Planet Heavier Than Earth", fontsize=15, ha='left')
plt.ylabel('Percent Distribution /%')

# 添加地球位置的注释
plt.annotate(
    s='critical point: 96.25%, the percent of heavier than earth ',
    xy=[0, 96.255558], xytext=[3, 50], fontsize=15,
    arrowprops=dict(
        arrowstyle="->",
        lw=3,
        ls='--',
        color='tab:orange',
        connectionstyle="angle,angleA=0,angleB=90,rad=3"
    )
)

plt.title('Percent Distribution of Mass of Planet')

从输出结果可以看出,在已发现的行星中,96.25% 行星的质量大于地球。(图中横坐标小于 e 的红色面积非常小)

全部行星的质量分布

通过 sns.distplot 接口绘制全部行星的质量分布图。

import seaborn as sns
plt.rcParams['figure.figsize'] = (12, 6)

data = df[['pl_bmasse']].copy()
data.dropna(inplace=True)

x = np.log(data['pl_bmasse'])

sns.distplot(x, color='tab:blue', hist_kws=dict(
    alpha=1), kde_kws=dict(color='tab:red', lw=4))
plt.xlabel('log(Planet Mass)')
plt.ylabel('Density')
plt.title('Mass of Planet Distribution')

从输出结果可以看出,所有行星质量分布呈双峰分布,第一个峰在 1.8 左右(此处用了对数单位,表示大约 6 个地球质量),第二个峰在 6.2 左右(大概 493 个地球质量)。

不同发现方式观测到的行星公转周期与其质量的关系

针对不同发现方式发现的行星,绘制各行星的公转周期和质量的关系。

import matplotlib.colors as cs
plt.rcParams['figure.figsize'] = (12, 6)

data = df[['pl_orbper', 'pl_bmasse', 'discoverymethod', 'sy_dist']].copy()
data.dropna(inplace=True)

methods = ['Transit', 'Radial Velocity', 'Microlensing', 'Imaging',
           'Transit Timing Variations', 'Eclipse Timing Variations',
           'Pulsar Timing', 'Orbital Brightness Modulation',
           'Pulsation Timing Variations', 'Disk Kinematics', 'Astrometry']

# 当我们不清楚系列有多少个时,可以尽可能多地增加颜色种类,这样在遍历时不会越界
colors = ['tab:blue', 'tab:red']+[value for value in cs.CSS4_COLORS.values()]
# 散点图支持的所有 marker
markers = 'sov^<>p*hH+xDd|.,1234'
for i, method in enumerate(methods):
    pdata = data.loc[data['discoverymethod'] == method]
    plt.scatter(
        np.log(pdata['pl_orbper']),
        np.log(pdata['pl_bmasse']),
        alpha=0.8,
        label=method,
        s=50,
        marker=markers[i],
        edgecolor=colors[i],
        facecolor='white',
        color=colors[i])

plt.legend(loc="upper left", bbox_to_anchor=(1.0, 0.8))
plt.xlabel('log(Orbital Period [days])')
plt.ylabel('log(Planet Mass)')
plt.title('Mass and Orbital Period of Planet by Discoverymethod')

从输出结果可以看出:径向速度(Radial Velocity)方法发现的行星在公转周期和质量上分布更宽,而凌日(Transit)似乎只能发现公转周期相对较短的行星,这是因为两种方法的原理差异造成的。对于公转周期很长的行星,其运行到恒星和观察者之间的时间也较长,因此凌日发现此类行星会相对较少。而径向速度与其说是在发现行星,不如说是在观察恒星,由于恒星自身发光,因此其观察机会更多,发现各类行星的可能性更大。

不同发现方式观测到的行星距离与其质量的关系

针对不同发现方式发现的行星,绘制各行星的距离和质量的关系。

plt.rcParams['figure.figsize'] = (12, 6)

data = df[['pl_orbper', 'pl_bmasse', 'discoverymethod', 'sy_dist']].copy()
data.dropna(inplace=True)

methods = ['Transit', 'Radial Velocity', 'Microlensing', 'Imaging',
           'Transit Timing Variations', 'Eclipse Timing Variations',
           'Pulsar Timing', 'Orbital Brightness Modulation',
           'Pulsation Timing Variations', 'Disk Kinematics', 'Astrometry']

colors = ['tab:blue', 'tab:red']+[value for value in cs.CSS4_COLORS.values()]
markers = 'sov^<>p*hH+xDd|.,1234'
for i, method in enumerate(methods):
    pdata = data.loc[data['discoverymethod'] == method]
    plt.scatter(
        np.log(pdata['sy_dist']),
        np.log(pdata['pl_bmasse']),
        label=method,
        s=50,
        marker=markers[i],
        edgecolors=colors[i],
        color='white',
    )

plt.legend(loc="upper left", bbox_to_anchor=(1.0, 0.8))
plt.xlabel('log(Distance)')
plt.ylabel('log(Planet Mass)')
plt.title('Mass and Distance of Planet by Discoverymethod')

从输出结果可以看出,凌日和径向速度对距离较为敏感,远距离的行星大多是通过凌日发现的,而近距离的行星大多数通过径向速度发现的。原因是:近距离的行星其引力对恒星造成的摆动更为明显,因此更容易观察;当距离较远时,引力作用变弱,摆动效应减弱,因此很难借助此方法观察到行星。同时,可以观察到当行星质量更大时,其距离分布相对较宽,这是因为虽然相对恒星的距离变长了,但是由于行星质量的增加,相对引力也同步增加,恒星摆动效应会变得明显。

行星质量与半径的关系

将所有行星的质量和半径对数化处理,绘制其分布并拟合其分布。
由于:


image.png

因此,从原理上质量对数与半径对数应该是线性关系,且斜率为定值 3 ,截距的大小与密度相关。

plt.rcParams['figure.figsize'] = (12, 6)

data = df[['pl_rade', 'pl_bmasse', 'discoverymethod', 'sy_dist']].copy()
data.dropna(inplace=True)

x, y = np.log(data['pl_rade']), np.log(data['pl_bmasse'])


plt.scatter(x, y,
            color='white', edgecolors='tab:red', label='origin')

# 拟合原始数据质量与半径的关系
fit_xy = np.polyfit(x, y, 1)
x_fitvalue = np.linspace(x.min(), x.max(), 100)
y_fitvalue = np.polyval(fit_xy, x_fitvalue)
plt.plot(x_fitvalue, y_fitvalue, lw=3, label='fit-line')
print(fit_xy)
# 根据类似地球的密度的直线
# 由于数据集中的质量和半径都是相对地球的,因此如果行星密度与地球类似,则截距项中密度相对地球取值应该为1
x_fitvalue = np.linspace(x.min(), x.max(), 100)
y_fitvalue = 3*x_fitvalue+np.log(4/3*np.pi*1.0)   # 类地星球,密度取 1
plt.plot(x_fitvalue, y_fitvalue, lw=3, ls='--', label='earth')


plt.legend()

plt.xlabel('log(Planet Radius)')
plt.ylabel('log(Planet Mass)')
plt.title('Mass and Radius of Planet Distribution')

从输出结果可以看出:行星质量和行星半径在对数变换下,具有较好的线性关系。输出 fix_xy 数值可知,其关系可以拟合出如下公式:


image.png

拟合出曲线对应的行星平均密度为:


image.png

说明大多数行星的密度都是低于地球的,大约为地球的 0.244179 。

恒星质量与半径的关系

同样的方式绘制恒星质量与半径的关系。

plt.rcParams['figure.figsize'] = (12, 6)

data = df[['st_rad', 'st_mass']].copy()
data.dropna(inplace=True)

x, y = np.log(data['st_rad']), np.log(data['st_mass'])

plt.scatter(x, y,
            color='white', edgecolors='tab:red', label='origin')


fit_xy = np.polyfit(x, y, 2)
x_fitvalue = np.linspace(x.min(), x.max(), 100)
y_fitvalue = np.polyval(fit_xy, x_fitvalue)
plt.plot(x_fitvalue, y_fitvalue, lw=3, label='fit-line')


plt.legend()
plt.xlabel('log(Stellar Radius)')
plt.ylabel('log(Stellar Mass)')
plt.text(-3, -4, '$\log(Stellar Mass)=A \\times log(Stellar Radius)^2+B \\times log(Stellar Radius)+C$', fontsize=16)
plt.title('Mass and Radius of Stellar Distribution')

从输出结果可以看出,恒星与行星的规律不同,其质量与半径在对数下呈二次曲线关系,其关系符合以下公式:


image.png

恒星表面重力加速度与半径的关系

同样的方式研究恒星表面重力加速度与半径的关系。

plt.rcParams['figure.figsize'] = (12, 6)

data = df[['st_rad', 'st_logg']].copy()
data.dropna(inplace=True)

x, y = np.log(data['st_rad']), np.array(data['st_logg'])   # 数据集中的重力加速度为对数处理后的

plt.scatter(x, y,
            color='white', edgecolors='tab:red', label='origin')

# 拟合曲线
fit_xy = np.polyfit(x, y, 1)
x_fitvalue = np.linspace(x.min(), x.max(), 100)
y_fitvalue = np.polyval(fit_xy, x_fitvalue)
plt.plot(x_fitvalue, y_fitvalue, lw=3, label='fit-line')

plt.legend()
plt.xlabel('log(Stellar Radius)')
plt.ylabel('Stellar Surface Gravity')

plt.title('Stellar Surface Gravity and Radius of Stellar Distribution')

从输出结果可以看出,恒星表面对数重力加速度与其对数半径呈现较好的线性关系:


image.png

行星与恒星特征分布及相关性探索

以上我们分别探索了各变量的分布和部分变量的相关关系,当数据较多时,可以通过 pd.plotting.scatter_matrix 接口,直接绘制各变量的分布和任意两个变量的散点图分布,对于数据的初步探索,该接口可以让我们迅速对数据全貌有较为清晰的认识。

data = df[['pl_orbper', 'pl_rade', 'pl_bmasse',
           'st_teff', 'st_rad', 'st_mass', 'st_logg']].copy()
data.dropna(inplace=True)

data = data.applymap(lambda x: np.log(x))

fig, ax = plt.subplots(1, 1, figsize=(15, 9))

axes = pd.plotting.scatter_matrix(
    frame=data,
    marker='.',
    s=8,
    ax=ax,
    alpha=0.6,
    edgecolor='tab:blue',
    cmap=plt.cm.RdBu,
    hist_kwds={'color': 'tab:red'},  # 直方图的参数
)
for ax in axes.ravel():
    ax.spines['top'].set_visible(True)
    ax.spines['right'].set_visible(True)

fig.suptitle(
    'Feature Distribution and Relation of Planet and Stellar', fontsize=20)

fig.align_xlabels()
fig.align_ylabels()

寻找类地行星

通过行星的半径和质量,恒星的半径和质量,以及行星的公转周期等指标与地球的相似性,寻找诸多行星中最类似地球的行星。

fig, ax = plt.subplots(1, 1)
# 将子图背景设置为黑色
ax.set(facecolor='k')

data = df[['pl_orbper', 'pl_rade', 'pl_bmasse', 'st_rad', 'st_mass']].copy()
data.dropna(inplace=True)

# 对所有特征取对数
data = data.applymap(lambda x: np.log(x))

# 构造新特征,地球相似性。由于所有特征都取了对数,且行星直径和质量单位均为相对地球的,因此其对数值为 0,   np.log(1)=0
# 将行星的距离和质量相乘,越接近 0 表示其与地球越相似
data['Earth Similar'] = data['pl_rade']*data['pl_bmasse']
# 同样恒星的单位为相对太阳,因此其对数值为 0
# 将恒星的距离和质量相乘,越接近 0 表示其与太阳越相似
data['Stellar Similar'] = data['st_rad']*data['st_mass']

# 公转周期调整函数,此函数描述了行星公转周期与地球的相似程度,越相似其值越大


def peroid_adjust(array, adjust_value):
    # 所有数据与adjust_value做差,adjust_value 为地球的公转周期数
    array = np.abs(array-adjust_value)
    # 差距最小的则其倒数会最大
    return 1.0/array


# 由于地球的公转周期为 365 天,且全部数据都进行了对数处理,因此 adjust_value 为 np.log(365)
data['peroid_similar'] = peroid_adjust(data['pl_orbper'], np.log(365))

plt.scatter(
    x=data['Earth Similar'],
    y=data['Stellar Similar'],
    c=data['pl_orbper'],  # marker 颜色
    s=100*data['peroid_similar'],  # marker 大小
    alpha=0.6,
    cmap=plt.cm.RdYlBu,
)

# 由于越接近 0 的值,表示其与地球及太阳越类似,因此将x y 轴范围缩小至 [-0.1 ,1 ] 区间
plt.xlim(-0.1, 1)
plt.ylim(-0.1, 1)
plt.xlabel('Earth Similar')
plt.ylabel('Stellar Similar')
plt.title('Find the Terrestrial Planet')

# 在 [-0.1 ,1 ] 区间寻找公转周期相似度最大的恒星索引号
target = data.loc[
    (data['Earth Similar'] > -0.1) &
    (data['Earth Similar'] < 1) &
    (data['Stellar Similar'] > -0.1) &
    (data['Stellar Similar'] < 1)
]
target_index = target.sort_values(['peroid_similar'], ascending=False).index[0]
# 打印该行星的名称
print(df.iloc[target_index]['pl_name'])

# 将行星的名称以注释的形式添加至图中
plt.annotate(
    s=df.iloc[target_index]['pl_name'],
    xy=[0.6, 0.04],     # 可以先运行并生成散点图后再添加注释代码,xy的位置为观察散点图所得
    xytext=[0.7, 0.4],
    fontsize=15,
    color='white',
    arrowprops=dict(
        arrowstyle="->",
        lw=3,
        ls='-',
        color='white',
    )
)

从输出结果可以看出,在 0.6 附近的位置出现了一个最大的圆圈,那就是我们找到的类地行星 Kepler - 452 b ,让我们了解一下这颗行星:

df.iloc[target_index]

数据显示,Kepler - 452 b 行星公转周期为 384.84 天,半径为 1.63 地球半径,质量为 3.29 地球质量;它的恒星为 Kepler - 452 半径为太阳的 1.11 倍,质量为 1.04 倍,恒星方面数据与太阳相似度极高。
以下内容来自百度百科。 开普勒452b(Kepler 452b),是美国国家航空航天局(NASA)发现的外行星, 直径是地球的 1.6 倍,地球相似指数( ESI )为 0.83,距离地球1400光年,位于为天鹅座。
2015 年 7 月 24 日 0:00,美国国家航空航天局 NASA 举办媒体电话会议宣称,他们在天鹅座发现了一颗与地球相似指数达到 0.98 的类地行星开普勒 - 452 b。这个类地行星距离地球 1400 光年,绕着一颗与太阳非常相似的恒星运行。开普勒 452 b 到恒星的距离,跟地球到太阳的距离相同。NASA 称,由于缺乏关键数据,现在不能说 Kepler - 452 b 究竟是不是“另外一个地球”,只能说它是“迄今最接近另外一个地球”的系外行星。

地球与类地行星的相对位置标记

在银河系经纬度坐标下绘制所有行星,并标记地球和 Kepler - 452 b 行星的位置。

fig, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.set(fc='black')

# Kepler - 452 b 行星数据索引号
Kepler_index = 3115

# 绘制所有行星
plt.scatter(
    x='glon',
    y='glat',
    s=df['pl_bmasse']*0.002,
    c='glon',
    cmap=plt.cm.RdYlBu_r,
    data=df
)
# 标记 Kepler - 452 b 行星位置
plt.scatter(
    x='glon',
    y='glat',
    s=200,
    c='tab:red',
    data=df.iloc[Kepler_index]
)
# 标记地球位置,以距离地球最近的比邻星 b 的坐标(距离地球 4.22 光年),作为地球坐标  313.9399,-1.927165

plt.scatter(
    x=313.9399,
    y=-1.927165,
    s=200,
    c='tab:blue',
)

# 标注地球
plt.annotate(
    s='Earth',
    xy=[313.9399, -1.927165],
    xytext=[210, 75],
    fontsize=15,
    color='white',
    arrowprops=dict(
        arrowstyle="->",
        lw=3,
        ls='-',
        color='white',
    )
)

# Kepler-452 b
plt.annotate(
    s='Kepler-452 b',
    xy=[df.iloc[Kepler_index]['glon'], df.iloc[Kepler_index]['glat']],
    xytext=[0, 75],
    fontsize=15,
    color='white',
    arrowprops=dict(
        arrowstyle="->",
        lw=3,
        ls='-',
        color='white',
    )
)

plt.xlabel('Galactic Longitude /deg')
plt.ylabel('Galactic Latitude /deg')
plt.title('Where Does the Kepler-452 b Locate ?')

类地行星,是人类寄希望移民的第二故乡,但即使最近的 Kepler-452 b ,也与地球相聚 1400 光年。

行星聚类

以下通过行星的公转周期和质量两个特征将所有行星聚为两类,即通过训练获得两个簇心。
定义函数-计算距离
聚类距离采用欧式距离:

image.png

def distance_L(fea_1, fea_2, center_x, center_y, data):
    x_x = np.power(data[fea_1]-center_x, 2)  # (x-cx)^2
    y_y = np.power(data[fea_2]-center_y, 2)  # (y-cy)^2
    return np.array(np.sqrt(x_x+y_y))

定义函数-训练簇心
训练簇心的原理是:根据上一次的簇心计算所有点与所有簇心的距离,任一点的分类以其距离最近的簇心确定。依此原理计算出所有点的分类后,对每个分类计算新的簇心。

def train_centers(fea_1, fea_2, centers, data):
    '''函数输入训练数据和上一次训练获得的簇心,返回下一次簇心'''

    # 上一个迭代次获得的簇心
    (center1_x, center1_y), (center2_x, center2_y) = centers

    # 分别计算与簇心1,和簇心2的欧式距离
    distance_1 = distance_L(fea_1, fea_2, center1_x, center1_y, data)
    distance_2 = distance_L(fea_1, fea_2, center2_x, center2_y, data)

    # 与 center1 欧氏距离较近的数据点,他们的均值为新簇心
    center1_x = data.loc[distance_1 < distance_2, fea_1].mean()
    center1_y = data.loc[distance_1 < distance_2, fea_2].mean()

    # 与 center2 欧氏距离较近的数据点,他们的均值为新簇心
    center2_x = data.loc[distance_1 > distance_2, fea_1].mean()
    center2_y = data.loc[distance_1 > distance_2, fea_2].mean()

    return (center1_x, center1_y), (center2_x, center2_y)


# 测试,以随机生成的簇心为上一次输入,测试函数训练结果
train_centers('pl_bmasse', 'pl_orbper', centers=((3, 3), (8, 8)), data=data)
# ((3, 3), (8, 8)) 经一次训练后变为
# ((2.36814897560919, 2.3440864259971432),
# (6.895476126190156, 6.774695415955819))

定义函数预测分类
根据训练得到的簇心,预测输入新的数据特征的分类。

def pred_centers(x, y, centers):
    (center1_x, center1_y), (center2_x, center2_y) = centers
    # 需要预测的x,y与簇心的距离
    d1 = np.sqrt(np.power(x-center1_x, 2)+np.power(y-center1_y, 2))
    d2 = np.sqrt(np.power(x-center2_x, 2)+np.power(y-center2_y, 2))
    # 若距离center1簇心的距离小于center2,则为0,否则为1,此处将bool直接转为int
    pred_classes = np.array(d1 < d2, dtype='int')
    return pred_classes


# 测试,假定簇心为 [[3, 6], [9, 5]] 即(3,9)和(6,5),测试 x,y是否能以此簇心为基础进行数据分类
x = np.array([1, 1, 1, 5, 5, 5, 10, 10, 10])
y = np.array([1, 5, 10, 1, 5, 10, 1, 5, 10])
pred_centers(x, y, [[3, 6], [9, 5]])

开始训练
随机生成一个簇心,并训练 15 次。

np.random.seed(122)
fea_1, fea_2 = 'pl_bmasse', 'pl_orbper'
epochs = 15
centers = np.random.randint(low=3, high=10, size=(2, 2))  # 随机初始化簇心
centers_all = [centers]  # 保存所有迭代次数的簇心
for i in range(epochs):
    centers = train_centers(fea_1, fea_2, centers, data=data)
    centers_all.append(centers)

# 查看随机生成的簇心及前两轮训练得到的簇心
centers_all[:2]

绘制聚类结果
以最后一次训练得到的簇心为基础,进行行星的分类,并以等高面的形式绘制各类的边界。

plt.rcParams['figure.figsize'] = (10, 6)
epoch = epochs-1

# 根据簇心预测行星所属分类
classes = pred_centers(
    data['pl_orbper'], data['pl_bmasse'], centers_all[epoch])

# 绘制各类别的数据图
plt.scatter('pl_orbper', 'pl_bmasse',
            data=data.loc[classes == 0], color='white', edgecolors='tab:red', label='class_0')
plt.scatter('pl_orbper', 'pl_bmasse',
            data=data.loc[classes == 1], color='white', edgecolors='tab:blue', label='class_1')

# 绘制簇心
(center1_x, center1_y), (center2_x, center2_y) = centers_all[epoch]
plt.scatter(center1_x, center1_y, marker='^', s=300, c='r')
plt.scatter(center2_x, center2_y, marker='v', s=300, c='r')

# 绘制分类边界
ax = plt.gca()
# 获得子图的 x y 轴范围
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
# 在 x y 轴范围区间中生成等长分布的数据点
x = np.linspace(xmin, xmax, 1000)   # shape 1000*1
y = np.linspace(ymin, ymax, 1000)   # shape 1000*1
# 通过 np.meshgrid 生成数据网格
X, Y = np.meshgrid(x, y)    # shape 1000*1000
# 对网格中的每个点进行分类结果的计算
Z = pred_centers(X, Y, centers_all[epoch])    # shape 1000*1000

# 等高面绘图接口
plt.contourf(X, Y, Z, cmap=plt.cm.RdBu, alpha=0.2)

# 进一步调节 x 轴范围
plt.xlim(-2.5, 12)

plt.legend()
plt.xlabel('log(Orbital Period [days])')
plt.ylabel('log(Planet Mass)')
plt.title('Planet Cluster by Mass and Orbital Period after Train Epoch %d' %
          epoch, fontsize=16)

从运行结果可以看出,所有行星被分成了两类。并通过上三角和下三角标注了每个类别的簇心位置。
聚类前
以下输出了聚类前原始数据绘制的图像。

plt.rcParams['figure.figsize'] = (10, 6)

data = df[['pl_bmasse', 'pl_orbper']].copy()
data.dropna(inplace=True)

data = data.apply(lambda x: np.log(x))

plt.scatter('pl_orbper', 'pl_bmasse', data=data,
            color='white', edgecolors='tab:red')
plt.xlim(-2.5, 12)
plt.xlabel('log(Orbital Period [days])')
plt.ylabel('log(Planet Mass)')
plt.title('Mass and Orbital Period of Planet Distribution')
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 218,525评论 6 507
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,203评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,862评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,728评论 1 294
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,743评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,590评论 1 305
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,330评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,244评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,693评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,885评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,001评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,723评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,343评论 3 330
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,919评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,042评论 1 270
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,191评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,955评论 2 355

推荐阅读更多精彩内容