利用神经网进行西瓜分类(上)

主要内容:使用pytorch搭建一个简单的神经网络对西瓜数据集进行分类。总结交叉熵损失函数和均方差损失函数的使用范围,简单介绍pytorch的数据输入方式。
我本以为使用深度学习框架搭建一个两层的小型网络,就像sklearn封装的MLPClassifier一样非常简单,十几二十行代码就搞定了。结果折腾了一天以后我发现我洋洋洒洒写了一百多行。主要问题是定义网络确实较为简单,但是喂数据较为麻烦。因此我主要记录网络的构建和数据喂给网络的方法。

1,网络结构

这个网络由两个线性层和两个激活函数组成。激活函数暂时使用最经典的sigmoid函数。损失函数选用均方误差函数。
这里专门记录一下均方误差函数和交叉熵损失函数的选择问题。一般观点认为均方误差更适于回归任务,交叉熵模型搭配Softmax更适用于分类任务。原因是交叉熵损失函数衡量的是真实分布于模型预测分布的差异,主要从概率和信息量的方面考虑这种差异。而均方误差函数则直接从距离上衡量这种差异。我们知道sigmoid被嫌弃的原因是值过大或者过小,其梯度会变得很小。这样会造成训练变慢。所以交叉熵模型主要针对概率进行求解,概率的取值范围固定为0-1,效果相对较好。这个问题其他一些博客叙述更清楚一些,具体可以参考这里
这里的网络使用均方误差损失函数,原因是其导数好求一些,后续的文章我想计算前向传播和反向传播的计算量差异,需要自己求导,所以这里偷懒了。

2,pytorch的Dataset和DataLoader

不知道是我本人的原因还是怎么回事,每次做相关任务,数据读取或者预处理都会消耗掉我大部分精力。网络上的资料也有相同的问题,线性回归的例子都是使用已有数据集,方便快捷。介绍如何读取自己数据的教程很少。因此为了加强这部分练习,我尝试跟随pytorch的官方教程进行联系。Data Loading and Processing Tutorial.PS为什么我更喜欢pytorch,其中一个并不重要的原因是官网没被墙。
详细步骤其实就是两步。

2.1重新构建自己数据的Dataset类。

换句话说就是替换掉def init,def len(self),def getitem(self, idx)三个函数。为了方便处理,我把自己的预处理程序(连续化详细可以查看我的前一篇博客)也放到了这个类里。针对从csv中读取的文件,def init更改为读取文件的相关内容。def len(self)是更换为数据的总长度。def getitem(self, idx)是返回每个数据和标签的本体。(这是不是一个迭代器?我不知道这种说法准不准)。有的教程还将数据变换(Transforms)放在里面。西瓜数据集不用,所以删除相关代码。
整体代码如下。

class WatermelonDataset(Dataset):
    """西瓜数据集."""

    def __init__(self, csv_file):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
        """
        self.df = pd.read_csv(csv_file, encoding='gb18030')
        self.sort_file_name = r"attribute_sort.csv"
        self.label_file_name = r"label_sort.csv"

    def __len__(self):
        return len(self.df)
    
    def Attribute_2_csv(self):
            
    def Attribute_read_sort(self):

    def Label_2_csv(self):

    def Label_read_sort(self):


    def __getitem__(self, idx):
        self.Attribute_2_csv()
        data = torch.from_numpy(self.Attribute_read_sort()).float()
        self.Label_2_csv()
        label = torch.from_numpy(self.Label_read_sort()).float()
        data = data[idx,:]
        label = label[idx,0]
        return data,label

2.2使用DataLoader

DataLoader本质上是一个迭代器,很多用法可以参考list,是一个比较方便的数据输入方法。以我现在的应用来看可以用来处理batch的输入,打乱顺序等常用操作(认识好浅薄。。。)。使用很简单。

dataloader = DataLoader.DataLoader(watermelon_dataset,batch_size= 2, shuffle = True, num_workers= 0)
for data in dataloader:

就可以了。注意num_workers处可能会报错,如果数据量不大建议取0。
全部代码如下:

# -*- coding: utf-8 -*-
"""
Created on Sun Jul 28 09:20:21 2019

@author: BHN
"""
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data.dataset import Dataset
import pandas as pd
import os
import torch.utils.data.dataloader as DataLoader
import torchvision.transforms as transforms
batch_size = 4

class WatermelonDataset(Dataset):
    """西瓜数据集."""

    def __init__(self, csv_file):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
        """
        self.df = pd.read_csv(csv_file, encoding='gb18030')
        self.sort_file_name = r"attribute_sort.csv"
        self.label_file_name = r"label_sort.csv"
        self.Attribute_2_csv()
#        self.data = torch.from_numpy(self.Attribute_read_sort()).float()
        self.data = self.Attribute_read_sort()
        self.data = self.transform(self.data)
        self.data = torch.from_numpy(self.data).float()
        self.Label_2_csv()
        self.label = torch.from_numpy(self.Label_read_sort()).float()

    def __len__(self):
        return len(self.df)
    
    def transform(self,sample):
        sample = sample.astype(float)
        attribute_min = np.amin(sample, axis=0)
        attribute_max = np.amax(sample, axis=0)
        sample = sample - attribute_min[None,:]
        sample = sample / attribute_max[None,:]
        return sample
    
    def Attribute_2_csv(self):
#        主要功能去重每个属性,并存到csv文件中,
        if os.path.exists(self.sort_file_name):#如果文件存在就不需要重新排序
            return
        answer = {}
        for attribute_name in list(self.df.columns[1:-1]): #删除第一列的序数和最后一列的决策属性
            attribute = pd.Series(list(set(self.df[attribute_name].values)))#DataFrame对应转成numpy的是to_numpy()
#            answer.append(attribute)
            answer[attribute_name] = attribute
        answer = pd.DataFrame(answer)
        answer.to_csv(self.sort_file_name,index=False)
        print(answer)
            
    def Attribute_read_sort(self):
#        读取手工拍好的顺序属性
        df=pd.read_csv(self.sort_file_name)#这里的编码经过了转换
        continuing_table = {}
        for attribute_name in list(df.columns):#保存的时候删除了序号,所以从0开始
#            print(df[attribute_name].dropna().values)
            continuing_table[attribute_name]=dict(zip(\
                           df[attribute_name].dropna().values,\
                           range(len(df[attribute_name].dropna().values))))
#        print(self.df.values)
        answer = []
        for attribute_name in list(self.df.columns[1:-1]):
            answer_attribute = []
            for single_attribute in self.df[attribute_name]:
                answer_attribute.append(continuing_table[attribute_name][single_attribute])
#                print(answer_attribute)
            answer.append(answer_attribute)
        answer = np.array(answer).T
        return answer
            
    def Label_2_csv(self):
#        label属性其实不需要排列,主要防止有多个label,方便数字化
        if os.path.exists(self.label_file_name):#如果文件存在就不需要重新排序
            return
        answer = {}
        for attribute_name in list([self.df.columns[-1]]):
            attribute = pd.Series(list(set(self.df[attribute_name].values)))#查看最后一列,也就是label
            answer[attribute_name] = attribute
        answer = pd.DataFrame(answer)
        answer.to_csv(self.label_file_name,index=False)
        print(answer)
        
    def Label_read_sort(self):
#        读取离散label的顺序,label属性其实不需要排列,但是因为要数字化,所以干脆一起做了
        df=pd.read_csv(self.label_file_name)#这里的编码经过了转换
        continuing_table = {}
        answer = []
        for attribute_name in list(df.columns):
            continuing_table=dict(zip(\
                                      df[attribute_name].dropna().values,\
                                      range(len(df[attribute_name].dropna().values))))
            answer_attribute = []

            for single_attribute in self.df[attribute_name]:
                answer_attribute.append(continuing_table[single_attribute])
#                print(answer_attribute)
            answer.append(answer_attribute)
        answer = np.array(answer).T
#        answer = np.array(answer)
        return answer


    def __getitem__(self, idx):
        data = self.data[idx,:]
        label = self.label[idx,0]
        return data,label

class Linear(nn.Module):
    def __init__(self):
        super(Linear, self).__init__()
        self.hidden_num = 9
        self.linear_1 = torch.nn.Linear(6, self.hidden_num)
        self.linear_2 = torch.nn.Linear(self.hidden_num, 2)
        
    def forward(self, x):
        hidden = self.linear_1(x)
        hidden = torch.sigmoid(hidden)
        output = self.linear_2(hidden)
        output = torch.sigmoid(output)        
        out = output
        return out
    
        
if __name__ =='__main__':
    watermelon_dataset = WatermelonDataset(csv_file='data//watermelon.csv')
    net = Linear()
    criterion = nn.MSELoss()
    optimizer = optim.SGD(net.parameters(), lr=0.05, momentum=0.9)
    

    for epoch in range(10000):  # loop over the dataset multiple times
        dataloader = DataLoader.DataLoader(watermelon_dataset,batch_size= batch_size, shuffle = True, num_workers= 0)
        for data in dataloader:
            # get the inputs; data is a list of [inputs, labels]
            inputs,labels = data
            
            labels = labels.long()
            labels_onehot = torch.FloatTensor(len(labels), 2)
            labels_onehot.zero_()
            labels_onehot.scatter_(1, labels.view(len(labels),1), 1)
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
#            print(outputs, labels)
            loss = criterion(outputs, labels_onehot)
            loss.backward()
            optimizer.step()

        if epoch%100 == 0:
            correct = 0
            total = 0
            print(loss.item())
            with torch.no_grad():
                dataloader = DataLoader.DataLoader(watermelon_dataset,batch_size= 2, shuffle = True, num_workers= 0)
                for data in dataloader:
                    # get the inputs; data is a list of [inputs, labels]
                    inputs,labels = data
                
                    
                    outputs = net(inputs)
#                    print(outputs)
                    _, predicted = torch.max(outputs, 0)
#                   print(predicted.float()==labels)
                    correct += (predicted.float() == labels).sum().item()
                print('Accuracy of : %d %%' % (100 * correct / len(watermelon_dataset)))

可以发现正确率很低。产生的原因还需要分析。希望下次能解决这个问题。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,001评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,210评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,874评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,001评论 1 291
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,022评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,005评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,929评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,742评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,193评论 1 309
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,427评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,583评论 1 346
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,305评论 5 342
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,911评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,564评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,731评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,581评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,478评论 2 352

推荐阅读更多精彩内容