机器学习贝叶斯网络分类水果

水果部分数据


捕获.PNG

代码

import numpy as np
import math
import csv
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import pylab as pl
import random
from matplotlib import cm
from sklearn.model_selection import train_test_split

# 求平均值
def mean(numbers):
    return sum(numbers)/float(len(numbers))

# 求平均差
def stdev(numbers):
  avg = mean(numbers)
  variance = sum([pow(x-avg,2) for x in numbers])/float(len(numbers)-1)
  return math.sqrt(variance)

# 求各列的平均值和方差--提取数据特征
def summarize(dataset):
    parameter = [(mean(attribute), stdev(attribute)) for attribute in zip(*dataset)]
    #parameter = [(mean(dataset.iloc[:,i]),stdev(dataset.iloc[:,i])) for i in range(dataset.shape[1]) ]
    del parameter[-1]
    return parameter

# 进行分类
def separatedByClass(dataset):
    separated = {}
    #创建字典
    for i in range(len(dataset)):
        vector = dataset[i]
        if (vector[-1] not in separated):
            #根据最后一个元素,随后一个元素为1,2,3,4,代表着水果的种类,作为键值key
            separated[vector[-1]] = []
        separated[vector[-1]].append(vector)
    return separated

# 类别属性提取特征,即每一类四种特征总的均值和方差
def summarizeByClass(dataset):
    separated = separatedByClass(dataset)
    summaries = { }
    #创建字典
    for classValue, instances in separated.items():
        summaries[classValue] = summarize(instances)
    return summaries


# 求出高斯概率密度函数
def calculateProbability(x, mean, stdev):
    exponent = math.exp(-(math.pow(x - mean, 2) / (2 * math.pow(stdev, 2))))
    return (1 / (math.sqrt(2 * math.pi) * stdev)) * exponent

#所属类的概率
def calculateClassProbabilities(summaries, inputVector):
    probabilities = {}
    #字典
    for classValue, classSummaries in summaries.items():
        probabilities[classValue] = 1
        for i in range(len(classSummaries)):
            mean, stdev = classSummaries[i]
            x = inputVector[i]
            probabilities[classValue] *= calculateProbability(x, mean, stdev)
            #求出总的高斯密度的乘积
    return probabilities

# 对数据单一预测
# 每组测试数据最有可能的情况
def predict(summaries, inputVector):
    probabilities = calculateClassProbabilities(summaries, inputVector)
    bestLabel, bestProb = None, -1
    for classValue, probability in probabilities.items():
        if bestLabel is None or probability > bestProb:

            bestProb = probability
            bestLabel = classValue
    return bestLabel

#进行多重预测
def getPredictions(summaries, testSet):
    predictions = []        #来存储结果
    for i in range(len(testSet)):
        result = predict(summaries, testSet[i])
        predictions.append(result)
    return predictions     # 最终返回输出结果

#输出结果计算准确率
def getAccuracy(testSet, predictions):
    correct = 0
    print("结果:")
    for x in range(len(testSet)):
        print("预测的结果:", predictions[x], "----", testSet[x][-1], ":正确的结果")
        if testSet[x][-1] == predictions[x]:
            correct += 1
    return (correct / float(len(testSet))) * 100.0

def main():

    fruits = pd.read_table('E:/fruit.txt') #fruit.txt所在位置,我将它放在E盘。
    feature_names = ['fruit_label', 'mass', 'width', 'height', 'color_score']
    X = fruits[['mass', 'width', 'height', 'color_score', 'fruit_label']]
    Y = fruits['fruit_label']
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.25, random_state=0)  #通过pandas取出数据,再随机生成X_train和X_test 训练和测试数据
    Traindataset = X_train.values
    Testdataset = X_test.values

    '''
    代码原因将数据转换成一下格式,目的是为了去掉pandas中dataframe的index,如mass,width 等特征值
       mass  width  height  color_score  fruit_label
   42   154    7.2     7.2         0.82            3
   48   174    7.3    10.1         0.72            4
   变成
  [[154.     7.2    7.2    0.82   3.  ]
   [174.     7.3   10.1    0.72   4.  ]
   [ 76.     5.8    4.     0.81   2.  ]]   
   '''
    summaries = summarizeByClass(Traindataset)            #根据测试数据进行提取数据特征, 分类,求方差,均值,然后对每类进行特征值提取
    print("特征的提取:",summaries)                      #输出贝叶斯整理的结果
    predictions = getPredictions(summaries, Testdataset)  #输入测试数据
    accuracy = getAccuracy(Testdataset, predictions)
    print("准确率:",accuracy,'%')

if __name__ == "__main__":
    main()

运行结果


捕获1.PNG
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 后期整理字体以及排版问题,修订不适合的翻译 “A wealth of information. Smart, ye...
    iamzzz阅读 759评论 0 0
  • 3.1. 介绍 现在,您已经安装了Wireshark并有可能热衷于开始捕捉您的第一个数据包。在接下来的章节中,我们...
    wwyyzz阅读 1,400评论 0 1
  • ¥开启¥ 【iAPP实现进入界面执行逐一显】 〖2017-08-25 15:22:14〗 《//首先开一个线程,因...
    小菜c阅读 6,497评论 0 17
  • 第一部分 HTML&CSS整理答案 1. 什么是HTML5? 答:HTML5是最新的HTML标准。 注意:讲述HT...
    kismetajun阅读 27,588评论 1 45
  • 有一个人经常喜欢指责别人。王阳明对他说:“学习应该多反省自己,如果只是看到别人的不对,责怪别人,就不会看到自己的不...
    六爸啦啦啦阅读 415评论 0 1