问题描述:
用遗传算法求使得F(X)最大的X,问题来源:莫烦的python教程之遗传算法
最终效果:
import numpy as np
import matplotlib.pyplot as plt
DNA_SIZE = 10
POP_SIZE = 100
CROSS_RATE = 0.8
N_GENERATIONS = 400
X_BOUND = [0,5]
MUTATE_RATE = 0.003
def F(x) : return np.sin(10*x)*x + 2*np.cos(x)
def translateDNA(pop):
return pop.dot(2 ** np.arange(DNA_SIZE)[::-1]) / float(2**DNA_SIZE-1) * X_BOUND[1]
def mutate(child):
for point in range(DNA_SIZE):
if np.random.rand() < MUTATE_RATE :
#注意三目运算符的写法
child[point] = 0 if child[point] == 1 else 1
return child
def crossover(parent,pop):
if np.random.rand()<CROSS_RATE:
#选一个随机下标
index =np.random.randint(0,POP_SIZE,size=1)
cross_points = np.random.randint(0,2,size=DNA_SIZE).astype(np.bool)
parent[cross_points] = pop[index,cross_points]
return parent
def select(pop,fitness):
idx = np.random.choice(np.arange(POP_SIZE),size = POP_SIZE,replace = True,p=fitness / fitness.sum())
return pop[idx]
def getFitness(pred):
return pred - np.min(pred) + 1e-3
pop = np.random.randint(0,2,(1,DNA_SIZE)).repeat(POP_SIZE,axis=0)
plt.ion() # something about plotting
x = np.linspace(*X_BOUND, 200)
plt.plot(x, F(x))
for _ in range(N_GENERATIONS):
F_values = F(translateDNA(pop))
if 'sca' in globals(): sca.remove()
sca = plt.scatter(translateDNA(pop), F_values, s=200, lw=0, c='red', alpha=0.5); plt.pause(0.1)
fitness = getFitness(F_values)
print("Most fitted DNA: ", pop[np.argmax(fitness), :])
print("current pop size is : ",len(pop))
pop = select(pop,fitness)
pop_copy = pop.copy()
for parent in pop:
child = crossover(parent,pop)
child = mutate(child)
#表示parent从头到尾用child赋值
parent[:] = child
plt.ioff(); plt.show()