公司代码是Java,但是算法部分使用了Python的sklearn,考虑用sklearn2pmml生成pmml文件,再由java调用,实现跨平台使用。
- 安装sklearn2pmml
pip install sklearn2pmml
需要注意的是,
- scikit-learn的版本号需<=0.20.4,使用0.20.4之后的版本会报错,
AttributeError: module 'sklearn.externals.joblib' has no attribute '__version__'
因为sklearn.externals.joblib在0.21中弃用,将在0.23中删除。
DeprecationWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+
- java版本号需>=1.7
我的配置是,
python: 3.6.8
sklearn: 0.20.4
sklearn.externals.joblib: 0.13.2
pandas: 0.24.1
sklearn_pandas: 1.8.0
sklearn2pmml: 0.48.0
java: 1.8.0_144
- 测试Python代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn2pmml import PMMLPipeline, sklearn2pmml
iris = load_iris()
train, test, train_labels, test_labels = train_test_split(iris.data, iris.target, test_size=0.2, random_state=0)
pipeline = PMMLPipeline([
("classifier", tree.DecisionTreeClassifier(random_state=9))
])
pipeline.fit(train, train_labels)
sklearn2pmml(pipeline, 'result.pmml', with_repr=True, debug=True)
生成的pmml文件如下图所示,
运行自己的代码时可能会出现以下错误,
RuntimeError: The JPMML-SkLearn conversion application has failed. The Java executable should have printed more information about the failure into its standard output and/or standard error streams
出现此错误时需要查看train和train_labels的列名,要求没有重复并且格式正确。
- 测试Java代码
下载jpmml-sklearn-executable-1.5.7.jar和pmml-evaluator-1.4.3.jar,并引用jar包创建新工程。
经验证,引用上述jar包不会报错,不同的版本可能会报错,
以下为Java代码,
package javaTopython;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.InputField;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.TargetField;
public class PmmlFile {
public static void main(String[] args) throws Exception {
String pathxml="tree.pmml";
Map<String, Double> map=new HashMap<String, Double>();
map.put("x1", 5.1);
map.put("x2", 3.5);
map.put("x3", 1.4);
map.put("x4", 0.2);
predictLrHeart(map, pathxml);
}
public static void predictLrHeart(Map<String, Double> irismap,String pathxml)throws Exception {
PMML pmml;
// 模型导入
File file = new File(pathxml);
InputStream inputStream = new FileInputStream(file);
try (InputStream is = inputStream) {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory
.newInstance();
ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory
.newModelEvaluator(pmml);
Evaluator evaluator = (Evaluator) modelEvaluator;
List<InputField> inputFields = evaluator.getInputFields();
// 过模型的原始特征,从画像中获取数据,作为模型输入
Map<FieldName, FieldValue> arguments = new LinkedHashMap<>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object rawValue = irismap
.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(rawValue);
arguments.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(arguments);
List<TargetField> targetFields = evaluator.getTargetFields();
//对于分类问题等有多个输出。
for (TargetField targetField : targetFields) {
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
System.err.println("target: " + targetFieldName.getValue()
+ " value: " + targetFieldValue);
}
}
}
}
运行结果如下,
target y value: ProbabilityDistribution{result=0, probability_entries=[0=0.8876504283659372, 1=0.11232695495162393, 2=2.2616682438804697E-5]}
需要注意模型简化处理的情况,此时pmml文件中的<DataField>可能会省略掉系数为零的列,所以最好有一个检验。
参考:
sklearn2pmml安装使用