星际探索
星空是无数人梦寐以求想了解的一个领域,远古的人们通过肉眼观察星空,并制定了太阴历,指导农业发展。随着现代科技发展,有了更先进的设备进行星空的探索。本实验获取了美国国家航空航天局(NASA)官网发布的地外行星数据,研究及可视化了地外行星各参数、寻找到了一颗类地行星并研究了天体参数的相关关系。
输入并执行魔法命令 %matplotlib inline, 设置全局字号,去除图例边框,去除右侧和顶部坐标轴。
import matplotlib.pyplot as plt
import warnings
%matplotlib inline
plt.rcParams['xtick.labelsize'] = 15
plt.rcParams['ytick.labelsize'] = 15
plt.rcParams['legend.fontsize'] = 15
plt.rcParams['axes.labelsize'] = 15
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['legend.frameon'] = False
warnings.filterwarnings("ignore") # 屏蔽报警
数据准备
本数据集来自 NASA,行星发现是 NASA 的重要工作之一,本数据集搜集了 NASA 官网发布的 4296 颗行星的数据,本数据集字段包括:
导入数据并查看前 5 行。
import pandas as pd
df = pd.read_excel(
'https://labfile.oss.aliyuncs.com/courses/3023/NASA_planets.xls')
df.head()
每年发现的行星数
截至 2020 年 10 月 22 日 全球共发现 4296 颗行星,按年聚合并绘制年度行星发现数,并在左上角绘制 NASA 的官方 LOGO 。
import matplotlib.image as imgplt
plt.rcParams['figure.figsize'] = (6, 4)
fig = plt.figure()
# 生成主图
ax = fig.add_axes([0, 0, 1, 1], aspect='auto')
# 生成主图内子图
ax_im = fig.add_axes([0.04, 0.3, 0.5, 0.5], aspect='auto')
# 关闭子图坐标系
ax_im.axis('off')
# 主图绘制年度行星发现数
data = df.groupby(['disc_year']).size()
ax.bar(data.index, data, color='tab:red')
# 子图绘制 NASA 图标
img = imgplt.imread(
fname='https://labfile.oss.aliyuncs.com/courses/3023/NASA.png') # 读入 NASA 图标
ax_im.imshow(img)
ax.set_xlabel('Year')
ax.set_ylabel('Discovery Planet Number')
ax.set_title('Discovery Planet Number by Year')
从运行结果可以看出,2005 年以前全球行星发现数是非常少的,经计算总计 173 颗,2014 和 2016 是行星发现成果最多的年份,2016 年度发现行星 1505 颗。
不同机构/项目/计划发现的行星数
对不同机构/项目/计划进行聚合并降序排列,绘制发现行星数目的前 20 。
plt.rcParams['figure.figsize'] = (6, 12)
top_num = 20
data = df.groupby(['disc_facility']).size().sort_values(ascending=False)
data = data[:top_num].sort_values()
plt.barh(data.index, data)
plt.title('Discovery Planet Number by Facility Top %d\n' % top_num, ha='right')
# ----------------- 此处演示了坐标轴微调的相关技巧 -------------------------#
ax = plt.gca()
# 将底部轴调节到顶部
ax.spines['bottom'].set_position(('data', len(data)))
# 左侧轴略微向左调节 90 个单位
ax.spines['left'].set_position(('data', -90))
# 将 x 轴刻度标签移动至坐标轴外侧
ax.tick_params(axis='x', pad=-20)
# 绘制各个机构发现的星星数目
for y, width in zip(range(len(data.index)), data):
plt.text(width+0.5, y, width, fontsize=15, va='center')
2009 年至 2013 年,开普勒太空望远镜成为有史以来最成功的系外行星发现者。在一片天空中至少找到了 1030 颗系外行星以及超过 4600 颗疑似行星。当机械故障剥夺了该探测器对于恒星的精确定位功能后,地球上的工程师们于 2014 年对其进行了彻底改造,并以 K2 计划命名,后者将在更短的时间内搜寻宇宙的另一片区域。
行星发现方式的占比
对发现行星的方式进行聚合并降序排列,绘制各种方法发现行星的比例,由于排名靠后的几种方式发现行星数较少,因此不显示其标签。
import numpy as np
plt.rcParams['figure.figsize'] = (6, 12)
data = df.groupby(['discoverymethod']).size().sort_values(ascending=False)
def autopct_fun(x):
if x < 1:
return None
else:
return '%.2f%%' % (2*x)
plt.pie(
x=0.5*data/np.sum(data),
labels=list(data.index[:3]) + ['']*(len(data)-3), # 只显示前 3 个
autopct=autopct_fun, # 匿名函数,设置百分比显示方法
pctdistance=0.7,
textprops=dict(fontsize=18)
)
plt.ylim(0.5, 1)
plt.title('Discovery Planet Number by Method\n')
行星在宇宙中并不会发光,因此无法直接观察,行星发现的方式多为间接方式。从输出结果可以看出,发现行星主要有以下 3 种方式,其原理如下:
地球质量与发现行星质量的对比
针对不同的行星质量,绘制比其质量大(或者小)的行星比例,由于行星质量量纲分布跨度较大,因此采用对数坐标。
plt.rcParams['figure.figsize'] = (12, 6)
def gt_Percent(x, array):
'''计算大于当前行星质量的比例'''
return 100*np.sum(array >= x)/len(array)
def lt_Percent(x, array):
'''计算小于当前行星质量的比例,取负值是为了将其映射到 y=0 轴下方'''
return -1*100*np.sum(array < x)/len(array)
data = df[['pl_bmasse']].copy()
data.dropna(inplace=True)
data['gt_pl_bmasse'] = data['pl_bmasse'].apply(
lambda x: gt_Percent(x, data['pl_bmasse']))
data['lt_pl_bmasse'] = data['pl_bmasse'].apply(
lambda x: lt_Percent(x, data['pl_bmasse']))
data = data.sort_values(['gt_pl_bmasse'], ascending=False)
# 分别绘制大于当前质量和小于当前质量的条形图
plt.bar(np.log(data['pl_bmasse']), data['gt_pl_bmasse'], color='tab:blue')
plt.bar(np.log(data['pl_bmasse']), data['lt_pl_bmasse'], color='tab:red')
# 将底部x轴移动到 0 轴
ax = plt.gca()
ax.spines['bottom'].set_position(('data', 0))
# 将对数坐标转化为十进制坐标,此处是通过 np.exp 接口逆向( np.log 的逆向)计算而来
plt.xticks(
[-4, -2, 0, 2, 5, 7, 10, 12],
['0.02e', '0.14e', 'e', '7.39e', '148e', '1.1ke', '2.2we','16.3we']
# 上述单位解释 k:1000 w:10000 e:mass of earth
)
plt.text(2.5, 70, "Planet Heavier Than Earth", fontsize=15, ha='left')
plt.ylabel('Percent Distribution /%')
# 添加地球位置的注释
plt.annotate(
s='critical point: 96.25%, the percent of heavier than earth ',
xy=[0, 96.255558], xytext=[3, 50], fontsize=15,
arrowprops=dict(
arrowstyle="->",
lw=3,
ls='--',
color='tab:orange',
connectionstyle="angle,angleA=0,angleB=90,rad=3"
)
)
plt.title('Percent Distribution of Mass of Planet')
从输出结果可以看出,在已发现的行星中,96.25% 行星的质量大于地球。(图中横坐标小于 e 的红色面积非常小)
全部行星的质量分布
通过 sns.distplot 接口绘制全部行星的质量分布图。
import seaborn as sns
plt.rcParams['figure.figsize'] = (12, 6)
data = df[['pl_bmasse']].copy()
data.dropna(inplace=True)
x = np.log(data['pl_bmasse'])
sns.distplot(x, color='tab:blue', hist_kws=dict(
alpha=1), kde_kws=dict(color='tab:red', lw=4))
plt.xlabel('log(Planet Mass)')
plt.ylabel('Density')
plt.title('Mass of Planet Distribution')
从输出结果可以看出,所有行星质量分布呈双峰分布,第一个峰在 1.8 左右(此处用了对数单位,表示大约 6 个地球质量),第二个峰在 6.2 左右(大概 493 个地球质量)。
不同发现方式观测到的行星公转周期与其质量的关系
针对不同发现方式发现的行星,绘制各行星的公转周期和质量的关系。
import matplotlib.colors as cs
plt.rcParams['figure.figsize'] = (12, 6)
data = df[['pl_orbper', 'pl_bmasse', 'discoverymethod', 'sy_dist']].copy()
data.dropna(inplace=True)
methods = ['Transit', 'Radial Velocity', 'Microlensing', 'Imaging',
'Transit Timing Variations', 'Eclipse Timing Variations',
'Pulsar Timing', 'Orbital Brightness Modulation',
'Pulsation Timing Variations', 'Disk Kinematics', 'Astrometry']
# 当我们不清楚系列有多少个时,可以尽可能多地增加颜色种类,这样在遍历时不会越界
colors = ['tab:blue', 'tab:red']+[value for value in cs.CSS4_COLORS.values()]
# 散点图支持的所有 marker
markers = 'sov^<>p*hH+xDd|.,1234'
for i, method in enumerate(methods):
pdata = data.loc[data['discoverymethod'] == method]
plt.scatter(
np.log(pdata['pl_orbper']),
np.log(pdata['pl_bmasse']),
alpha=0.8,
label=method,
s=50,
marker=markers[i],
edgecolor=colors[i],
facecolor='white',
color=colors[i])
plt.legend(loc="upper left", bbox_to_anchor=(1.0, 0.8))
plt.xlabel('log(Orbital Period [days])')
plt.ylabel('log(Planet Mass)')
plt.title('Mass and Orbital Period of Planet by Discoverymethod')
从输出结果可以看出:径向速度(Radial Velocity)方法发现的行星在公转周期和质量上分布更宽,而凌日(Transit)似乎只能发现公转周期相对较短的行星,这是因为两种方法的原理差异造成的。对于公转周期很长的行星,其运行到恒星和观察者之间的时间也较长,因此凌日发现此类行星会相对较少。而径向速度与其说是在发现行星,不如说是在观察恒星,由于恒星自身发光,因此其观察机会更多,发现各类行星的可能性更大。
不同发现方式观测到的行星距离与其质量的关系
针对不同发现方式发现的行星,绘制各行星的距离和质量的关系。
plt.rcParams['figure.figsize'] = (12, 6)
data = df[['pl_orbper', 'pl_bmasse', 'discoverymethod', 'sy_dist']].copy()
data.dropna(inplace=True)
methods = ['Transit', 'Radial Velocity', 'Microlensing', 'Imaging',
'Transit Timing Variations', 'Eclipse Timing Variations',
'Pulsar Timing', 'Orbital Brightness Modulation',
'Pulsation Timing Variations', 'Disk Kinematics', 'Astrometry']
colors = ['tab:blue', 'tab:red']+[value for value in cs.CSS4_COLORS.values()]
markers = 'sov^<>p*hH+xDd|.,1234'
for i, method in enumerate(methods):
pdata = data.loc[data['discoverymethod'] == method]
plt.scatter(
np.log(pdata['sy_dist']),
np.log(pdata['pl_bmasse']),
label=method,
s=50,
marker=markers[i],
edgecolors=colors[i],
color='white',
)
plt.legend(loc="upper left", bbox_to_anchor=(1.0, 0.8))
plt.xlabel('log(Distance)')
plt.ylabel('log(Planet Mass)')
plt.title('Mass and Distance of Planet by Discoverymethod')
从输出结果可以看出,凌日和径向速度对距离较为敏感,远距离的行星大多是通过凌日发现的,而近距离的行星大多数通过径向速度发现的。原因是:近距离的行星其引力对恒星造成的摆动更为明显,因此更容易观察;当距离较远时,引力作用变弱,摆动效应减弱,因此很难借助此方法观察到行星。同时,可以观察到当行星质量更大时,其距离分布相对较宽,这是因为虽然相对恒星的距离变长了,但是由于行星质量的增加,相对引力也同步增加,恒星摆动效应会变得明显。
行星质量与半径的关系
将所有行星的质量和半径对数化处理,绘制其分布并拟合其分布。
由于:
因此,从原理上质量对数与半径对数应该是线性关系,且斜率为定值 3 ,截距的大小与密度相关。
plt.rcParams['figure.figsize'] = (12, 6)
data = df[['pl_rade', 'pl_bmasse', 'discoverymethod', 'sy_dist']].copy()
data.dropna(inplace=True)
x, y = np.log(data['pl_rade']), np.log(data['pl_bmasse'])
plt.scatter(x, y,
color='white', edgecolors='tab:red', label='origin')
# 拟合原始数据质量与半径的关系
fit_xy = np.polyfit(x, y, 1)
x_fitvalue = np.linspace(x.min(), x.max(), 100)
y_fitvalue = np.polyval(fit_xy, x_fitvalue)
plt.plot(x_fitvalue, y_fitvalue, lw=3, label='fit-line')
print(fit_xy)
# 根据类似地球的密度的直线
# 由于数据集中的质量和半径都是相对地球的,因此如果行星密度与地球类似,则截距项中密度相对地球取值应该为1
x_fitvalue = np.linspace(x.min(), x.max(), 100)
y_fitvalue = 3*x_fitvalue+np.log(4/3*np.pi*1.0) # 类地星球,密度取 1
plt.plot(x_fitvalue, y_fitvalue, lw=3, ls='--', label='earth')
plt.legend()
plt.xlabel('log(Planet Radius)')
plt.ylabel('log(Planet Mass)')
plt.title('Mass and Radius of Planet Distribution')
从输出结果可以看出:行星质量和行星半径在对数变换下,具有较好的线性关系。输出 fix_xy 数值可知,其关系可以拟合出如下公式:
拟合出曲线对应的行星平均密度为:
说明大多数行星的密度都是低于地球的,大约为地球的 0.244179 。
恒星质量与半径的关系
同样的方式绘制恒星质量与半径的关系。
plt.rcParams['figure.figsize'] = (12, 6)
data = df[['st_rad', 'st_mass']].copy()
data.dropna(inplace=True)
x, y = np.log(data['st_rad']), np.log(data['st_mass'])
plt.scatter(x, y,
color='white', edgecolors='tab:red', label='origin')
fit_xy = np.polyfit(x, y, 2)
x_fitvalue = np.linspace(x.min(), x.max(), 100)
y_fitvalue = np.polyval(fit_xy, x_fitvalue)
plt.plot(x_fitvalue, y_fitvalue, lw=3, label='fit-line')
plt.legend()
plt.xlabel('log(Stellar Radius)')
plt.ylabel('log(Stellar Mass)')
plt.text(-3, -4, '$\log(Stellar Mass)=A \\times log(Stellar Radius)^2+B \\times log(Stellar Radius)+C$', fontsize=16)
plt.title('Mass and Radius of Stellar Distribution')
从输出结果可以看出,恒星与行星的规律不同,其质量与半径在对数下呈二次曲线关系,其关系符合以下公式:
恒星表面重力加速度与半径的关系
同样的方式研究恒星表面重力加速度与半径的关系。
plt.rcParams['figure.figsize'] = (12, 6)
data = df[['st_rad', 'st_logg']].copy()
data.dropna(inplace=True)
x, y = np.log(data['st_rad']), np.array(data['st_logg']) # 数据集中的重力加速度为对数处理后的
plt.scatter(x, y,
color='white', edgecolors='tab:red', label='origin')
# 拟合曲线
fit_xy = np.polyfit(x, y, 1)
x_fitvalue = np.linspace(x.min(), x.max(), 100)
y_fitvalue = np.polyval(fit_xy, x_fitvalue)
plt.plot(x_fitvalue, y_fitvalue, lw=3, label='fit-line')
plt.legend()
plt.xlabel('log(Stellar Radius)')
plt.ylabel('Stellar Surface Gravity')
plt.title('Stellar Surface Gravity and Radius of Stellar Distribution')
从输出结果可以看出,恒星表面对数重力加速度与其对数半径呈现较好的线性关系:
行星与恒星特征分布及相关性探索
以上我们分别探索了各变量的分布和部分变量的相关关系,当数据较多时,可以通过 pd.plotting.scatter_matrix 接口,直接绘制各变量的分布和任意两个变量的散点图分布,对于数据的初步探索,该接口可以让我们迅速对数据全貌有较为清晰的认识。
data = df[['pl_orbper', 'pl_rade', 'pl_bmasse',
'st_teff', 'st_rad', 'st_mass', 'st_logg']].copy()
data.dropna(inplace=True)
data = data.applymap(lambda x: np.log(x))
fig, ax = plt.subplots(1, 1, figsize=(15, 9))
axes = pd.plotting.scatter_matrix(
frame=data,
marker='.',
s=8,
ax=ax,
alpha=0.6,
edgecolor='tab:blue',
cmap=plt.cm.RdBu,
hist_kwds={'color': 'tab:red'}, # 直方图的参数
)
for ax in axes.ravel():
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
fig.suptitle(
'Feature Distribution and Relation of Planet and Stellar', fontsize=20)
fig.align_xlabels()
fig.align_ylabels()
寻找类地行星
通过行星的半径和质量,恒星的半径和质量,以及行星的公转周期等指标与地球的相似性,寻找诸多行星中最类似地球的行星。
fig, ax = plt.subplots(1, 1)
# 将子图背景设置为黑色
ax.set(facecolor='k')
data = df[['pl_orbper', 'pl_rade', 'pl_bmasse', 'st_rad', 'st_mass']].copy()
data.dropna(inplace=True)
# 对所有特征取对数
data = data.applymap(lambda x: np.log(x))
# 构造新特征,地球相似性。由于所有特征都取了对数,且行星直径和质量单位均为相对地球的,因此其对数值为 0, np.log(1)=0
# 将行星的距离和质量相乘,越接近 0 表示其与地球越相似
data['Earth Similar'] = data['pl_rade']*data['pl_bmasse']
# 同样恒星的单位为相对太阳,因此其对数值为 0
# 将恒星的距离和质量相乘,越接近 0 表示其与太阳越相似
data['Stellar Similar'] = data['st_rad']*data['st_mass']
# 公转周期调整函数,此函数描述了行星公转周期与地球的相似程度,越相似其值越大
def peroid_adjust(array, adjust_value):
# 所有数据与adjust_value做差,adjust_value 为地球的公转周期数
array = np.abs(array-adjust_value)
# 差距最小的则其倒数会最大
return 1.0/array
# 由于地球的公转周期为 365 天,且全部数据都进行了对数处理,因此 adjust_value 为 np.log(365)
data['peroid_similar'] = peroid_adjust(data['pl_orbper'], np.log(365))
plt.scatter(
x=data['Earth Similar'],
y=data['Stellar Similar'],
c=data['pl_orbper'], # marker 颜色
s=100*data['peroid_similar'], # marker 大小
alpha=0.6,
cmap=plt.cm.RdYlBu,
)
# 由于越接近 0 的值,表示其与地球及太阳越类似,因此将x y 轴范围缩小至 [-0.1 ,1 ] 区间
plt.xlim(-0.1, 1)
plt.ylim(-0.1, 1)
plt.xlabel('Earth Similar')
plt.ylabel('Stellar Similar')
plt.title('Find the Terrestrial Planet')
# 在 [-0.1 ,1 ] 区间寻找公转周期相似度最大的恒星索引号
target = data.loc[
(data['Earth Similar'] > -0.1) &
(data['Earth Similar'] < 1) &
(data['Stellar Similar'] > -0.1) &
(data['Stellar Similar'] < 1)
]
target_index = target.sort_values(['peroid_similar'], ascending=False).index[0]
# 打印该行星的名称
print(df.iloc[target_index]['pl_name'])
# 将行星的名称以注释的形式添加至图中
plt.annotate(
s=df.iloc[target_index]['pl_name'],
xy=[0.6, 0.04], # 可以先运行并生成散点图后再添加注释代码,xy的位置为观察散点图所得
xytext=[0.7, 0.4],
fontsize=15,
color='white',
arrowprops=dict(
arrowstyle="->",
lw=3,
ls='-',
color='white',
)
)
从输出结果可以看出,在 0.6 附近的位置出现了一个最大的圆圈,那就是我们找到的类地行星 Kepler - 452 b ,让我们了解一下这颗行星:
df.iloc[target_index]
数据显示,Kepler - 452 b 行星公转周期为 384.84 天,半径为 1.63 地球半径,质量为 3.29 地球质量;它的恒星为 Kepler - 452 半径为太阳的 1.11 倍,质量为 1.04 倍,恒星方面数据与太阳相似度极高。
以下内容来自百度百科。 开普勒452b(Kepler 452b),是美国国家航空航天局(NASA)发现的外行星, 直径是地球的 1.6 倍,地球相似指数( ESI )为 0.83,距离地球1400光年,位于为天鹅座。
2015 年 7 月 24 日 0:00,美国国家航空航天局 NASA 举办媒体电话会议宣称,他们在天鹅座发现了一颗与地球相似指数达到 0.98 的类地行星开普勒 - 452 b。这个类地行星距离地球 1400 光年,绕着一颗与太阳非常相似的恒星运行。开普勒 452 b 到恒星的距离,跟地球到太阳的距离相同。NASA 称,由于缺乏关键数据,现在不能说 Kepler - 452 b 究竟是不是“另外一个地球”,只能说它是“迄今最接近另外一个地球”的系外行星。
地球与类地行星的相对位置标记
在银河系经纬度坐标下绘制所有行星,并标记地球和 Kepler - 452 b 行星的位置。
fig, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.set(fc='black')
# Kepler - 452 b 行星数据索引号
Kepler_index = 3115
# 绘制所有行星
plt.scatter(
x='glon',
y='glat',
s=df['pl_bmasse']*0.002,
c='glon',
cmap=plt.cm.RdYlBu_r,
data=df
)
# 标记 Kepler - 452 b 行星位置
plt.scatter(
x='glon',
y='glat',
s=200,
c='tab:red',
data=df.iloc[Kepler_index]
)
# 标记地球位置,以距离地球最近的比邻星 b 的坐标(距离地球 4.22 光年),作为地球坐标 313.9399,-1.927165
plt.scatter(
x=313.9399,
y=-1.927165,
s=200,
c='tab:blue',
)
# 标注地球
plt.annotate(
s='Earth',
xy=[313.9399, -1.927165],
xytext=[210, 75],
fontsize=15,
color='white',
arrowprops=dict(
arrowstyle="->",
lw=3,
ls='-',
color='white',
)
)
# Kepler-452 b
plt.annotate(
s='Kepler-452 b',
xy=[df.iloc[Kepler_index]['glon'], df.iloc[Kepler_index]['glat']],
xytext=[0, 75],
fontsize=15,
color='white',
arrowprops=dict(
arrowstyle="->",
lw=3,
ls='-',
color='white',
)
)
plt.xlabel('Galactic Longitude /deg')
plt.ylabel('Galactic Latitude /deg')
plt.title('Where Does the Kepler-452 b Locate ?')
类地行星,是人类寄希望移民的第二故乡,但即使最近的 Kepler-452 b ,也与地球相聚 1400 光年。
行星聚类
以下通过行星的公转周期和质量两个特征将所有行星聚为两类,即通过训练获得两个簇心。
定义函数-计算距离
聚类距离采用欧式距离:
def distance_L(fea_1, fea_2, center_x, center_y, data):
x_x = np.power(data[fea_1]-center_x, 2) # (x-cx)^2
y_y = np.power(data[fea_2]-center_y, 2) # (y-cy)^2
return np.array(np.sqrt(x_x+y_y))
定义函数-训练簇心
训练簇心的原理是:根据上一次的簇心计算所有点与所有簇心的距离,任一点的分类以其距离最近的簇心确定。依此原理计算出所有点的分类后,对每个分类计算新的簇心。
def train_centers(fea_1, fea_2, centers, data):
'''函数输入训练数据和上一次训练获得的簇心,返回下一次簇心'''
# 上一个迭代次获得的簇心
(center1_x, center1_y), (center2_x, center2_y) = centers
# 分别计算与簇心1,和簇心2的欧式距离
distance_1 = distance_L(fea_1, fea_2, center1_x, center1_y, data)
distance_2 = distance_L(fea_1, fea_2, center2_x, center2_y, data)
# 与 center1 欧氏距离较近的数据点,他们的均值为新簇心
center1_x = data.loc[distance_1 < distance_2, fea_1].mean()
center1_y = data.loc[distance_1 < distance_2, fea_2].mean()
# 与 center2 欧氏距离较近的数据点,他们的均值为新簇心
center2_x = data.loc[distance_1 > distance_2, fea_1].mean()
center2_y = data.loc[distance_1 > distance_2, fea_2].mean()
return (center1_x, center1_y), (center2_x, center2_y)
# 测试,以随机生成的簇心为上一次输入,测试函数训练结果
train_centers('pl_bmasse', 'pl_orbper', centers=((3, 3), (8, 8)), data=data)
# ((3, 3), (8, 8)) 经一次训练后变为
# ((2.36814897560919, 2.3440864259971432),
# (6.895476126190156, 6.774695415955819))
定义函数预测分类
根据训练得到的簇心,预测输入新的数据特征的分类。
def pred_centers(x, y, centers):
(center1_x, center1_y), (center2_x, center2_y) = centers
# 需要预测的x,y与簇心的距离
d1 = np.sqrt(np.power(x-center1_x, 2)+np.power(y-center1_y, 2))
d2 = np.sqrt(np.power(x-center2_x, 2)+np.power(y-center2_y, 2))
# 若距离center1簇心的距离小于center2,则为0,否则为1,此处将bool直接转为int
pred_classes = np.array(d1 < d2, dtype='int')
return pred_classes
# 测试,假定簇心为 [[3, 6], [9, 5]] 即(3,9)和(6,5),测试 x,y是否能以此簇心为基础进行数据分类
x = np.array([1, 1, 1, 5, 5, 5, 10, 10, 10])
y = np.array([1, 5, 10, 1, 5, 10, 1, 5, 10])
pred_centers(x, y, [[3, 6], [9, 5]])
开始训练
随机生成一个簇心,并训练 15 次。
np.random.seed(122)
fea_1, fea_2 = 'pl_bmasse', 'pl_orbper'
epochs = 15
centers = np.random.randint(low=3, high=10, size=(2, 2)) # 随机初始化簇心
centers_all = [centers] # 保存所有迭代次数的簇心
for i in range(epochs):
centers = train_centers(fea_1, fea_2, centers, data=data)
centers_all.append(centers)
# 查看随机生成的簇心及前两轮训练得到的簇心
centers_all[:2]
绘制聚类结果
以最后一次训练得到的簇心为基础,进行行星的分类,并以等高面的形式绘制各类的边界。
plt.rcParams['figure.figsize'] = (10, 6)
epoch = epochs-1
# 根据簇心预测行星所属分类
classes = pred_centers(
data['pl_orbper'], data['pl_bmasse'], centers_all[epoch])
# 绘制各类别的数据图
plt.scatter('pl_orbper', 'pl_bmasse',
data=data.loc[classes == 0], color='white', edgecolors='tab:red', label='class_0')
plt.scatter('pl_orbper', 'pl_bmasse',
data=data.loc[classes == 1], color='white', edgecolors='tab:blue', label='class_1')
# 绘制簇心
(center1_x, center1_y), (center2_x, center2_y) = centers_all[epoch]
plt.scatter(center1_x, center1_y, marker='^', s=300, c='r')
plt.scatter(center2_x, center2_y, marker='v', s=300, c='r')
# 绘制分类边界
ax = plt.gca()
# 获得子图的 x y 轴范围
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
# 在 x y 轴范围区间中生成等长分布的数据点
x = np.linspace(xmin, xmax, 1000) # shape 1000*1
y = np.linspace(ymin, ymax, 1000) # shape 1000*1
# 通过 np.meshgrid 生成数据网格
X, Y = np.meshgrid(x, y) # shape 1000*1000
# 对网格中的每个点进行分类结果的计算
Z = pred_centers(X, Y, centers_all[epoch]) # shape 1000*1000
# 等高面绘图接口
plt.contourf(X, Y, Z, cmap=plt.cm.RdBu, alpha=0.2)
# 进一步调节 x 轴范围
plt.xlim(-2.5, 12)
plt.legend()
plt.xlabel('log(Orbital Period [days])')
plt.ylabel('log(Planet Mass)')
plt.title('Planet Cluster by Mass and Orbital Period after Train Epoch %d' %
epoch, fontsize=16)
从运行结果可以看出,所有行星被分成了两类。并通过上三角和下三角标注了每个类别的簇心位置。
聚类前
以下输出了聚类前原始数据绘制的图像。
plt.rcParams['figure.figsize'] = (10, 6)
data = df[['pl_bmasse', 'pl_orbper']].copy()
data.dropna(inplace=True)
data = data.apply(lambda x: np.log(x))
plt.scatter('pl_orbper', 'pl_bmasse', data=data,
color='white', edgecolors='tab:red')
plt.xlim(-2.5, 12)
plt.xlabel('log(Orbital Period [days])')
plt.ylabel('log(Planet Mass)')
plt.title('Mass and Orbital Period of Planet Distribution')