让keras的model能正常使用pickle
ref: http://zachmoshe.com/2017/04/03/pickling-keras-models.html
最近碰到了一个场景,需要能够让代码兼容keras的Model和sklearn的classifier,支持模型的保存和读取。想到这里我首先想到的是pickle,基本很少有场景是pickle不能序列化存储的。但这次但我尝试dump keras的Model对象时,还真报错不能存储了。尽管第一反应是,可以针对keras的特殊类型用keras自己的save/load机制,但如果pickle可以兼容,那应该对代码来说更友好。果不其然就找到了这篇博客,介绍了让keras兼容pickle的方案。代码如下
import types import tempfile import keras.models def make_keras_picklable(): def __getstate__(self): model_str = "" with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd: keras.models.save_model(self, fd.name, overwrite=True) model_str = fd.read() d = { 'model_str': model_str } return d def __setstate__(self, state): with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd: fd.write(state['model_str']) fd.flush() model = keras.models.load_model(fd.name) self.__dict__ = model.__dict__ cls = keras.models.Model cls.__getstate__ = __getstate__ cls.__setstate__ = __setstate__
使用的时候只需要执行一次make_keras_picklable方法就可以了。
本文出自 Tech Trace,转载时请注明出处及相应链接。
本文永久链接: https://www.qiujiahui.com/2017/08/28/%e8%ae%a9keras%e7%9a%84model%e8%83%bd%e6%ad%a3%e5%b8%b8%e4%bd%bf%e7%94%a8pickle/