python实现梯度下降法

梯度下降法

梯度定义

梯度的本意是一个向量(矢量),表示某一函数在该点处的方向导数沿着该方向取得最大值,即函数在该点处沿着该方向(此梯度的方向)变化最快,变化率最大(为该梯度的模)。
<p align="right">--------百度百科</p>

对于f(x)=2x来说,其梯度\nabla f(x)为:
\nabla f(x)=\dfrac{df(x)}{dx}=2
对于f(x,y)=x^2+2y来说,其梯度\nabla f(x,y)为:
\nabla f(x,y)=\left(\dfrac{\partial f}{\partial x},\dfrac{\partial f}{\partial y}\right )=(2x,2)

梯度下降法思路

因为梯度是函数上升最快的方向,所以如果我们要寻找函数的最小值,只需沿着梯度的反方向寻找即可。这里以f(x)=2x为例,简述梯度下降法实现的大体步骤:

  1. 确定变量的初始点x_0,从初始点开始一步步向函数最小值逼近。
  2. 求函数梯度,然后求梯度的反向,将变量的初始点代入,确定变量变化的方向:-\nabla f(x_0);用求得的梯度向量(变量变化的方向)乘以学习率\alpha (变量变化的步长)得到一个新的向量;变量的初始点加上求得的新向量,到达下一个点。
    x = x_0 - \alpha \nabla f(x_0)
  3. 判断此时函数值的变化量是否满足精度要求。定义一个我们认为满足要求的精度p_0;用上一个点的函数值减去当前点的函数值,得到此时函数值变化量的精度值p(可以近似认为p为损失函数);判断p<p_0是否成立。不成立则反复执行步骤2、3。 \begin{cases}p = f(x)-f(x_0)\\p < p_0\end{cases}

但是梯度下降法对初始点的选取要求比较高,选取不当容易陷入极小值(局部最优解)。

梯度下降法的简单应用

梯度下降法求二维曲线的最小值

下图为梯度下降法求曲线y=x^2+2x+5最小值的结果图,左图红色的点为求解过程中的过程点,右图为求解过程中精度的变化(损失函数值的变化),代码见附录。

梯度下降法求二维曲线的最小值

梯度下降法求三维曲面的最小值

下图为梯度下降法求曲面z=\sqrt{x^2+y^2}最小值的结果图,图中红色的点为求解过程中的过程点,代码见附录。

梯度下降法求三维曲面的最小值

代码附录

# -*- encoding=utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as aplt
from mpl_toolkits.mplot3d.axes3d import Axes3D
import sympy 

class gradientDescent(object):
    def init2D(self,vector:float,precision:float,startPoint:float):
    """
    vector:学习率
    precision:精度
    startPoint:起始点
    """
        self.vector = vector
        self.precision = precision
        self.startPoint = startPoint
        self.startPrecision = precision + 1

    def init3D(self,vector:float,precision:float,startVar1Point:float,startVar2Point:float):
    """
    vector:学习率
    precision:精度
    startVar1Point:变量1的起始位置
    startVar2Point:变量2的起始位置
    """
        self.vector = vector
        self.precision = precision
        self.startVar1Point = startVar1Point
        self.startVar2Point = startVar2Point
        self.startPrecision = precision + 1

    def singleVar2D(self, func:str, var:str):
        grad = sympy.diff(func, var)
        grad = str(grad)
        xpoint = []
        ypoint = []
        errors = []
        x = self.startPoint
        while self.startPrecision > self.precision:
            y = eval(func)
            xpoint.append(x)
            ypoint.append(y)
            x1 = x - self.vector*eval(grad)
            x = x1
            y1 = eval(func)
            self.startPrecision = y - y1
            errors.append(self.startPrecision)
        xpoint.append(x)
        ypoint.append(y)
        xlen = len(xpoint)
        return [xpoint,ypoint,errors,xlen]
        
    def doubleVar3D(self, func:str, var1:str, var2:str):
        var1Grad = sympy.diff(func, var1)
        var1Grad = str(var1Grad)
        var1Grad = var1Grad.replace("sqrt","np.sqrt")
        var2Grad = sympy.diff(func, var2)
        var2Grad = str(var2Grad)
        var2Grad = var2Grad.replace("sqrt","np.sqrt")
        func = func.replace("sqrt","np.sqrt")
        xpoint = []
        ypoint = []
        zpoint = []
        errors = []
        x = self.startVar1Point
        y = self.startVar2Point
        while self.startPrecision > self.precision:
            z = eval(func)
            xpoint.append(x)
            ypoint.append(y)
            zpoint.append(z)
            x1 = x - self.vector*eval(var1Grad)
            y1 = y - self.vector*eval(var2Grad)
            x = x1
            y = y1
            z1 = eval(func)
            self.startPrecision = z - z1
            errors.append(self.startPrecision)
        xpoint.append(x)
        ypoint.append(y)
        zpoint.append(z)
        xlen = len(xpoint)
        return [xpoint,ypoint,zpoint,errors,xlen]


if __name__ == '__main__':
            xData = np.arange(-100,100,0.1)
            yData = xData**2 + 2*xData + 5
            vector=0.2
            precision=10e-6
            startPoint=-100
            x = sympy.symbols("x")
            func = "x**2+2*x+5"
            gradient_descent = gradientDescent()
            gradient_descent.init2D(vector,precision,startPoint)
            [xpoint,ypoint,errors,xlen] = gradient_descent.singleVar2D(func,x)        
            fig,ax = plt.subplots(figsize=(12,8),ncols=2,nrows=1)
            for i in range(xlen):
                ax[0].cla()
                ax[0].plot(xData,yData,color="green",label="$y=x^2+2x+5$")
                ax[0].scatter(xpoint[i],ypoint[i],color="red",label="process point")
                plt.pause(0.1)
            ax[0].legend(loc = "best")
            ax[1].plot(errors,label="Loss curve")
            ax[1].legend(loc = "best")
            plt.pause(0.1)
            plt.show()
            # =======================================================================
            xData = np.arange(-100,100,0.1)
            yData = np.arange(-100,100,0.1)
            X,Y = np.meshgrid(xData,yData)
            # z = sqrt(x^2+y^2)
            Z = np.sqrt(X**2+Y**2)
            x = sympy.symbols("x")
            y = sympy.symbols("y")
            func = "sqrt(x**2+y**2)"
            vector=0.2
            precision=10e-6
            startVar1Point=100
            startVar2Point=-100           
            gradient_descent = gradientDescent()
            gradient_descent.init3D(vector, precision, startVar1Point, startVar2Point)
            [xpoint,ypoint,zpoint,errors,xlen] = gradient_descent.doubleVar3D(func,x,y)        
            fig = plt.figure()
            ax = Axes3D(fig)
            surf = ax.plot_surface(X,Y,Z,label="$z=\sqrt{x^2+y^2}$")
            ax.scatter(xpoint,ypoint,zpoint,color="red",label="process point")
            # 解决标签报错,不显示问题
            surf._facecolors2d=surf._facecolors3d
            surf._edgecolors2d=surf._edgecolors3d
            ax.legend()
            plt.show()
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,222评论 6 493
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,455评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 157,720评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,568评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,696评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 49,879评论 1 290
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,028评论 3 409
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,773评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,220评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,550评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,697评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,360评论 4 332
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,002评论 3 315
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,782评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,010评论 1 266
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,433评论 2 360
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,587评论 2 350

推荐阅读更多精彩内容