为什么要评估?
当训练或部署神经网络时,了解模型的准确性是有用的。在DL4J中,评估类和评估类的变体可用于评估模型的性能。
分类评估
评估类用于评估二分类和多类分类器(包括时间序列分类器)的性能。本节介绍了评估类的基本用法。
给定一个DataSetIterator形式的数据集,执行评估的最简单方法是使用MultiLayerNetwork和ComutationGraph上的内置评估方法:
DataSetIterator myTestData = ...
Evaluation eval = model.evaluate(myTestData);
然而,也可以对单个小批量进行评估。这里是一个例子,从我们的示例项目中数据实例/CSV示例中获得。
CSV的例子有3类花的CSV数据,建立了一个简单的前馈神经网络用于对基于4个测量值的花的分类。
Evaluation eval = new Evaluation(3);
INDArray output = model.output(testData.getFeatures());
eval.eval(testData.getLabels(), output);
log.info(eval.stats());
第一行创建一个具有3个类的评估对象。第二行从模型中获取我们测试数据集的标签。第三行使用eval方法将来自testdata的标签数组与从模型生成的标签进行比较。第四行将评估数据记录到控制台。
输出
Examples labeled as 0 classified by model as 0: 24 times
Examples labeled as 1 classified by model as 1: 11 times
Examples labeled as 1 classified by model as 2: 1 times
Examples labeled as 2 classified by model as 2: 17 times
==========================Scores========================================
# of classes: 3
Accuracy: 0.9811
Precision: 0.9815
Recall: 0.9722
F1 Score: 0.9760
Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes)
========================================================================
默认情况下,.stats() 方法显示混淆矩阵条目(每行一个)、准确度、精度、召回率和F1分数。此外,评估类还可以计算并返回以下值:
- 混淆矩阵
- 假阳性/阴性率
- 真阳性/阴性
- 类别计数
- F-beta, G-measure, Matthews 关系数及更多, 查看 Evaluation JavaDoc
显示混淆矩阵。
System.out.println(eval.confusionToString());
显示
Predicted: 0 1 2
Actual:
0 0 | 16 0 0
1 1 | 0 19 0
2 2 | 0 0 18
此外,可以直接访问混淆矩阵,使用CSV或HTML转换。
eval.getConfusionMatrix() ;
eval.getConfusionMatrix().toHTML();
eval.getConfusionMatrix().toCSV();
回归评估
为了评估执行回归的网络,使用回归评估类。
带着评估类,一个DataSetIterator上的回归评估可以执行如下:
DataSetIterator myTestData = ...
RegressionEvaluation eval = model.evaluateRegression(myTestData);
这里有一个单列的代码片段,在这种情况下,神经网络是根据测量值来预测自己的年龄。
RegressionEvaluation eval = new RegressionEvaluation(1);
打印评估的统计数据。
System.out.println(eval.stats());
返回
Column MSE MAE RMSE RSE R^2
col_0 7.98925e+00 2.00648e+00 2.82653e+00 5.01481e-01 7.25783e-01
列是均方误差、均方绝对误差、均方根误差、相对平方误差和R^2决定系数。
查看 回归评估JavaDoc
同时进行多个评估
当执行多种类型的评估时(例如,在同一网络和数据集上执行评估和ROC),在数据集的一次传递中执行以下操作更有效:
DataSetIterator testData = ...
Evaluation eval = new Evaluation();
ROC roc = new ROC();
model.doEvaluation(testdata, eval, roc);
时间序列评估
时间序列评估与上述评估方法非常相似。DL4J中的评估对所有(非掩码的)时间步分别执行——例如,长度为10的时间序列将为评估对象贡献10个预测/标签。与时间序列的一个不同之处在于掩码数组是(可选的),这些掩码数组用于将一些时间步标记为丢失或不存在。请参阅使用RNNS掩码以获得更多关于掩码的细节。
对于大多数用户来说,仅仅使用 MultiLayerNetwork.evaluate(DataSetIterator)
或 MultiLayerNetwork.evaluateRegression(DataSetIterator)
和类似的方法就足够了。如果掩码数组存在,这些方法将正确地处理掩码。
二分类器评估
EvaluationBinary用于评估具有二分类输出的网络——这些网络通常具有Sigmoid激活函数和XENT损失函数。为每个输出计算典型的分类度量,例如准确度、精度、召回率、F1得分等。
EvaluationBinary eval = new EvaluationBinary(int size)
ROC
ROC(接收者操作特征)是另一种常用的评估分类器的评估指标。DL4J中存在三个ROC变体:
- ROC -用“一对全部”的方法评估非二分类器
- ROCBinary - 用于单二分类标签(作为单列概率,或两列的softmax概率分布)
- ROCMultiClass - 用于多二分类标签
这些类具有通过calculateAUC()和calculateAUPRC()方法计算ROC曲线下面积(AUROC)和精确度-召回曲线下面积(AUPRC)的能力。此外,可以使用getRocCurve()
和getPrecisionRecallCurve()
获得ROC和精确度-召回曲线。
ROC和精确度-召回曲线可以导出到HTML以便查看,使用:“EvaluationTools.exportRocChartsToHtmlFile(ROC,File)”,该文件将导出具有ROC和精确度-召回曲线的HTML文件,可以在浏览器中查看。
注意,所有三种支持两种操作/计算模式。
- 阈值(近似AUROC/AUPRC计算,无内存问题)
- 精确(精确的AUROC/AUPRC计算,但是对于非常大的数据集(即具有数百万个示例的数据集)可能需要大量的内存
可以使用构造函数设置容器的数量。可以使用默认构造函数new ROC()
来精确设置,或者显式地使用new ROC(0)
。
参见ROCBinary JavaDoc用于评估二元分类器。
评估分类器校准
DL4J还具有评估校准类,它被设计用于分析分类器的校准。它提供了许多的工具用于如下目的:
- 每个类别的标签数量和预测的计数
- 可靠性图(或可靠性曲线)
- 残差图(直方图)
- 概率直方图,包括每个类的概率
使用评估校准的分类器评估方式与其它评估类相似。可以使用EvaluationTools.exportevaluationCalibrationToHtmlFile(EvaluationCalibration, File)
将各种绘图/直方图导出到HTML以便查看。
Spark网络的分布式评估
SparkDl4jMultiLayer 和 SparkComputationGraph 都有相似的评估方法:
Evaluation eval = SparkDl4jMultiLayer.evaluate(JavaRDD<DataSet>);
//一次传递多次评估:
SparkDl4jMultiLayer.doEvaluation(JavaRDD<DataSet>, IEvaluation...);
多任务网络评估
多任务网络是经过训练以产生多个输出的网络。例如,可以对给定音频样本的网络进行训练,以预测说话者的语言和说话人的性别。这里简要描述了多任务配置。
适用于多任务网络的评估类
可用的评估
Evaluation
评估指标:
- 精度,召回率,F1,FBeta,准确度,马休斯相关系数,gMeasure
argmax / 0.5) 注意:在使用用于二分类度量(如F1、精确度、召回等)的评估类时应小心。有许多案例需要考虑:-
对于二分类(1或2个网络输出)
c)在两个类上使用宏平均度量进行二分类(不常见且通常不可取),如上(b)所示,指定“null”作为参数(而不是0或1)
将报告宏平均(一个对全部)二分类度量。请注意,可以指定微vs宏平均
注意,设置自定义二进制决策阈值仅对于二分类情况(1或2个输出)是可能的,并且如果类的数量超过2,则不能使用它。概率>阈值的预测被认为是类1,否则被认为是类0。
成本数组(行向量,大小等于输出数量)修改评估过程:我们不是简单地执行predictedClass = argMax(probabilities),而是执行predictedClass = argMax(cost probabilities)。因此,所有1s的数组(或者实际上任何相等值的数组)将导致与无成本数组相同的性能;非相等值将偏离对某些类的预测。
-
Evaluation
public Evaluation(int numClasses)
评估中要考虑的分类数
- 参数 numClasses 评估中要考虑的分类数
Evaluation
public Evaluation(int numClasses, Integer binaryPositiveClass)
构造函数,用于指定类的数目,并且可选地用于二分类的正类。有关二分类情况下的评估的详细信息,请参见评估JavaDoc
- 参数 numClasses 评估的分类数。必须是2,如果binaryPositiveClass是非空的
- 参数 binaryPositiveClass 如果非空,则为正类(0或1)。
eval
public void eval(INDArray trueLabels, INDArray input, ComputationGraph network)
对 使用给定的true标签的输出、计算图网络输入和用于评估的计算图网络 进行评估
- 参数 trueLabels 使用的标签
- 参数 input 用于评估的网络输入
- 参数 network 用于输出的网络
eval
public void eval(INDArray trueLabels, INDArray input, MultiLayerNetwork network)
对 使用给定的true标签的输出、多层网络输入和用于评估的多层网络 进行评估
- 参数 trueLabels 使用的标签
- 参数 input 用于评估的网络输入
- 参数 network 用于输出的网络
eval
public void eval(INDArray realOutcomes, INDArray guesses)
收集关于真实结果和猜测的统计数据。这是逻辑的结果矩阵。
请注意,如果传递的两个矩阵中长度不相同,则会抛出IllegalArgumentException。
- 参数 realOutcomes 真实的结果(标签-通常是二分类的)
- 参数 guesses 猜测/预测 (通常是概率向量)
eval
public void eval(final INDArray realOutcomes, final INDArray guesses,
final List<? extends Serializable> recordMetaData)
用可选元数据评估网络
- 参数 realOutcomes 数据标签
- 参数 guesses 网络预测
- 参数 recordMetaData 可选的;可以是空的。如果不是NULL,则其大小应该等于结果/猜测的数量。
eval
public void eval(int predictedIdx, int actualIdx)
评估单一预测(一次一个预测)
- 参数 predictedIdx 网络预测类索引
- 参数 actualIdx 实际类索引
stats
public String stats()
以字符串形式报告统计信息
- 返回分类统计信息
stats
public String stats(boolean suppressWarnings)
以字符串形式获取分类报告的方法。
参数 suppressWarnings 是否输出与评估结果相关的警告
返回(多行)字符串的准确性、精确性、召回、F1得分等
stats
public String stats(boolean suppressWarnings, boolean includeConfusion)
以字符串形式获取分类报告的方法。
- 参数 suppressWarnings 是否输出与评估结果相关的警告
- 参数 includeConfusion 混淆矩阵是否应包含在返回的统计数据中
- 返回(多行)字符串的准确性、精确性、召回、F1得分等
confusionMatrix
public String confusionMatrix()
将混淆矩阵作为字符串获取
- 作为字符串返回混淆矩阵
precision
public double precision(Integer classLabel)
返回给定类标签的精度
- 参数 classLabel 标签
- 返回标签的精度
precision
public double precision(Integer classLabel, double edgeCase)
返回给定类标签的精度
- 参数 classLabel 标签
- 参数 edgeCase 在0/0情况时的输出
- 返回标签的精度
precision
public double precision()
迄今为止,基于猜测的精确性。
注意:返回的值将根据类的数量和设置而不同。
- 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
- 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均精度,相当于precision(EvaluationAveraging.Macro)。
- 基于猜测返回总精度
precision
public double precision(EvaluationAveraging averaging)
计算所有类的平均精度。可以指定是应该使用宏平均还是微观平均。注意:如果任何类具有tp=0和fp=0,(精度=0/0),则这些类被排除在平均值之外。
- 参数 averaging 平均法-宏或微
- 返回平均精度
averagePrecisionNumClassesExcluded
public int averagePrecisionNumClassesExcluded()
在计算(宏)平均精度时,由于没有预测平均中排除了多少类——即,精度是0/0的边缘情况。
- 返回从平均精度排除的类数
averageRecallNumClassesExcluded
public int averageRecallNumClassesExcluded()
在计算(宏)平均召回时,由于没有预测平均中排除了多少类——即,召回是0/0的边缘情况。
- 返回从平均召回排除的类数
averageF1NumClassesExcluded
public int averageF1NumClassesExcluded()
在计算(宏)平均F1时,由于没有预测,从平均值中排除了多少类——即,F1将根据0/0的精度或召回率来计算。
- 返回从平均F1排除的类数
averageFBetaNumClassesExcluded
public int averageFBetaNumClassesExcluded()
在计算(宏)平均FBeta时,由于没有预测,从平均值中排除了多少类——即,FBeta将根据0/0的精度或召回率来计算。
- 返回从平均FBeta排除的类数
recall
public double recall(int classLabel)
返回给定标签的召回率
- 参数 classLabel 标签
- 返回double类型的召回率
recall
public double recall(int classLabel, double edgeCase)
返回给定标签的召回率
- 参数 classLabel 标签
- 参数 edgeCase 在0/0的情况下的输出
- 返回double类型的召回率
recall
public double recall()
迄今为止基于猜测的召回
注意:返回的值将根据类的数量和设置而不同。
- 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
- 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均召回,相当于recall(EvaluationAveraging.Macro)。
- 为结果返回召回
recall
public double recall(EvaluationAveraging averaging)
计算所有类的平均召回-可以指定是使用宏平均还是微观平均。注意:如果任何类都具有TP=0和fn=0,(召回=0/0),这些都是从平均值中排除的。
- 参数 averaging 平均方法-宏或微
- 返回平均召回率
falsePositiveRate
public double falsePositiveRate(int classLabel)
返回给定标签的假阳性率
- 参数 classLabel 标签
- 返回double类型的假阳性率
falsePositiveRate
public double falsePositiveRate(int classLabel, double edgeCase)
返回给定标签的假阳性率
- 参数 classLabel 标签
- 参数 edgeCase 0/0时的输出
- 返回double类型的假阳性率
falsePositiveRate
public double falsePositiveRate()
迄今为止基于猜测的假阳性率 注意:返回的值将根据类的数量和设置而不同。
- 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
- 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均假阳性率,相当于falsePositiveRate(EvaluationAveraging.Macro)。
- 返回输出假阳性率
falsePositiveRate
public double falsePositiveRate(EvaluationAveraging averaging)
计算所有类别的平均假阳性率。可以指定是应该使用宏平均还是微观平均
- 参数 averaging 平均方法.宏观或微观
- 返回平均假阳性率
falseNegativeRate
public double falseNegativeRate(Integer classLabel)
返回给定标签的假阴性率
- 参数 classLabel 标签
- 返回double类型的假阴性率
falseNegativeRate
public double falseNegativeRate(Integer classLabel, double edgeCase)
返回给定标签的假阴性率
- 参数 classLabel 标签
- 参数 edgeCase 在0/0的情况下的输出
- 返回double类型的假阴性率
falseNegativeRate
public double falseNegativeRate()
迄今为止基于猜测的假阴性率 注意:返回的值将根据类的数量和设置而不同。
- 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
- 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均假阴性率,相当于falseNegativeRate(EvaluationAveraging.Macro)。
- 返回输出假阳性率
falseNegativeRate
public double falseNegativeRate(EvaluationAveraging averaging)
计算所有类别的平均假阴性率。可以指定是应该使用宏平均还是微观平均
- 参数 averaging 平均方法.宏观或微观
- 返回平均假阴性率
falseAlarmRate
public double falseAlarmRate()
误报率反映了对分类记录的错误分类率。http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw 注意:返回的值将根据类的数量和设置而不同。
- 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
- 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均误报率。
- 返回输出误报率
f1
public double f1(int classLabel)
计算给定分类的F1分数
- 参数 classLabel 计算F1的标签
- 返回给定标签的F1分数
fBeta
public double fBeta(double beta, int classLabel)
计算给定类的FBeta,其中FBeta定义为:
(1 +beta^ 2)(精确召回)/(beta^ 2精度+召回)。
F1是FBeta的一个特例,具有beta=1。
- 参数 beta 使用的Beta值
- 参数 classLabel 分类标签
- 返回 FBeta
fBeta
public double fBeta(double beta, int classLabel, double defaultValue)
计算给定类的FBeta,其中FBeta定义为:
(1 +beta^ 2)(精确召回)/(beta^ 2精度+召回)。
F1是FBeta的一个特例,具有beta=1。
- 参数 beta 使用的Beta值
- 参数 classLabel 分类标签
- 参数 defaultValue 精度或召回未定义(精度或召回为0/0)时的缺省值
- 返回 FBeta
f1
public double f1()
计算F1得分
F1得分定义为:
TP:真阳性
FP:假阳性
FN:假阴性
F1得分:2 TP/(2TP+FP+FN)
注意:返回的值将根据类的数量和设置而不同。
- 对于二分类,如果设置了正类(通过默认值为1、通过构造函数或通过setBinaryPositiveClass(Integer),则返回的值将仅用于指定的正类。
- 对于多分类的情况,或者当getBinaryPositiveClass()为NULL时,返回的值是跨所有类的宏平均值。即,宏平均 f1,相当于 f1(EvaluationAveraging.Macro)。
- 返回基于当前猜测的f1分数或精度与召回的调和平均
f1
public double f1(EvaluationAveraging averaging)
计算所有类别的F1得分。可以指定是应该使用宏平均还是微观平均
- 参数 averaging 平均方法.宏观或微观
fBeta
public double fBeta(double beta, EvaluationAveraging averaging)
计算所有类别的F_beta得分。可以指定是应该使用宏平均还是微观平均
- 参数 beta 使用的Beta值
- 参数 averaging 平均方法.宏观或微观
gMeasure
public double gMeasure(int output)
计算给定输出的G-measure
- 参数 output 指定输出
- 返回指定输出的G-measure
gMeasure
public double gMeasure(EvaluationAveraging averaging)
使用微或宏平均计算所有输出的平均Gmeasure
- 参数 averaging 平均方法.宏观或微观
- 返回平均G measure
accuracy
public double accuracy()
准确率: (TP + TN) / (P + N)
- 返回到目前为止猜测的准确率
topNAccuracy
public double topNAccuracy()
迄今为止预测的第N高的准确率。对于top n=1(默认值),相当于accuracy()
- 返回 前N 准确率
翻译:风一样的男子
如果您觉得我的文章给了您帮助,请为我买一杯饮料吧!以下是我的支付宝,意思一下我将非常感激!