我的情况是这样的:
我有一个训练张量流模型的脚本.在此脚本中,我实例化了一个提供训练数据的类.该类的初始化依次实例化另一个名为“image”的类,以进行数据扩充的各种操作,而不是.
main script -> instantiates data_Feed class -> instantiates image class
我的问题是我试图通过传递会话本身或图形来使用tensorflow在这个图像类中做一些操作.但我收效甚微.
有效的方法(但速度太慢)
我现在所拥有的,但工作缓慢,就像这样(简化):
class image(object):
def __init__(self,im):
self.im = im
def augment(self):
aux_im = tf.image.random_saturation(self.im,0.6)
sess = tf.Session(graph=aux_im.graph)
self.im = sess.run(aux_im)
class data_Feed(object):
def __init__(self,data_dir):
self.images = load_data(data_dir)
def process_data(self):
for im in self.images:
image = image(im)
image.augment()
if __name__ == "__main__":
# initialize everything tensorflow related here,including model
sess = tf.Session()
# next load the data
data_Feed = data_Feed(TRAIN_DATA_DIR)
train_data = data_Feed.process_data()
这种方法有效,但它为每个图像创建一个新的Session:
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0,name: GeForce GTX 1070,pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0,pci bus id: 0000:01:00.0)
etc ...
不起作用的方法(应该更快)
例如,什么不起作用,我无法弄清楚为什么,是从我的主脚本传递图形或会话,如下所示:
class image(object):
def __init__(self,im):
self.im = im
def augment(self,tf_sess):
with tf_sess.as_default():
aux_im = tf.image.random_saturation(self.im,0.6)
self.im = tf_sess.run(aux_im)
class data_Feed(object):
def __init__(self,data_dir,tf_sess):
self.images = load_data(data_dir)
self.tf_sess = tf_sess
def process_data(self):
for im in self.images:
image = image(im)
image.augment(self.tf_sess)
if __name__ == "__main__":
# initialize everything tensorflow related here,including model
sess = tf.Session()
# next load the data
data_Feed = data_Feed(TRAIN_DATA_DIR,sess)
train_data = data_Feed.process_data()
这是我得到的错误:
Traceback (most recent call last):
File "/usr/lib/python2.7/threading.py",line 801,in __bootstrap_inner
self.run()
File "/usr/lib/python2.7/threading.py",line 754,in run
self.__target(*self.__args,**self.__kwargs)
File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py",line 409,in data_generator_task
generator_output = next(generator)
File "/home/mathetes/DropBox/ML/load_gluc_data.py",line 198,in generate
yield self.next_batch()
File "/home/mathetes/DropBox/ML/load_gluc_data.py",line 192,in next_batch
X,y,l = self.process_image(json_im,X,l)
File "/home/mathetes/DropBox/ML/load_gluc_data.py",line 131,in process_image
im.augment_with_tf(self.tf_sess)
File "/home/mathetes/DropBox/ML/load_gluc_data.py",line 85,in augment_with_tf
self.im = sess.run(saturation,{im_placeholder: np.asarray(self.im)})
File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py",line 766,in run
run_Metadata_ptr)
File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py",line 921,in _run
+ e.args[0])
TypeError: Cannot interpret Feed_dict key as Tensor: Tensor Tensor("Placeholder:0",shape=(96,96,3),dtype=float32) is not an element of this graph.
任何帮助将非常感激!
最佳答案
如何创建一个ImageAugmenter类,而不是使用Image类,该类在初始化时接受会话,然后使用Tensorflow处理您的图像?你可以这样做:
import tensorflow as tf
import numpy as np
class ImageAugmenter(object):
def __init__(self,sess):
self.sess = sess
self.im_placeholder = tf.placeholder(tf.float32,shape=[1,784,3])
def augment(self,image):
augment_op = tf.image.random_saturation(self.im_placeholder,0.6,0.8)
return self.sess.run(augment_op,{self.im_placeholder: image})
class DataFeed(object):
def __init__(self,sess):
self.images = load_data(data_dir)
self.augmenter = ImageAugmenter(sess)
def process_data(self):
processed_images = []
for im in self.images:
processed_images.append(self.augmenter.augment(im))
return processed_images
def load_data(data_dir):
# True method would read images from disk
# This is just a mockup
images = []
images.append(np.random.random([1,3]))
images.append(np.random.random([1,3]))
return images
if __name__ == "__main__":
TRAIN_DATA_DIR = '/some/dir/'
sess = tf.Session()
data_Feed = DataFeed(TRAIN_DATA_DIR,sess)
train_data = data_Feed.process_data()
print(train_data)
有了这个,你不会为每个图像创建一个新的会话,它应该给你你想要的.
注意如何调用sess.run();我传递给它的Feed dict的关键是上面定义的占位符张量.根据您的错误跟踪,您可能尝试从未定义im_placeholder的代码的一部分调用sess.run(),或者将其定义为tf.placeholder以外的其他内容.
此外,您可以通过更改ImageAugmenter.augment()方法以接收较低和较高参数作为tf.image.random_saturation()方法的输入来进一步改进代码,或者您可以使用特定形状初始化ImageAugmenter而不是例如,让它硬编码.