python – Tensorflow:在方法中使用会话/图形

前端之家收集整理的这篇文章主要介绍了python – Tensorflow:在方法中使用会话/图形前端之家小编觉得挺不错的,现在分享给大家,也给大家做个参考。

我的情况是这样的:

我有一个训练张量流模型的脚本.在此脚本中,我实例化了一个提供训练数据的类.该类的初始化依次实例化另一个名为“image”的类,以进行数据扩充的各种操作,而不是.

main script -> instantiates data_Feed class -> instantiates image class

我的问题是我试图通过传递会话本身或图形来使用tensorflow在这个图像类中做一些操作.但我收效甚微.

 有效的方法(但速度太慢)

我现在所拥有的,但工作缓慢,就像这样(简化):

class image(object):
    def __init__(self,im):
        self.im = im

    def augment(self):
        aux_im = tf.image.random_saturation(self.im,0.6)

        sess = tf.Session(graph=aux_im.graph)
        self.im = sess.run(aux_im)

class data_Feed(object):
    def __init__(self,data_dir):
        self.images = load_data(data_dir)

    def process_data(self):
        for im in self.images:
            image = image(im)
            image.augment()

if __name__ == "__main__":
    # initialize everything tensorflow related here,including model
    sess = tf.Session()
    # next load the data
    data_Feed = data_Feed(TRAIN_DATA_DIR)
    train_data = data_Feed.process_data()

这种方法有效,但它为每个图像创建一个新的Session:

I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0,name: GeForce GTX 1070,pci bus id: 0000:01:00.0)
I tensorflow/core/common_runtime/gpu/gpu_device.cc:975] Creating TensorFlow device (/gpu:0) -> (device: 0,pci bus id: 0000:01:00.0)
etc ...

 不起作用的方法(应该更快)

例如,什么不起作用,我无法弄清楚为什么,是从我的主脚本传递图形或会话,如下所示:

class image(object):
    def __init__(self,im):
        self.im = im

    def augment(self,tf_sess):
        with tf_sess.as_default():
            aux_im = tf.image.random_saturation(self.im,0.6)

            self.im = tf_sess.run(aux_im)

class data_Feed(object):
    def __init__(self,data_dir,tf_sess):
        self.images = load_data(data_dir)
        self.tf_sess = tf_sess

    def process_data(self):
        for im in self.images:
            image = image(im)
            image.augment(self.tf_sess)

if __name__ == "__main__":
    # initialize everything tensorflow related here,including model
    sess = tf.Session()
    # next load the data
    data_Feed = data_Feed(TRAIN_DATA_DIR,sess)
    train_data = data_Feed.process_data()

这是我得到的错误

Traceback (most recent call last):
  File "/usr/lib/python2.7/threading.py",line 801,in __bootstrap_inner
    self.run()
  File "/usr/lib/python2.7/threading.py",line 754,in run
    self.__target(*self.__args,**self.__kwargs)
  File "/usr/local/lib/python2.7/dist-packages/keras/engine/training.py",line 409,in data_generator_task
    generator_output = next(generator)
  File "/home/mathetes/DropBox/ML/load_gluc_data.py",line 198,in generate
    yield self.next_batch()
  File "/home/mathetes/DropBox/ML/load_gluc_data.py",line 192,in next_batch
    X,y,l = self.process_image(json_im,X,l)
  File "/home/mathetes/DropBox/ML/load_gluc_data.py",line 131,in process_image
    im.augment_with_tf(self.tf_sess)
  File "/home/mathetes/DropBox/ML/load_gluc_data.py",line 85,in augment_with_tf
    self.im = sess.run(saturation,{im_placeholder: np.asarray(self.im)})
  File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py",line 766,in run
    run_Metadata_ptr)
  File "/home/mathetes/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py",line 921,in _run
    + e.args[0])
TypeError: Cannot interpret Feed_dict key as Tensor: Tensor Tensor("Placeholder:0",shape=(96,96,3),dtype=float32) is not an element of this graph.

任何帮助将非常感激!

最佳答案
如何创建一个ImageAugmenter类,而不是使用Image类,该类在初始化时接受会话,然后使用Tensorflow处理您的图像?你可以这样做:

import tensorflow as tf
import numpy as np

class ImageAugmenter(object):
    def __init__(self,sess):
        self.sess = sess
        self.im_placeholder = tf.placeholder(tf.float32,shape=[1,784,3])

    def augment(self,image):
        augment_op = tf.image.random_saturation(self.im_placeholder,0.6,0.8)
        return self.sess.run(augment_op,{self.im_placeholder: image})

class DataFeed(object):
    def __init__(self,sess):
        self.images = load_data(data_dir)
        self.augmenter = ImageAugmenter(sess)

    def process_data(self):
        processed_images = []
        for im in self.images:
            processed_images.append(self.augmenter.augment(im))
        return processed_images

def load_data(data_dir):
    # True method would read images from disk
    # This is just a mockup
    images = []
    images.append(np.random.random([1,3]))
    images.append(np.random.random([1,3]))
    return images

if __name__ == "__main__":
    TRAIN_DATA_DIR = '/some/dir/'
    sess = tf.Session()
    data_Feed = DataFeed(TRAIN_DATA_DIR,sess)
    train_data = data_Feed.process_data()
    print(train_data)

有了这个,你不会为每个图像创建一个新的会话,它应该给你你想要的.

注意如何调用sess.run();我传递给它的Feed dict的关键是上面定义的占位符张量.根据您的错误跟踪,您可能尝试从未定义im_placeholder的代码的一部分调用sess.run(),或者将其定义为tf.placeholder以外的其他内容.

此外,您可以通过更改ImageAugmenter.augment()方法以接收较低和较高参数作为tf.image.random_saturation()方法的输入来进一步改进代码,或者您可以使用特定形状初始化ImageAugmenter而不是例如,让它硬编码.

猜你在找的Python相关文章