多线程 – Keras Tensorflow – 从多个线程预测时的异常

前端之家收集整理的这篇文章主要介绍了多线程 – Keras Tensorflow – 从多个线程预测时的异常前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。
我正在使用keras 2.0.8和tensorflow 1.3.0后端.

我在类init中加载一个模型,然后用它来预测多线程.

import tensorflow as tf
from keras import backend as K
from keras.models import load_model


class CNN:
    def __init__(self,model_path):
        self.cnn_model = load_model(model_path)
        self.session = K.get_session()
        self.graph = tf.get_default_graph()

    def query_cnn(self,data):
        X = self.preproccesing(data)
        with self.session.as_default():
            with self.graph.as_default():
                return self.cnn_model.predict(X)

我初始化CNN一次,query_cnn方法从多个线程发生.

我在日志中得到的例外是:

File "/home/*/Similarity/CNN.py",line 43,in query_cnn
    return self.cnn_model.predict(X)
  File "/usr/local/lib/python3.5/dist-packages/keras/models.py",line 913,in predict
    return self.model.predict(x,batch_size=batch_size,verbose=verbose)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1713,in predict
    verbose=verbose,steps=steps)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py",line 1269,in _predict_loop
    batch_outs = f(ins_batch)
  File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py",line 2273,in __call__
    **self.session_kwargs)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 895,in run
    run_Metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1124,in _run
    Feed_dict_tensor,options,run_Metadata)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py",line 1321,in _do_run
    options,line 1340,in _do_call
    raise type(e)(node_def,op,message)
tensorflow.python.framework.errors_impl.NotFoundError: PruneForTargets: Some target nodes not found: group_deps

代码在大多数情况下工作正常,它可能是多线程的一些问题.

我该如何解决

解决方法

确保在创建其他线程之前完成图形创建.

在图表上调用finalize()可以帮助您.

def __init__(self,model_path):
        self.cnn_model = load_model(model_path)
        self.session = K.get_session()
        self.graph = tf.get_default_graph()
        self.graph.finalize()

更新1:finalize()将使您的图形为只读,以便可以安全地在多个线程中使用.作为副作用,它将帮助您找到无意的行为,有时还会发现内存泄漏,因为当您尝试修改图形时它会引发异常.

想象一下,你有一个线程可以做一个例如输入的热编码. (坏的例子:)

def preprocessing(self,data):
    one_hot_data = tf.one_hot(data,depth=self.num_classes)
    return self.session.run(one_hot_data)

如果在图表中打印对象数量,您会发现它会随着时间的推移而增加

# amount of nodes in tf graph
print(len(list(tf.get_default_graph().as_graph_def().node)))

但是,如果您首先定义图形不是这种情况(略微更好的代码):

def preprocessing(self,data):
    # run pre-created operation with self.input as placeholder
    return self.session.run(self.one_hot_data,Feed_dict={self.input: data})

更新2:根据此thread,您需要在执行多线程之前在keras模型上调用model._make_predict_function().

Keras builds the GPU function the first time you call predict(). That
way,if you never call predict,you save some time and resources.
However,the first time you call predict is slightly slower than every
other time.

更新的代码

def __init__(self,model_path):
    self.cnn_model = load_model(model_path)
    self.cnn_model._make_predict_function() # have to initialize before threading
    self.session = K.get_session()
    self.graph = tf.get_default_graph() 
    self.graph.finalize() # make graph read-only

更新3:我做了一个预热概念的证明,因为_make_predict_function()似乎没有按预期工作.
首先我创建了一个虚拟模型:

import tensorflow as tf
from keras.layers import *
from keras.models import *

model = Sequential()
model.add(Dense(256,input_shape=(2,)))
model.add(Dense(1,activation='softmax'))

model.compile(loss='mean_squared_error',optimizer='adam')

model.save("dummymodel")

然后在另一个脚本中我加载了该模型并使其在多个线程上运行

import tensorflow as tf
from keras import backend as K
from keras.models import load_model
import threading as t
import numpy as np

K.clear_session()

class CNN:
    def __init__(self,model_path):

        self.cnn_model = load_model(model_path)
        self.cnn_model.predict(np.array([[0,0]])) # warmup
        self.session = K.get_session()
        self.graph = tf.get_default_graph()
        self.graph.finalize() # finalize

    def preproccesing(self,data):
        # dummy
        return data

    def query_cnn(self,data):
        X = self.preproccesing(data)
        with self.session.as_default():
            with self.graph.as_default():
                prediction = self.cnn_model.predict(X)
        print(prediction)
        return prediction


cnn = CNN("dummymodel")

th = t.Thread(target=cnn.query_cnn,kwargs={"data": np.random.random((500,2))})
th2 = t.Thread(target=cnn.query_cnn,2))})
th3 = t.Thread(target=cnn.query_cnn,2))})
th4 = t.Thread(target=cnn.query_cnn,2))})
th5 = t.Thread(target=cnn.query_cnn,2))})
th.start()
th2.start()
th3.start()
th4.start()
th5.start()

th2.join()
th.join()
th3.join()
th5.join()
th4.join()

评论预热和最终确定的线条我能够重现你的第一个问题

猜你在找的Java相关文章