问题
用sklearn训练的模型,如何将其参数保存,方便下次调用
模型
gbr = GBR(random_state=1412) # 实例化
gbr.fit(X, y.ravel()) # 训练模型
方法
常用方法 joblib 和 pickle 库
保存模型
- joblib
# from sklearn.externals import joblib # 低版本Scikit-learn 0.21版本以下
import joblib # 新版本 Scikit-learn
joblib.dump(gbr, "train_model.m")
- pickle
import pickle
with open('train_model.pkl', 'wb') as f:
pickle.dump(gbr, f)
读取模型
- joblib
import joblib
gbr = joblib.load("train_model.m")
- pickle
import pickle
with open('train_model.pkl', 'rb') as f:
gbr = pickle.load(f)
不同架构(Java、C++等)
ONNX 是模型的
二进制序列化
。它的开发是为了提高数据模型的可互操作表示的可用性。它旨在促进数据模型在不同机器学习框架之间的转换,并提高它们在不同计算架构上的可移植性。更多详细信息可从ONNX 教程中获得。为了将 scikit-learn 模型转换为 ONNX,我们开发了一个特定的工具sklearn-onnx。
PMML 是
XML文档
标准的一种实现,定义为表示数据模型以及用于生成它们的数据。PMML 是人类和机器可读的,是在不同平台上进行模型验证和长期存档的不错选择。另一方面,与一般的 XML 一样,当性能至关重要时,它的冗长对生产没有帮助。要将 scikit-learn 模型转换为 PMML,您可以使用在 Affero GPLv3 许可下分发的例如sklearn2pmml 。
- 参考文章3给出了存储成
json格式
的方式
import json
import numpy as np
class MyLogReg(LogisticRegression):
# Override the class constructor
def __init__(self, C=1.0, solver='liblinear', max_iter=100, X_train=None, Y_train=None):
LogisticRegression.__init__(self, C=C, solver=solver, max_iter=max_iter)
self.X_train = X_train
self.Y_train = Y_train
# A method for saving object data to JSON file
def save_json(self, filepath):
dict_ = {}
dict_['C'] = self.C
dict_['max_iter'] = self.max_iter
dict_['solver'] = self.solver
dict_['X_train'] = self.X_train.tolist() if self.X_train is not None else 'None'
dict_['Y_train'] = self.Y_train.tolist() if self.Y_train is not None else 'None'
# Creat json and save to file
json_txt = json.dumps(dict_, indent=4)
with open(filepath, 'w') as file:
file.write(json_txt)
# A method for loading data from JSON file
def load_json(self, filepath):
with open(filepath, 'r') as file:
dict_ = json.load(file)
self.C = dict_['C']
self.max_iter = dict_['max_iter']
self.solver = dict_['solver']
self.X_train = np.asarray(dict_['X_train']) if dict_['X_train'] != 'None' else None
self.Y_train = np.asarray(dict_['Y_train']) if dict_['Y_train'] != 'None' else None
存储和查看方法
filepath = "mylogreg.json"
# Create a model and train it
mylogreg = MyLogReg(X_train=Xtrain, Y_train=Ytrain)
mylogreg.save_json(filepath)
# Create a new object and load its data from JSON file
json_mylogreg = MyLogReg()
json_mylogreg.load_json(filepath)
json_mylogreg
限制
Pickle 和 Joblib
- 兼容性问题
Pickle 和 Joblib 的最大缺点就是其兼容性问题,可能与不同模型不同版本的 scikit-learn 或 Python 版本有关。 - 安全问题
Pickle(以及扩展的 Joblib)在可维护性和安全性方面存在一些问题。
JSON
- 安全性较低
- 适用于实例变量较少的对象
使用 JSON 进行数据序列化实际上是将对象保存为字符串格式,所以我们可以用文本编辑器打开和修改 mylogreg.json 文件。尽管这种方法对开发人员来说很方便,但其他人员也可以随意查看和修改 JSON 文件的内容,因此安全性较低。而且,这种方法更适用于实例变量较少的对象,例如 sklearn 模型,因为任何新变量的添加都需要更改保存和载入的方法。
相关文章: