首页 » 未分类 » 让keras的model能正常使用pickle

让keras的model能正常使用pickle

作者: vForce 分类: 未分类 发布时间: 2017-08-28 14:57 ė浏览 26,049 次 6没有评论

ref: http://zachmoshe.com/2017/04/03/pickling-keras-models.html

最近碰到了一个场景,需要能够让代码兼容keras的Model和sklearn的classifier,支持模型的保存和读取。想到这里我首先想到的是pickle,基本很少有场景是pickle不能序列化存储的。但这次但我尝试dump keras的Model对象时,还真报错不能存储了。尽管第一反应是,可以针对keras的特殊类型用keras自己的save/load机制,但如果pickle可以兼容,那应该对代码来说更友好。果不其然就找到了这篇博客,介绍了让keras兼容pickle的方案。代码如下

import types
import tempfile
import keras.models

def make_keras_picklable():
def __getstate__(self):
model_str = ""
with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
keras.models.save_model(self, fd.name, overwrite=True)
model_str = fd.read()
d = { 'model_str': model_str }
return d

def __setstate__(self, state):
with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
fd.write(state['model_str'])
fd.flush()
model = keras.models.load_model(fd.name)
self.__dict__ = model.__dict__


cls = keras.models.Model
cls.__getstate__ = __getstate__
cls.__setstate__ = __setstate__

使用的时候只需要执行一次make_keras_picklable方法就可以了。

本文出自 Tech Trace,转载时请注明出处及相应链接。

本文永久链接: https://www.qiujiahui.com/2017/08/28/%e8%ae%a9keras%e7%9a%84model%e8%83%bd%e6%ad%a3%e5%b8%b8%e4%bd%bf%e7%94%a8pickle/

0

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

Ɣ回顶部