优化算法应用(五)优化支持向量机(SVM)

一. 目标描述

支持向量机(Support Vector Machine,SVM)是一种有监督的机器学习算法。其原理简单,但是通用的公式推到及求解过程异常的复杂,这里将直接用优化算法来对支持向量机进行求解。
(该篇没有对偶问题转化推导,没有核函数,如需要了解自行搜索。)

二. 支持向量机简介

支持向量机是二分类算法,它的主要目标是找到两个数据集的支持向量及分割平面,使分割平面间的距离最大化。


  如上图,红蓝点分别代表两个数据集,其中红色线和蓝色线表示由支持向量决定的两个平行的线(超平面),超平面之间的距离为2d。我们的目标就是找到最佳的斜率和支持向量,让d的值最大。

1. 硬间隔

硬间隔表示所有的数据点都需要在超平面的一侧。


  如上图,确定L的斜率后,蓝色数据集可以确定4个超平面L1,L2,L3,L4。由于硬间隔的要求,蓝色数据集只能在超平面的一侧,故只有L1能被选作超平面。
  选定超平面后就可以计算两个超平面之间的距离。
  设蓝色数据集为(x1,y1), (x2,y2), (x3,y3), (x4,y4),红色数据集(m1,n1), (m2,n2), (m3,n3), (m4,n4)。
  假设(x1,y1)和(m1,n1)为两个支持向量,则超平面由下式确定


  点(a,b)到直线Ax+By+C=0的距离计算公式如下:


  则可计算两个超平面间间隔2d如下


2. 软间隔

当数据集中有一定异常数据,数据无法被两个超平面分为两类时,将引入松弛参数,并重新计算数据集之间的最大距离。


  如图,若蓝色有两个分类错误的数据点时,分别计算其到对应超平面的距离,整个数据集间的距离如下:


  其中k为松弛参数,k为非负数,当k取值为0时,则表示可以允许所有数据点分类错误,当k取值为正无穷时,表示不允许任何数据点分类错误,此时,等同于硬间隔。

三. 支持向量机适应度函数

1. 适应度函数设计

从支持向量机的定义知,这是一个最大化问题。其中的变量为超平面,而超平面由支持向量决定。所以适应度函数的输入应该是超平面的各个维度的斜率,输出则是两个数据集之间的距离。其流程如下:



  以上面的例子来说计算步骤如下
  设蓝色数据集为(x1,y1), (x2,y2), (x3,y3), (x4,y4),红色数据集(m1,n1), (m2,n2), (m3,n3), (m4,n4)。
  确定一组斜率(A,B),可以计算出8个超平面,分别由C1,C2,C3,C4,C5,C6,C7,C8。

数据点 截距
(x1,y1) C1
(x2,y2) C2
(x3,y3) C3
(x4,y4) C4
(m1,n1) C5
(m2,n2) C6
(m3,n3) C7
(m4,n4) C8

然后计算各个超平面组合之间的距离,即在下表中选出最大的D作为适应度值返回。

C5 C6 C7 C8
C1 D15 D16 D17 D18
C2 D25 D26 D27 D28
C3 D35 D36 D37 D38
C4 D45 D46 D47 D48

2. 适应度函数优化

从适应度函数可以看出,如果数据集1有i个数据,数据集2有j个数据,那么为了选出最佳的支持向量,我们需要计算i*j次数据集之间的距离,当数据集中的数据量很大时,这将是会耗费非常多的时间。
  下面,将适应度函数优化一下,每次只计算单次数据集间距离。
  可以看到,原适应度函数中,若数据集中数据的维度为2时,输入的维度也为2,输入中的每一维都可以认为是该维度斜率(如A,B相除得到斜率),然后会根据斜率和数据点计算出该超平面的截距(如C1,C2),然后选出最佳的截距组合来确定超平面之间的距离。
  这里优化一下,将适应度函数的输入确定为两个超平面,如此每次只需计算这两个超平面之间的距离即可。
  具体如下:
  若数据集中数据的维度为2,那么适应度函数输入的维度为2+2(如(A,B,C1,C2)), 若数据集中数据的维度为N时,适应度函数输入的维度为N+2,其中最后两维代表这两个超平面的截距。
这样一来,适应度函数输入的就是两个确定的超平面,而不是超平面的截距,但是此时,该超平面不一定会穿过数据集中的数据点,或者可以说该超平面不存在支持向量。我愿称之为广义支持向量机。


  如图,优化后的适应度函数计算出的超平面不一定经过数据点。

四. 代码实现

1 .原SVM模型

文件路径:\optimization algorithm\application_svm\SVM_Model.m

classdef SVM_Model < handle
    properties
        % 数据集m行*n列的矩阵,n = dim
        dataset1 = [];
        dataset2 = [];
        % 惩罚系数
        c = 5;
    end
    
    methods
        % 构造函数
        function self = SVM_Model(dataset1,dataset2)
            self.dataset1 = dataset1;
            self.dataset2 = dataset2;
        end
        
        % 输入为向量,输出为适应度值
        % 输入的x为角度
        function value = fit_function(self,x) 
            % 计算哪个数据集在上方
            is_up = self.get_is_up(x);
            
            % 计算结果及最佳支持向量截距
            [value,intercept1,intercept2] = self.get_value_and_intercept(x,is_up);
        end
        
        % 绘图,num为图片编号
        function draw(self,input,num)      
            bound_max = max(abs([self.dataset1(1,:),self.dataset1(2,:),self.dataset2(1,:),self.dataset2(2,:)]));
            set(gca,'XLim',[-bound_max bound_max]);
            set(gca,'YLim',[-bound_max bound_max]);
            % 绘制左下右上两点,保持图像区域不变
            scatter(-bound_max,-bound_max,1,'w');
            hold on;
            scatter(bound_max,bound_max,1,'w');
            hold on;
            axis square;
            
            for i = 1:length(self.dataset1)
                scatter(self.dataset1(i,1),self.dataset1(i,2),10,'b','filled');
                hold on;
            end
            for i = 1:length(self.dataset2)
                scatter(self.dataset2(i,1),self.dataset2(i,2),10,'r','filled');
                hold on;
            end
            
            % 有数据则绘制分割线
            if ~isempty(input)
                is_up = self.get_is_up(input);
                [value,intercept1,intercept2] = self.get_value_and_intercept(input,is_up);
                x =-bound_max:0.1:bound_max;
                k = -(input(1))/(input(2));
                x1=x;
                y1 = k*x1+intercept1/-(input(2));
                % 剔除直线上超出边界的点
                I = y1>bound_max;
                y1(I) = [];
                x1(I) = [];
                J = y1<-bound_max;
                y1(J) = [];
                x1(J) = [];
                plot(x1,y1,'b');
                hold on;
                
                y2 = k*x+intercept2/-(input(2));
                x2=x;
                % 剔除直线上超出边界的点
                I = y2>bound_max;
                y2(I) = [];
                x2(I) = [];
                J = y2<-bound_max;
                y2(J) = [];
                x2(J) = [];
                plot(x2,y2,'r');
                hold on;
            end
            text(bound_max*0.9+-bound_max*0.1,bound_max*0.9-bound_max*0.1,num2str(num),'FontSize',20);
        end
        
    end
    
    % 受保护的方法,继承用
    methods (Access = protected)
        
        % 计算该斜率下哪个数据集在上哪个数据集在下,
        % 若数据集1 在上则返回true
        function value = get_is_up(self,x)
            % 计算两个数据集的平均位置
            mean_data1 = sum(self.dataset1)/length(self.dataset1);
            mean_data2 = sum(self.dataset2)/length(self.dataset2);
            
            % 根据数据集的平均位置来计算该斜率下的截距,用来判断哪个数据集在上,哪个数据集在下
            mean_intercept1 = self.get_intercept(mean_data1,x);
            mean_intercept2 = self.get_intercept(mean_data2,x);
            value = mean_intercept1 < mean_intercept2;
        end
        
        % 计算当前斜率下的最佳截距和结果
        % 输入的gradient为斜率
        function [value,intercept1,intercept2] = get_value_and_intercept(self,x,is_up)
            value = -realmax('double');
            % 遍历数据集,找出最佳的支持向量组合
            % 数据集之间距离越大越优
            for i = 1:length(self.dataset1)
                for j = 1:length(self.dataset2)
                    tmp_intercept1 = self.get_intercept(self.dataset1(i,:),x);
                    tmp_intercept2 = self.get_intercept(self.dataset2(j,:),x);
                    svm_value = self.get_svm_value(x,tmp_intercept1,tmp_intercept2,is_up);
                    % 记录着两个支持向量对应的截距
                    if value < svm_value
                        intercept1 = tmp_intercept1;
                        intercept2 = tmp_intercept2;
                        value = svm_value;
                    end
                end
            end
        end
        
        % 根据数据和超平面斜率,计算超平面截距
        % a*x+b*y+c = 0 --> c = -(a*x+b*y)
        function value = get_intercept(self,data,x)
            value = -sum(data.*x);
        end
        
        % 计算一个整个数据集在该超平面上的距离
        function value = get_svm_value(self,x,intercept1,intercept2,is_up)
            % 直线的分母
            deno = sqrt(sum(x.^2));
            if is_up
                % 如果数据集1在数据集2的下方
                % 则计算数据集1中在支持向量上方的数据的距离
                value = -(intercept1-intercept2)/deno;
                for i = 1:length(self.dataset1)
                    temp_intercept1 = self.get_intercept(self.dataset1(i,:),x);
                    if  temp_intercept1 > intercept1
                        value = value - self.c*abs(temp_intercept1-intercept1)/deno;
                    end
                end
                % 如果数据集2在数据集1的上方
                % 则计算数据集2中在支持向量下方的数据的距离
                for i = 1:length(self.dataset2)
                    temp_intercept2 = self.get_intercept(self.dataset2(i,:),x);
                    if temp_intercept2 < intercept2
                        value = value - self.c*abs(temp_intercept2-intercept2)/deno;
                    end
                end
            else
                % 如果数据集1在数据集2的上方
                % 则计算数据集1中在支持向量下方的数据的距离
                value = (intercept1-intercept2)/deno;
                for i = 1:length(self.dataset1)
                    temp_intercept1 = self.get_intercept(self.dataset1(i,:),x);
                    if  temp_intercept1 < intercept1
                        value = value - self.c*abs(temp_intercept1-intercept1)/deno;
                    end
                end
                % 如果数据集2在数据集1的下方
                % 则计算数据集2中在支持向量上方的数据的距离
                for i = 1:length(self.dataset2)
                    temp_intercept2 = self.get_intercept(self.dataset2(i,:),x);
                    if temp_intercept2 > intercept2
                        value = value - self.c*abs(temp_intercept2-intercept2)/deno;
                    end
                end
            end

        end

    end
    
end

测试代码
文件路径:\optimization algorithm\application_svm\Test.m

%% 清理之前的数据
% 清除所有数据
clear all;
% 清除窗口输出
clc;

a = unifrnd(-pi,pi,20,1);
ra = unifrnd(0,5,20,1);
b = unifrnd(-pi,pi,20,1);
rb = unifrnd(0,5,20,1);

data1 = [sin(a(:)).*ra(:)+5,cos(a(:)).*ra(:)+5;];
data2 = [sin(b(:)).*rb(:)-5,cos(b(:)).*rb(:)-5;];

% 数据维度为2
data_dim = 2;
model = SVM_Model(data1,data2);

range_max = ones(1,data_dim);
range_min = ones(1,data_dim)*-1;

%% 添加目录
% 将上级目录中的frame文件夹加入路径
addpath('../frame')
% 引入差分进化算法
addpath('../algorithm_differential_evolution')
%% 算法实例
dim = data_dim;
% 种群数量
size = 10;
% 最大迭代次数
iter_max = 50;
% 取值范围上界
range_max_list = range_max;
% 取值范围下界
range_min_list = range_min;
% 实例化差分进化算法类
base = DE_Impl(dim,size,iter_max,range_min_list,range_max_list);
base.is_cal_max = true;
% 确定适应度函数
base.fitfunction = @model.fit_function;
% 运行
base.run();
disp(['复杂度',num2str(base.cal_fit_num)]);

disp(model.fit_function(base.position_best));

%% 下面绘制动态图
% 绘制每一代的路径
for i = 1:length(base.position_best_history)
    model.draw(base.position_best_history(i,:),i);
    % 每0.01绘制一次
    pause = 0.01;
    %下面是保存为GIF的程序
    frame=getframe(gcf);
    % 返回单帧颜色图像
    imind=frame2im(frame);
    % 颜色转换
    [imind,cm] = rgb2ind(imind,256);
    filename = ['svm.gif'];
    if i==1
         imwrite(imind,cm,filename,'gif', 'Loopcount',inf,'DelayTime',1e-4);
    else
         imwrite(imind,cm,filename,'gif','WriteMode','append','DelayTime',pause);
    end

    if i <length(base.position_best_history)
        % 如果不是最后一张图就清除窗口
        clf;
    end
end

运行结果:

从图中可以看出,每一代的超平面都会进过数据点。

2. 优化后SVM模型

文件路径:\optimization algorithm\application_svm\SVM_Model_Broad.m

% 广义的svm,超平面不一定会经过数据点
% 集成至svm
classdef SVM_Model_Broad < SVM_Model

    properties
    end
    
    methods
        % 构造函数
        function self = SVM_Model_Broad(dataset1,dataset2)
             % 调用父类构造函数
            self@SVM_Model(dataset1,dataset2);
        end
        
        % 输入为向量,输出为适应度值
        % 输入的x为角度,x维度 =数据维度+2,最后两维为截距
        function value = fit_function(self,x) 
            
            % 取出除后2维的其他维
            k = x(1:length(x)-2);
            
             % 计算哪个数据集在上方
            is_up = self.get_is_up(k);
            
            % 数据集1的截距
            intercept1 = x(length(x)-1);
            % 数据集2的截距
            intercept2 = x(length(x));
            
            % 计算结果
            value = self.get_svm_value(k,intercept1,intercept2,is_up);
        end
        
        % 绘图,num为图片编号,输入的input为角度
        function draw(self,input,num)      
            bound_max = max(abs([self.dataset1(1,:),self.dataset1(2,:),self.dataset2(1,:),self.dataset2(2,:)]));
            set(gca,'XLim',[-bound_max bound_max]);
            set(gca,'YLim',[-bound_max bound_max]);
            % 绘制左下右上两点,保持图像区域不变
            scatter(-bound_max,-bound_max,1,'w');
            hold on;
            scatter(bound_max,bound_max,1,'w');
            hold on;
            axis square;
            
            for i = 1:length(self.dataset1)
                scatter(self.dataset1(i,1),self.dataset1(i,2),10,'b','filled');
                hold on;
            end
            for i = 1:length(self.dataset2)
                scatter(self.dataset2(i,1),self.dataset2(i,2),10,'r','filled');
                hold on;
            end
            
            % 有数据则绘制分割线
            if ~isempty(input)
                % 获取两个数据集的截距
                intercept1 = input(length(input)-1);
                intercept2 = input(length(input));
                
                % 获取斜率
                k = -(input(1))/(input(2));
                x =-bound_max:0.1:bound_max;
                x1=x;
                y1 = k*x1+intercept1/-(input(2));
                % 剔除直线上超出边界的点
                I = y1>bound_max;
                y1(I) = [];
                x1(I) = [];
                J = y1<-bound_max;
                y1(J) = [];
                x1(J) = [];
                plot(x1,y1,'b');
                hold on;
                
                x2=x;
                y2 = k*x2+intercept2/-(input(2));
                % 剔除直线上超出边界的点
                I = y2>bound_max;
                y2(I) = [];
                x2(I) = [];
                J = y2<-bound_max;
                y2(J) = [];
                x2(J) = [];
                plot(x2,y2,'r');
                hold on;
            end
            text(bound_max*0.9+-bound_max*0.1,bound_max*0.9-bound_max*0.1,num2str(num),'FontSize',20);
        end
        
    end
    
end

测试代码
文件路径:\optimization algorithm\application_svm\Test_Broad.m

%% 清理之前的数据
% 清除所有数据
clear all;
% 清除窗口输出
clc;

a = unifrnd(-pi,pi,20,1);
ra = unifrnd(0,5,20,1);
b = unifrnd(-pi,pi,20,1);
rb = unifrnd(0,5,20,1);

data1 = [sin(a(:)).*ra(:)+5,cos(a(:)).*ra(:)+5;];
data2 = [sin(b(:)).*rb(:)-5,cos(b(:)).*rb(:)-5;];

data_dim = 2;
model = SVM_Model_Broad(data1,data2);
model.c = 10;
range_max = ones(1,data_dim+2);
range_min = -ones(1,data_dim+2);

%% 添加目录
% 将上级目录中的frame文件夹加入路径
addpath('../frame')
% 引入差分进化算法
addpath('../algorithm_differential_evolution')
%% 算法实例
dim = data_dim+2;
% 种群数量
size = 40;
% 最大迭代次数
iter_max = 200;
% 取值范围上界
range_max_list = range_max;
% 取值范围下界
range_min_list = range_min;
% 实例化差分进化算法类
base = DE_Impl(dim,size,iter_max,range_min_list,range_max_list);
base.is_cal_max = true;
% 确定适应度函数
base.fitfunction = @model.fit_function;
% 运行
base.run();
disp(['复杂度',num2str(base.cal_fit_num)]);

disp(model.fit_function(base.position_best));

%% 下面绘制动态图
% 绘制每一代的路径
for i = 1:length(base.position_best_history)
    model.draw(base.position_best_history(i,:),i);
    % 每0.01绘制一次
    pause = 0.01;
    %下面是保存为GIF的程序
    frame=getframe(gcf);
    % 返回单帧颜色图像
    imind=frame2im(frame);
    % 颜色转换
    [imind,cm] = rgb2ind(imind,256);
    filename = ['svm_broad.gif'];
    if i==1
         imwrite(imind,cm,filename,'gif', 'Loopcount',inf,'DelayTime',1e-4);
    else
         imwrite(imind,cm,filename,'gif','WriteMode','append','DelayTime',pause);
    end

    if i <length(base.position_best_history)
        % 如果不是最后一张图就清除窗口
        clf;
    end
end

  出图中可以看出,超平面不一定会经过数据点。

五. 总结

这次介绍了如何使用优化算法来优化支持向量机。直接使用了支持向量机的定义作为适应度函数模型,避免了大量的对偶问题转换。同时为了减少计算量,使用了广义的支持向量机,让超平面不必一定经过支持向量,当数据集中数据较多时效果会非常明显。
  文中使用的差分进化算法实现可以看优化算法matlab实现(七)差分进化算法matlab实现。如果想使用其他优化算法,则引入相关的优化算法路径后,实例化即可。

文件目录如下:

\optimization algorithm\application_svm\SVM_Model.m
\optimization algorithm\application_svm\Test.m
\optimization algorithm\application_svm\SVM_Model_Broad.m
\optimization algorithm\application_svm\Test_Broad.m
\optimization_algorithm\frame\Unit.py
\optimization_algorithm\frame\Algorithm_Impl.py
\optimization_algorithm\algorithm_differential_evolution\DE_Unit.py
\optimization_algorithm\algorithm_differential_evolution\DE_Base.py
\optimization_algorithm\algorithm_differential_evolution\DE_Impl.py
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,125评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,293评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,054评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,077评论 1 291
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,096评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,062评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,988评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,817评论 0 273
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,266评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,486评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,646评论 1 347
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,375评论 5 342
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,974评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,621评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,796评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,642评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,538评论 2 352

推荐阅读更多精彩内容