遗传算法helloworld级别的python实现(结果可视化)

问题描述:

用遗传算法求使得F(X)最大的X,问题来源:莫烦的python教程之遗传算法

最终效果:

population进化的过程.gif
import numpy as np
import matplotlib.pyplot as plt


DNA_SIZE = 10
POP_SIZE = 100
CROSS_RATE = 0.8
N_GENERATIONS = 400
X_BOUND = [0,5]
MUTATE_RATE = 0.003

def F(x) : return np.sin(10*x)*x + 2*np.cos(x)

def translateDNA(pop):
    return pop.dot(2  ** np.arange(DNA_SIZE)[::-1]) /  float(2**DNA_SIZE-1) * X_BOUND[1]


def mutate(child):
    for point in range(DNA_SIZE):
        if np.random.rand() < MUTATE_RATE :
            #注意三目运算符的写法
        
            child[point] = 0 if child[point] == 1 else 1
    return child

def crossover(parent,pop):
    if np.random.rand()<CROSS_RATE:
        #选一个随机下标
        index =np.random.randint(0,POP_SIZE,size=1)
        cross_points = np.random.randint(0,2,size=DNA_SIZE).astype(np.bool)
        parent[cross_points] = pop[index,cross_points]
    return parent
    
def select(pop,fitness):
    idx = np.random.choice(np.arange(POP_SIZE),size = POP_SIZE,replace = True,p=fitness / fitness.sum())
    return pop[idx]    

def getFitness(pred):
    return pred - np.min(pred) + 1e-3

pop = np.random.randint(0,2,(1,DNA_SIZE)).repeat(POP_SIZE,axis=0)

plt.ion()       # something about plotting
x = np.linspace(*X_BOUND, 200)
plt.plot(x, F(x))

for _ in range(N_GENERATIONS):
    F_values = F(translateDNA(pop))
    
    
    if 'sca' in globals(): sca.remove()
    sca = plt.scatter(translateDNA(pop), F_values, s=200, lw=0, c='red', alpha=0.5); plt.pause(0.1)
    
    fitness = getFitness(F_values)
    print("Most fitted DNA: ", pop[np.argmax(fitness), :])
    print("current pop size is : ",len(pop))
    
    
    pop = select(pop,fitness)
    pop_copy = pop.copy()
    for parent in pop:
        child = crossover(parent,pop)
        child = mutate(child)
        #表示parent从头到尾用child赋值
        parent[:] = child
plt.ioff(); plt.show()
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容