Bootstrap

JPMML调用PMML机器学习模型零基础总结(内含新版本jpmml解决方法)

起因

由于公司要上线机器学习的预测模型,而我用的是python语言,要在java上部署,所以需要我提供pmml文件,再通过jpmml库调用。但我不知道这种方式可不可行,同时交给别人调试也不方便,于是乎产生了自己搭建java-jpmml环境的想法。

流程

总体的流程大致是:
python侧sklearn2pmml 直接生成 .pmml 模型文件
下载JDK 16.0
下载intellij idea社区版
下载Maven(好像也可以不下载,毕竟intellij自带创建maven项目)
在intellij中创建新的maven项目
pom.xml中导入dependcy依赖环境,右击空白处maven-reload project自动下载jpmml及相关环境
创建新class,导入测试代码,run。

好像很简单,其实内部坑非常多!!

关键问题

先看一下我的pom.xml里面的dependcy:

<dependencies>
        <dependency>
            <groupId>org.jpmml</groupId>
            <artifactId>pmml-evaluator</artifactId>
            <version>1.5.15</version>
        </dependency>
        <dependency>
            <groupId>com.sun.xml.bind</groupId>
            <artifactId>jaxb-impl</artifactId>
            <version>2.1.2</version>
        </dependency>
        <dependency>
            <groupId>javax.xml.bind</groupId>
            <artifactId>jaxb-api</artifactId>
            <version>2.3.0</version>
        </dependency>
        <!-- https://mvnrepository.com/artifact/javax.activation/javax.activation -->
        <dependency>
            <groupId>javax.activation</groupId>
            <artifactId>activation</artifactId>
            <version>1.1.1</version>
        </dependency>
    </dependencies>

jpmml-evaluator:官网上所谓的jpmml安装方式,通过maven仓库直接导入,之前网上搜,说是会自动下载,结果我的idea自动下载莫名其妙有时有有时没有,这里推荐大家鼠标右键点空白选择maven-reload project,刷新一下马上自动下载下来了!

剩下的dependcy:为什么这里有个javaX,是因为网上所有找到的旧版本的代码,都需要javaX的相关依赖支持,虽然最后添加了解决了报错(javax.xml.bind.JAXBException: Implementation of JAXB-API has not been found),但是代码依旧出错!原因是代码本身就是错的!

说完dependcy来说说网上的旧版本代码到底错在哪里
原始代码出处举例:https://zhuanlan.zhihu.com/p/73245462

里面在读取模型时都有共同逻辑的一段代码:

private Evaluator loadPmml(String fp) throws FileNotFoundException, JAXBException, SAXException {
        InputStream is = new FileInputStream(fp);
        PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
        try {
            is.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        ModelEvaluatorFactory factory = ModelEvaluatorFactory.newInstance();
        return factory.newModelEvaluator(pmml);
    }

factory.newModelEvaluator(pmml),在IDEA中会标红pmml提示无法解析,运行后报错:
org.jpmml.evaluator.ModelEvaluatorFactory.newModelEvaluator(org.dmg.pmml.PMML,org.dmg.pmml.Model)
形参和参数列表不匹配

实际上,这个方法已经无法传入单个参数,需要传入两个参数和三个参数。

解决方法
放弃factory初始化方法,实际上在jpmml1.4.3的更新文档中已经推荐使用新的builder方法:

String fp = "iris.pmml";
TestPmml obj = new TestPmml();
Evaluator model = new LoadingModelEvaluatorBuilder()
                .load(new File("iris.pmml"))
                .build()

代码中直接把那段又长又丑的load函数删掉,上面的Evaluator开始一句话就可搞定初始化。
此外,需要注意这里还有一个坑,刚用intellij idea不知道文件目录该放在哪里,哪里新建代码段,看我的示例目录:
在这里插入图片描述
你的代码段是在src/main/java下面新创建类,注意类名和代码里面的类名应保持一致。
pmml文件的相对路径读取是从untitled3开始的,所以放在最外面可以直接读取。

测试代码

如果看了上面还是看不懂,可以看完整的示例代码。
参考的https://zhuanlan.zhihu.com/p/73245462
Python侧一致。
JAVA侧修改了jpmml初始化代码:

import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.*;
import java.io.*;
import java.util.*;


public class TestPmml {
    public static void main(String args[]) throws Exception {
        String fp = "iris.pmml";
        TestPmml obj = new TestPmml();
        Evaluator model = new LoadingModelEvaluatorBuilder()
                .load(new File("iris.pmml"))
                .build();
        List<Map<String, Object>> inputs = new ArrayList<>();
        inputs.add(obj.getRawMap(5.1, 3.5, 1.4, 0.2));
        inputs.add(obj.getRawMap(4.9, 3, 1.4, 0.2));
        for (int i = 0; i < inputs.size(); i++) {
            Map<String, Object> output = obj.predict(model, inputs.get(i));
            System.out.println("X=" + inputs.get(i) + " -> y=" + output.get("y"));
        }
    }


    private Map<String, Object> getRawMap(Object a, Object b, Object c, Object d) {
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("x1", a);
        data.put("x2", b);
        data.put("x3", c);
        data.put("x4", d);
        return data;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> predict(Evaluator evaluator, Map<String, Object> data) {
        Map<FieldName, FieldValue> input = getFieldMap(evaluator, data);
        Map<String, Object> output = evaluate(evaluator, input);
        return output;
    }

    /**
     * 把原始输入转换成PMML格式的输入。
     */
    private Map<FieldName, FieldValue> getFieldMap(Evaluator evaluator, Map<String, Object> input) {
        List<InputField> inputFields = evaluator.getInputFields();
        Map<FieldName, FieldValue> map = new LinkedHashMap<FieldName, FieldValue>();
        for (InputField field : inputFields) {
            FieldName fieldName = field.getName();
            Object rawValue = input.get(fieldName.getValue());
            FieldValue value = field.prepare(rawValue);
            map.put(fieldName, value);
        }
        return map;
    }

    /**
     * 运行模型得到结果。
     */
    private Map<String, Object> evaluate(Evaluator evaluator, Map<FieldName, FieldValue> input) {
        Map<FieldName, ?> results = evaluator.evaluate(input);
        List<TargetField> targetFields = evaluator.getTargetFields();
        Map<String, Object> output = new LinkedHashMap<String, Object>();
        for (int i = 0; i < targetFields.size(); i++) {
            TargetField field = targetFields.get(i);
            FieldName fieldName = field.getName();
            Object value = results.get(fieldName);
            if (value instanceof Computable) {
                Computable computable = (Computable) value;
                value = computable.getResult();
            }
            output.put(fieldName.getValue(), value);
        }
        return output;
    }

}
;