《时空数据处理和组织课程实习》实验报告
题目: 实验5 决策树分类
日期:6.13
实验环境:python3.6,windows,wsl2(ubuntu 20.04)
写在前面
建议使用简书查看本文,排版更佳本文简书地址
本次涉及到的代码文件是day5.py
、day5_pre_process.py
与之对应的readme是day3.md
实习涉及到的全部代码都已储存到了github仓库,建议在线查看我的代码
亦可以使用git clone https://github.com/uiharuayako/geoDataWork.git
实时获取我的最新进展!
所有代码均为本人原创或者来自老师给的资料,多点学习和交流思路,少点复制粘贴,谢谢!
实验内容与完成情况:
程序编程实现了实验内容的所有项目,完成了所需的全部功能
其实代码里注释写得很清楚,建议直接看代码,在这里我主要写一下我的思路,以及代码片段的分析
代码分析
1. 从文件中导入数据,并转化为DataFrame。
2. 训练决策树模型,用于预测居民收入是否超过50K;
3. 对Test数据集进行验证,输出模型的准确率。
直接看代码,这次代码很短,但是信息量很大。
因为做的时候群里有人留言说数据有问题,还有好几个问题,我看了一下。那个adult.test
修改之前,每行的50K
都写成了50K.
。这个问题确实存在。然后还有同学说数据本身也有错漏,于是我写了个预处理函数进行修补。
预处理函数使用的是pandas模块,不得不说,pandas dataframe的功能比spark的对python支持好很多。起码pandas是native python,而spark是python转成Java的。效率差别高下立判!
说实话,我这个if嵌套的我自己都难受,但是这样,4w条数据也能在几秒内处理好。效率还是很高的,反观pyspark读这个数据......甚至读不出来。
预处理函数的代码如下:
import pandas as pd
data = pd.read_csv('adult/adult.data', header=None, sep=', ', engine='python')
print(data.shape)
# 第一步,判定含有空值的行
null_lines = data.isnull().T.any()
for index, value in null_lines.items():
if value:
print("{}行有空值".format(index + 1))
# 去除空值
data.dropna(axis=0, how='any')
# 第二步,判定不对劲的值
def is_number(s):
try:
float(s)
return True
except ValueError:
pass
try:
import unicodedata
unicodedata.numeric(s)
return True
except (TypeError, ValueError):
pass
return False
work_type = {'Private': 1,
'Self-emp-not-inc': 2,
'Self-emp-inc': 3,
'Federal-gov': 4,
'Local-gov': 5,
'State-gov': 6,
'Without-pay': 7,
'Never-worked': 8,
'?': -1}
education = {'Bachelors': 1,
'Some-college': 2,
'11th': 3,
'HS-grad': 4,
'Prof-school': 5,
'Assoc-acdm': 6,
'Assoc-voc': 7,
'9th': 8,
'7th-8th': 9,
'12th': 10,
'Masters': 11,
'1st-4th': 12,
'10th': 13,
'Doctorate': 14,
'5th-6th': 15,
'Preschool': 16,
'?': -1}
marital_status = {'Married-civ-spouse': 1,
'Divorced': 2,
'Never-married': 3,
'Separated': 4,
'Widowed': 5,
'Married-spouse-absent': 6,
'Married-AF-spouse': 7,
'?': -1}
occupation = {'Tech-support': 1,
'Craft-repair': 2,
'Other-service': 3,
'Sales': 4,
'Exec-managerial': 5,
'Prof-specialty': 6,
'Handlers-cleaners': 7,
'Machine-op-inspct': 8,
'Adm-clerical': 9,
'Farming-fishing': 10,
'Transport-moving': 11,
'Priv-house-serv': 12,
'Protective-serv': 13,
'Armed-Forces': 14,
'?': -1}
relationship = {'Wife': 1,
'Own-child': 2,
'Husband': 3,
'Not-in-family': 4,
'Other-relative': 5,
'Unmarried': 6,
'?': -1}
race = {'White': 1,
'Asian-Pac-Islander': 2,
'Amer-Indian-Eskimo': 3,
'Other': 4,
'Black': 5,
'?': -1}
sex = {'Female': 1,
'Male': 2,
'?': -1}
native_country = {'United-States': 1,
'Cambodia': 2,
'England': 3,
'Puerto-Rico': 4,
'Canada': 5,
'Germany': 6,
'Outlying-US(Guam-USVI-etc)': 7,
'India': 8,
'Japan': 9,
'Greece': 10,
'South': 11,
'China': 12,
'Cuba': 13,
'Iran': 14,
'Honduras': 15,
'Philippines': 16,
'Italy': 17,
'Poland': 18,
'Jamaica': 19,
'Vietnam': 20,
'Mexico': 21,
'Portugal': 22,
'Ireland': 23,
'France': 24,
'Dominican-Republic': 25,
'Laos': 26,
'Ecuador': 27,
'Taiwan': 28,
'Haiti': 29,
'Columbia': 30,
'Hungary': 31,
'Guatemala': 32,
'Nicaragua': 33,
'Scotland': 34,
'Thailand': 35,
'Yugoslavia': 36,
'El-Salvador': 37,
'Trinadad&Tobago': 38,
'Peru': 39,
'Hong': 40,
'Holand-Netherlands': 41,
'?': -1}
for index, row in data.iterrows():
if is_number(row[0]):
if row[1] in work_type:
if is_number(row[2]):
if row[3] in education:
if is_number(row[4]):
if row[5] in marital_status:
if row[6] in occupation:
if row[7] in relationship:
if row[8] in race:
if row[9] in sex:
if is_number(row[10]):
if is_number(row[11]):
if is_number(row[12]):
if row[13] in native_country:
continue
print("{}有错误".format(index + 1))
这里做了两种验证,第一个是验证是不是有数据为空。为?的数据视为不空。然后把空行去掉,做第二步验证,就是拿一串很长很长的if。我验证了每一位的数据是否有效,为数字的进行一个数字验证(即,验证字符串是否是纯数字),为字符串的在字典里找有没有对应的key,发现adult.data以及adult.test都没有错误(当然,.test文件的50K.改过来了)。
然后我就直接准备开始训练,这里用了老师讲的spark机器学习流水线的概念。有训练集和验证集,配置参数即可开始使用。Spark ML还是比较好用的。
这里的关键是怎么把字符串的属性信息翻译成数字,我选择了字典的方式,通过预先定义的字典,使用属性信息作为键值,即可直接翻译出数字。这样的效率很高,效果也不错,就是前期编辑这几个字典的时候有点麻烦,总体来说还是没啥问题的。
下面是训练代码:
import findspark
findspark.init()
from pyspark.ml.classification import DecisionTreeClassificationModel
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vector, Vectors
from pyspark.sql import Row, SQLContext
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer
from pyspark import SparkContext, SparkConf
from pyspark.sql import SQLContext
# 字典出类型
work_type = {'Private': 1,
'Self-emp-not-inc': 2,
'Self-emp-inc': 3,
'Federal-gov': 4,
'Local-gov': 5,
'State-gov': 6,
'Without-pay': 7,
'Never-worked': 8,
'?': -1}
education = {'Bachelors': 1,
'Some-college': 2,
'11th': 3,
'HS-grad': 4,
'Prof-school': 5,
'Assoc-acdm': 6,
'Assoc-voc': 7,
'9th': 8,
'7th-8th': 9,
'12th': 10,
'Masters': 11,
'1st-4th': 12,
'10th': 13,
'Doctorate': 14,
'5th-6th': 15,
'Preschool': 16,
'?': -1}
marital_status = {'Married-civ-spouse': 1,
'Divorced': 2,
'Never-married': 3,
'Separated': 4,
'Widowed': 5,
'Married-spouse-absent': 6,
'Married-AF-spouse': 7,
'?': -1}
occupation = {'Tech-support': 1,
'Craft-repair': 2,
'Other-service': 3,
'Sales': 4,
'Exec-managerial': 5,
'Prof-specialty': 6,
'Handlers-cleaners': 7,
'Machine-op-inspct': 8,
'Adm-clerical': 9,
'Farming-fishing': 10,
'Transport-moving': 11,
'Priv-house-serv': 12,
'Protective-serv': 13,
'Armed-Forces': 14,
'?': -1}
relationship = {'Wife': 1,
'Own-child': 2,
'Husband': 3,
'Not-in-family': 4,
'Other-relative': 5,
'Unmarried': 6,
'?': -1}
race = {'White': 1,
'Asian-Pac-Islander': 2,
'Amer-Indian-Eskimo': 3,
'Other': 4,
'Black': 5,
'?': -1}
sex = {'Female': 1,
'Male': 2,
'?': -1}
native_country = {'United-States': 1,
'Cambodia': 2,
'England': 3,
'Puerto-Rico': 4,
'Canada': 5,
'Germany': 6,
'Outlying-US(Guam-USVI-etc)': 7,
'India': 8,
'Japan': 9,
'Greece': 10,
'South': 11,
'China': 12,
'Cuba': 13,
'Iran': 14,
'Honduras': 15,
'Philippines': 16,
'Italy': 17,
'Poland': 18,
'Jamaica': 19,
'Vietnam': 20,
'Mexico': 21,
'Portugal': 22,
'Ireland': 23,
'France': 24,
'Dominican-Republic': 25,
'Laos': 26,
'Ecuador': 27,
'Taiwan': 28,
'Haiti': 29,
'Columbia': 30,
'Hungary': 31,
'Guatemala': 32,
'Nicaragua': 33,
'Scotland': 34,
'Thailand': 35,
'Yugoslavia': 36,
'El-Salvador': 37,
'Trinadad&Tobago': 38,
'Peru': 39,
'Hong': 40,
'Holand-Netherlands': 41,
'?': -1}
def f(x):
rel = {
'features': Vectors.dense(float(x[0]),
float(work_type[x[1]]),
float(x[2]),
float(education[x[3]]),
float(x[4]),
float(marital_status[x[5]]),
float(occupation[x[6]]),
float(relationship[x[7]]),
float(race[x[8]]),
float(sex[x[9]]),
float(x[10]),
float(x[11]),
float(x[12]),
float(native_country[x[13]])
),
'label': str(x[14])}
return rel
# spark 初始化
conf = SparkConf().setMaster("local").setAppName("ml")
sc = SparkContext(conf=conf) # 创建spark对象
# solve the question:AttributeError: 'PipelinedRDD' object has no attribute 'toDF'
sqlContext = SQLContext(sc)
data = sc.textFile("adult/adult.data").map(lambda line: line.split(', ')).map(
lambda p: Row(**f(p))).toDF()
labelIndexer = StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
featureIndexer = VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data)
labelConverter = IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
trainingData = data
testData = sc.textFile("adult/adult.test").map(lambda line: line.split(', ')).map(
lambda p: Row(**f(p))).toDF()
dtClassifier = DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")
dtPipeline = Pipeline().setStages([labelIndexer, featureIndexer, dtClassifier, labelConverter])
dtPipelineModel = dtPipeline.fit(trainingData)
dtPredictions = dtPipelineModel.transform(testData)
dtPredictions.select("predictedLabel", "label", "features").show(20)
evaluator = MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
dtAccuracy = evaluator.evaluate(dtPredictions)
print(dtAccuracy)
然后...就报错了,报错的原因是python crash了,没有错误信息,只知道是
data = sc.textFile("adult/adult.data").map(lambda line: line.split(', ')).map(
lambda p: Row(**f(p))).toDF()
这句话的toDF方法出错,这就很离谱,数据集应该没问题呀。然后我想可能是内存溢出,或者读取的行数太多超过了某一限制。于是想了个办法。在linux下运行
sudo head -20000 adult.data >test.txt
然后这个test还是无法读取。于是我
sudo head -10000 adult.data >test2.txt
这个test2奇迹般的读取出来了???而且后面的代码也正确运行。经过多次尝试,我觉得16000这个行数比较适中,就选择了16000行作为训练集。
这个时候我仍在怀疑是不是我的验证函数出了问题,于是我又截取了最后16000行进行测试,发现也不出错。
那是不是就是这么碰巧,这32651行里,所有所有的错误都恰好集中在中间651行,还恰好就是没有被检测出来呢?
在linux下,我使用sed -n '16001,24000p' adult.data>fin_data_mid.txt
截取了16001行开始的8000条数据,还是没有错误。
对此,我发现了两个事实:
- adult.data能通过测试函数的验证。而且当我对其作出随机改动并保存,测试函数
day5_pre_process
总能正确指出我改动的位置。我认为我的测试函数是没有问题的 - adult.data的前16000和后16000条数据,以及中间8000条数据均能进行训练,证明前16000条,后16000条,中间8000条数据都是没有问题的
我得到结论:adult.data
本身没有问题,有问题的是pyspark
最终,在和老师讨论后,我使用了部分数据进行训练,得到了最终结果,不算解决了问题,但是确实有思考的过程。
在查阅资料后发现,spark社区里有人提出过这个问题,但是至今没有得到解决...这个问题的出现是因人而异,甚至有时候能读,有时候读不出来,我觉得这是很奇怪的。
命令行结果
将data的前16000行作为训练集,test做测试集,运行,有如下结果
D:\ProgramData\Anaconda3\envs\py36\python.exe D:/code/geoDataWork/day5.py
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
+--------------+-----+--------------------+
|predictedLabel|label| features|
+--------------+-----+--------------------+
| <=50K|<=50K|[25.0,1.0,226802....|
| <=50K|<=50K|[38.0,1.0,89814.0...|
| <=50K| >50K|[28.0,5.0,336951....|
| >50K| >50K|[44.0,1.0,160323....|
| <=50K|<=50K|[18.0,-1.0,103497...|
| <=50K|<=50K|[34.0,1.0,198693....|
| <=50K|<=50K|[29.0,-1.0,227026...|
| >50K| >50K|[63.0,2.0,104626....|
| <=50K|<=50K|[24.0,1.0,369667....|
| <=50K|<=50K|[55.0,1.0,104996....|
| <=50K| >50K|[65.0,1.0,184454....|
| >50K|<=50K|[36.0,4.0,212465....|
| <=50K|<=50K|[26.0,1.0,82091.0...|
| <=50K|<=50K|[58.0,-1.0,299831...|
| <=50K| >50K|[48.0,1.0,279724....|
| >50K| >50K|[43.0,1.0,346189....|
| <=50K|<=50K|[20.0,6.0,444554....|
| <=50K|<=50K|[43.0,1.0,128354....|
| <=50K|<=50K|[37.0,1.0,60548.0...|
| >50K| >50K|[40.0,1.0,85019.0...|
+--------------+-----+--------------------+
only showing top 20 rows
0.8234580616672087
进程已结束,退出代码为 0
将后16000行作为训练集,有如下结果
D:\ProgramData\Anaconda3\envs\py36\python.exe D:/code/geoDataWork/day5.py
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
+--------------+-----+--------------------+
|predictedLabel|label| features|
+--------------+-----+--------------------+
| <=50K|<=50K|[25.0,1.0,226802....|
| <=50K|<=50K|[38.0,1.0,89814.0...|
| <=50K| >50K|[28.0,5.0,336951....|
| >50K| >50K|[44.0,1.0,160323....|
| <=50K|<=50K|[18.0,-1.0,103497...|
| <=50K|<=50K|[34.0,1.0,198693....|
| <=50K|<=50K|[29.0,-1.0,227026...|
| >50K| >50K|[63.0,2.0,104626....|
| <=50K|<=50K|[24.0,1.0,369667....|
| <=50K|<=50K|[55.0,1.0,104996....|
| >50K| >50K|[65.0,1.0,184454....|
| >50K|<=50K|[36.0,4.0,212465....|
| <=50K|<=50K|[26.0,1.0,82091.0...|
| <=50K|<=50K|[58.0,-1.0,299831...|
| <=50K| >50K|[48.0,1.0,279724....|
| >50K| >50K|[43.0,1.0,346189....|
| <=50K|<=50K|[20.0,6.0,444554....|
| <=50K|<=50K|[43.0,1.0,128354....|
| <=50K|<=50K|[37.0,1.0,60548.0...|
| >50K| >50K|[40.0,1.0,85019.0...|
+--------------+-----+--------------------+
only showing top 20 rows
0.8309072931788737
进程已结束,退出代码为 0
可以运行,甚至效果更好,再看中间8000条的结果
D:\ProgramData\Anaconda3\envs\py36\python.exe D:/code/geoDataWork/day5.py
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
+--------------+-----+--------------------+
|predictedLabel|label| features|
+--------------+-----+--------------------+
| <=50K|<=50K|[25.0,1.0,226802....|
| <=50K|<=50K|[38.0,1.0,89814.0...|
| <=50K| >50K|[28.0,5.0,336951....|
| >50K| >50K|[44.0,1.0,160323....|
| <=50K|<=50K|[18.0,-1.0,103497...|
| <=50K|<=50K|[34.0,1.0,198693....|
| <=50K|<=50K|[29.0,-1.0,227026...|
| >50K| >50K|[63.0,2.0,104626....|
| <=50K|<=50K|[24.0,1.0,369667....|
| <=50K|<=50K|[55.0,1.0,104996....|
| <=50K| >50K|[65.0,1.0,184454....|
| >50K|<=50K|[36.0,4.0,212465....|
| <=50K|<=50K|[26.0,1.0,82091.0...|
| <=50K|<=50K|[58.0,-1.0,299831...|
| <=50K| >50K|[48.0,1.0,279724....|
| >50K| >50K|[43.0,1.0,346189....|
| <=50K|<=50K|[20.0,6.0,444554....|
| <=50K|<=50K|[43.0,1.0,128354....|
| <=50K|<=50K|[37.0,1.0,60548.0...|
| >50K| >50K|[40.0,1.0,85019.0...|
+--------------+-----+--------------------+
only showing top 20 rows
0.8308268997562765
进程已结束,退出代码为 0
依然没有问题