import numpy as np
import math
import random
import matplotlib.pyplot as plt
def compute_all_choices(x,y):
iterations = 2000
p_count = x.shape[0]
inlier_dist_thresh = 0.25
sample_count = 0
P= 0.99
preinliers = 0
bestk = 0
bestb = 0
inlier_num_thresh = int(p_count*0.4)
while(iterations>sample_count):
sample_idx = [random.randint(0,p_count-1),random.randint(0,p_count-1)]
if sample_idx[0]==sample_idx[1]:
continue
x_1 = x[sample_idx[0]]
x_2 = x[sample_idx[1]]
y_1 = y[sample_idx[0]]
y_2 = y[sample_idx[1]]
k = (y_1-y_2)/(x_1-x_2)
b = y_1 - k*x_1
total_inlier = 0
for i in range (p_count):
y_hat = k * x[i] + b
if abs(y_hat - y[i]) < inlier_dist_thresh:
total_inlier += 1
if total_inlier > preinliers:
preinliers = total_inlier
iterations = math.log(1-P)/math.log(1-math.pow(total_inlier/float(p_count),2))
bestk = k
bestb = b
if total_inlier > inlier_num_thresh:
break
return bestk, bestb
def main():
# y = 2x+5
X = np.array([random.uniform(1,10) for i in range(60)])
y = 2*X+5
randomness = np.array([random.uniform(-0.3,0.3) for i in range(60)])
y += randomness
for i in range(18):
y[2*i+2] = y[2*i+2] + random.uniform(-15,15)
k,b = compute_all_choices(X,y)
vizy = X*k+b
plt.title("demo")
plt.xlabel("x")
plt.ylabel("y")
plt.scatter(X,y)
plt.plot(X,vizy)
plt.show()
if __name__ == '__main__':
main()
原理
https://blog.csdn.net/zhoucoolqi/article/details/105497572
n – 用于拟合的最小数据组数.
k – 算法规定的最大遍历次数.
t – 数据和模型匹配程度的阈值,在t范围内即inliers,在范围外即outliers.
d – 表示模型合适的最小数据组数.