机器学习笔记-基于梯度下降的曲线拟合

背景

7月份的时候导师布置了个作业,他给了一条用程序生成的曲线,然后让我们用代码实现一个梯度下降算法来拟合曲线。具体要求:

data.csv文件中包含两列用逗号分隔的数据。第一列是x,第二列是y。完成如下工作:
(1)在data.csv中随机选择80%的数据作为训练集,剩余20%作为测试集。
(2)构造模型,采用梯度下降算法训练模型。
(3)用测试集对训练的模型进行评估,将测试集中的x作为输入,用模型计算y,计算预测值与实际值的RMSE。
(4)绘制data.csv中的点,绘制x ∈ [0,1] 之间模型的对应曲线。

数据格式如下:

0.000000000000000000,0.000045401991009684
0.010010010010010010,0.000067487908347918
0.020020020020020020,0.000099516665248245
0.030030030030030030,0.000145574221405758
0.040040040040040040,0.000211247752152538
0.050050050050050046,0.000304101936049645
0.060060060060060060,0.000434277611628926
0.070070070070070073,0.000615236631426893
0.080080080080080079,0.000864687227990188
0.090090090090090086,0.001205760122738213
0.100100100100100092,0.001668621265042236

上面的csv文件一共有1000行数据,在xy平面上绘制出来的曲线如下:


思路

老师的意思是先猜这条曲线是什么函数的曲线(先确定函数的基本形式),一开始函数的具体参数是不知道的,需要猜几个初始值,那么猜出来的曲线一定和实际曲线有较大差异,再用最优化的方法找到使差异最小化的函数参数,从而实现曲线的拟合。这里要求实现梯度下降算法来求解最小值。

从曲线的图像来看原始数据应该是几个均值方差不同的高斯函数叠加而成的,图中有4个峰,因此可以假设曲线的模型为:f(x)=\alpha_1e^{-\frac{(x-\mu_1)^2}{2\sigma^2_1}}+\alpha_2e^{-\frac{(x-\mu_2)^2}{2\sigma^2_2}}+\alpha_3e^{-\frac{(x-\mu_3)^2}{2\sigma^2_3}}+\alpha_4e^{-\frac{(x-\mu_4)^2}{2\sigma^2_4}}
令误差函数为E=\sum\limits_{i=1}^{n} (f(x_i) - y_i)^2。则理想的模型参数:
(\alpha_1,\mu_1,\sigma_1,\alpha_2,...,\sigma_4)=\min\limits_{\alpha_1,...,\sigma_4}E

梯度下降算法每次求出函数(E)在某个点(当前参数)的梯度,因为梯度就是函数值增长最快的那个方向,所以让参数沿着梯度的负方向乘以一定的步长进行更新,就一定能抵达一个局部极小点。所以只要给定了这里的误差函数E(\alpha_1,\mu_1,\sigma_1,\alpha_2,\mu_2,\sigma_2,\alpha_3,\mu_3,\sigma_3,\alpha_4,\mu_4,\sigma_4),就可以通过梯度下降算法来找到使误差函数达到局部极小的12个参数。

为了便于计算,可以把\sigma^2当成一个整体,此时需要求出E在某个点的梯度的一般表示:(\frac{\partial E}{\partial \alpha_1},\frac{\partial E}{\partial \mu_1},\frac{\partial E}{\partial \sigma_1^2},\frac{\partial E}{\partial \alpha_2},\frac{\partial E}{\partial \mu_2},\frac{\partial E}{\partial \sigma_2^2},\frac{\partial E}{\partial \alpha_3},\frac{\partial E}{\partial \mu_3},\frac{\partial E}{\partial \sigma_3^2},\frac{\partial E}{\partial \alpha_4},\frac{\partial E}{\partial \mu_4},\frac{\partial E}{\partial \sigma_4^2},)。其中\frac{\partial E}{\partial \alpha_1}=2\sum\limits_{i=1}^{n}((f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}) \frac{\partial E}{\partial \mu_1}=2\sum\limits_{i=1}^{n}(\frac{\alpha_1(x_i-\mu_1)}{\sigma_1^2}(f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}) \frac{\partial E}{\partial \sigma_1^2}=2\sum\limits_{i=1}^{n}(\frac{\alpha_1(x_i-\mu_1)^2}{2\sigma_1^4}(f(x_i)-y_i)e^{-\frac{(x_i-\mu_1)^2}{2\sigma_1^2}}),其余参数的偏导数以此类推。

设定一个迭代次数,每次求出误差函数的梯度后,设定步长\eta,让参数沿梯度的负方向更新,如:\alpha_1=\alpha_1-\eta\frac{\partial E}{\partial \alpha_1}\mu_1=\mu_1-\eta\frac{\partial E}{\partial \mu_1},然后重复这个步骤,直到达到一定迭代次数或者总误差小于一定阈值停止迭代。

程序

程序使用Java实现。(C++写起来麻烦而且没有合适的图表显示库,Python太慢,Java写起来最顺手)

一开始我面临的问题就是选择一个图表显示库,简单地调研了一下选了XChart,但是去了该项目的Github主页发现居然没有打包好的 jar 包,于是需要 clone 下来然后用 mvn package 命令把 jar 包打出来。

然后我定义了一个模型类 Model,这个模型类的成员变量是 double数组,用来放待调的参数,比如上文中的f(x)对应的参数数组长度就为12。Model类有一些待实现的方法如函数的求值(val)、梯度的求值(grad)等,其派生类GaussianModel就是上文中的模型。另外,因为梯度下降会抵达最近的极小点而不是全局最小点,最终的收敛点极大依赖于参数的初始值,我每次随机选取了一部分数据点来求梯度以跳出局部极小。

Java代码如下:

package com.company;

import org.knowm.xchart.QuickChart;
import org.knowm.xchart.SwingWrapper;
import org.knowm.xchart.XYChart;

import java.io.File;
import java.io.FileNotFoundException;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.lang.Math.E;
import static java.lang.Math.pow;
import static java.lang.Math.sqrt;
import static java.lang.System.exit;


public class Solver {

    private List<Point> rawData = new ArrayList<>();
    private List<Point> trainData = new ArrayList<>();
    private List<Point> testData = new ArrayList<>();
    private Model model = null;
    private Function<Model, Double> loss = null;

    public Solver(String csvPath) throws FileNotFoundException {
        Scanner scanner = new Scanner(new File(csvPath));
        while (scanner.hasNextLine()) {
            String[] xy = scanner.nextLine().split(",");
            rawData.add(new Point(Double.valueOf(xy[0]), Double.valueOf(xy[1])));
        }
    }

    private Function<Model, Double> mse = (m) -> {
        double lossSum = 0.0;
        for (Point p : trainData) {
            double diff = m.val(p.x) - p.y;
            lossSum += (diff * diff);
        }
        return lossSum / 2.0;
    };

    private void divide(float ratio4Train) {
        trainData.clear();
        testData.clear();
        if (ratio4Train <= 0) throw new IllegalArgumentException("Ratio <= 0");
        int testCount = (int) (rawData.size() * (1 - ratio4Train));
        Random rand = new Random(System.currentTimeMillis());
        Set<Integer> exclusiveIndices4Test = new HashSet<>();
        while (exclusiveIndices4Test.size() < testCount) {
            int index = rand.nextInt(rawData.size());
            if (! exclusiveIndices4Test.contains(index)) {
                testData.add(rawData.get(index));
                exclusiveIndices4Test.add(index);
            }
        }
        for (int i = 0; i < rawData.size(); i ++) {
            if (! exclusiveIndices4Test.contains(i)) {
                trainData.add(rawData.get(i));
            }
        }
    }

    private void train() {
        System.out.println("Train data size: " + trainData.size());
        System.out.println("Test data size: " + testData.size());
//        model = new PolyModel(4);
        model = new GaussianModel(5);
        loss = mse;
        // ==========================================================
        for (int i = 0; i < 10000; i ++) {
            double lossVal = loss.apply(model);
            double[] gradVal = model.grad(trainData);
            System.out.println(String.format("Iter: %d, loss: %f ", i, lossVal));
            System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
            System.out.println(String.format("Grad: %f, %f, %f\n", gradVal[0], gradVal[1], gradVal[2]));
            if (Double.isNaN(lossVal)) {
                model.randomize(); i = 0;
                continue;
            }
            for (int j = 0; j < gradVal.length; j ++) {
                double delta = model.rate(j) * gradVal[j];
                model.theta[j] -= delta;
            }
//            if (lossVal < 1.06) break;
        }
        System.out.println(String.format("Theta: %f, %f, %f", model.theta[0], model.theta[1], model.theta[2]));
    }

    private void validate() {
        double RMSE = 0.0;
        for (Point p : testData) {
            double diff = model.val(p.x) - p.y;
            RMSE += (diff * diff);
        }
        RMSE /= testData.size();
        RMSE = sqrt(RMSE);
        System.out.println("RMSE: " + RMSE);
    }

    private void plot() {
        XYChart chart = QuickChart.getChart(
                "Result", "X", "Y", "y(x)",
                trainData.stream().map(point -> point.x).collect(Collectors.toList()),
                trainData.stream().map(point -> point.y).collect(Collectors.toList()));

        double[] xPoints = new double[150];
        double[] yPoints = new double[150];
        for (int i = 0; i < 150; i ++) {
            xPoints[i] = i * 10.0 / 150;
            yPoints[i] = model.val(xPoints[i]);
        }
        chart.addSeries("model", xPoints, yPoints);

        new SwingWrapper<XYChart>(chart).displayChart();
    }

    public void solve() {
        divide(0.8f);
        train();
        validate();
        plot();
    }

    public static void main(String[] args) throws FileNotFoundException {
    // write your code here
        if (args.length < 1) {
            System.out.println("Usage: java -jar GradientDesent.jar data.csv");
            exit(0);
        }
        new Solver(args[0]).solve();
    }

    private static class Point {
        double x;
        double y;
        public Point(double x, double y) {this.x = x; this.y = y;}

    }

    private static abstract class Model {
        double theta[] = null;
        abstract double val(double x);
        abstract double[] grad(List<Point> trainData);
        abstract void randomize();
        abstract double rate(int i);
    }

    private static class PolyModel extends Model{

        public PolyModel(int n) {
            if (n < 2) throw new IllegalArgumentException("n MUST be larger than 2.");
            theta = new double[n];
            randomize();
        }

        double val(double x) {
            double result = 0.0;
            for (int i = 0; i < theta.length; i ++) {
                result += theta[i] * pow(x, i);
            }
            return result;
        }

        @Override
        double[] grad(List<Point> trainData) {
            double []gradVec = new double[theta.length];
            for (int i = 0; i < gradVec.length; i ++) {
                gradVec[i] = 0.0;
                Random r = new Random();
                List<Point> data = new ArrayList<>();
                for (int k = 0; k < 50; k ++)
                    data.add(trainData.get(r.nextInt(trainData.size())));
                for (Point p : data) {
                    double diff = val(p.x) - p.y;
                    gradVec[i] += (diff * pow(p.x, i));
                }
            }
            return gradVec;
        }

        @Override
        void randomize() {
            Random rand = new Random(System.currentTimeMillis());
            for (int i = 0; i < theta.length; i ++) {
                theta[i] = rand.nextDouble() ;
            }
        }

        @Override
        double rate(int i) {
            return 0.00000002;
        }
    }

    private static class GaussianModel extends Model{

        /**
         * f(x) = a * e ^ (- (x - μ)^2 / σ^2)
         * (a, μ, σ2) <<----
         * @param n number of gaussian function
         */
        public GaussianModel(int n) {
            if (n < 1) throw new IllegalArgumentException("n MUST be larger than 1.");
            theta = new double[n * 3];
            randomize();
        }

        @Override
        double val(double x) {
            double result = 0.0;
            for (int i = 0; i < theta.length / 3; i ++) {
                double alpha = theta[i * 3 + 0];
                double miu = theta[i * 3 + 1];
                double sigma2 = theta[i * 3 + 2];
                result += (alpha * pow(E, - pow((x - miu), 2) / sigma2 / 2));
            }
            return result;
        }

        @Override
        double[] grad(List<Point> trainData) {
            double[] gradVec = new double[theta.length];
            for (int i = 0; i < theta.length / 3; i ++) {
                gradVec[i * 3 + 0] = 0;
                gradVec[i * 3 + 1] = 0;
                gradVec[i * 3 + 2] = 0;
                double alpha = theta[i * 3 + 0];
                double miu = theta[i * 3 + 1];
                double sigma2 = theta[i * 3 + 2];
                Random r = new Random();
                List<Point> stochasticData = new ArrayList<>();
                for (int k = 0; k < 30; k ++)
                    stochasticData.add(trainData.get(r.nextInt(trainData.size())));
                for (Point p : stochasticData) {
                    double val = val(p.x);
                    gradVec[i * 3 + 0] += 2
                            * (val - p.y)
                            * (pow(E, - pow((p.x - miu), 2) / sigma2 / 2));
                    gradVec[i * 3 + 1] += (2
                            * alpha
                            * (val - p.y)
                            * pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
                            * ((p.x - miu) / sigma2));
                    gradVec[i * 3 + 2] += (2
                            * alpha
                            * (val - p.y)
                            * pow(E, - pow((p.x - miu), 2) / sigma2 / 2)
                            * (pow((p.x - miu), 2) / pow(sigma2, 2) / 2)); //把sigma平方当成了一个整体
                }
            }
            return gradVec;
        }

        @Override
        void randomize() {
            Random rand = new Random(System.currentTimeMillis());
            for (int i = 0; i < theta.length / 3; i ++) {
                theta[i * 3 + 0] = rand.nextDouble();
                theta[i * 3 + 1] = rand.nextDouble() * 5;
                theta[i * 3 + 2] = rand.nextDouble();
            }
        }

        @Override
        double rate(int i) {
            if (i % 3 == 0) {
                return 0.0005;
            } else if (i % 3 == 1) { // miu
                return 0.0005;
            } else {
                return 0.00005;
            }
        }

        public String toString() {
            StringBuilder builder = new StringBuilder("Theta: ");
            for (double t : theta) {
                builder.append(t);
                builder.append(", ");
            }
            builder.append("\nGrad: ");
            return builder.toString();
        }
    }
}

最后的结果还是比较看人品的,并不是每次都能拟合地比较好,贴一个结果的图:


结果

数据和代码我放到了我的Github:https://github.com/Jimmie00x0000/gradient_desent_demo

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,313评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,369评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 159,916评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,333评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,425评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,481评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,491评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,268评论 0 269
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,719评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,004评论 2 328
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,179评论 1 342
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,832评论 4 337
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,510评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,153评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,402评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,045评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,071评论 2 352

推荐阅读更多精彩内容