Dataset是Tensorflow里面一个比较重要的概念,我们知道机器学习算法需要大概的数据来训练data model. 所以Dataset就是用来做这么一件重要的事情:定义数据pipline,为学习算法提供训练数据。
其实我们也可以将Dataset理解成一个数据源,指向某些包含训练数据的文件列表,或者是内存里面已有的数据结构(比如Tensor objects)。
Dataset 数据结构
组成Dataset的基本单元是element。每个element必需有相同的数据结构,其中每个element包含多个Tensor objects。比如:
# 创建一个dataset,里面包含一个2-Dimension (4x10) Tensor对象
dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4,10]))
# 创建一个dataset,里面包含两个Tensor,tensor1的shape为(4x3),tensor2的shape为(4x5)
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4,3]),tf.random_uniform([4,5])))
创建Dataset
前面说了Dataset可以理解成数据源, 那么怎么创建一个Dataset并使它跟多个数据源关联呢?Tensorflow Dataset API提供了两种方式:
从已有的一个或者多个Tensors对象中创建
上一节的Dataset.from_tensor_slices()就是这用这种方式创建的Dataset
利用这种方式,同样地可以创建指向训练数据文件的Dataset,比如我们让每个element包含两个Tensor,第一个Tensor指向一堆汽车的图片文件,另外一个vector tensor表示对应的图片是否为一辆卡车:train_imgs = tf.constant(['train/img1.png','train/img2.png','train/img3.png','train/img4.png','train/img5.png','train/img6.png']) train_labels = tf.constant([0,0,1,1]) tr_data = Dataset.from_tensor_slices((train_imgs,train_labels))
这样dataset里面的每一个element其实就是一个tuple,包含了(feature,label)
对已有的Dataset进行转换(transformation),比如batch(),map(),filter(),后面会再介绍这些常用的API
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4,10])) dataset2 = dataset1.batch(10)
dataset2就是用这里介绍的第二种方法创建的。
读取Dataset
从前面Dataset的定义以及结构可以看出,Dataset其实是对Tensor提供了一层封装,而Tensor又是对真实的训练数据的封装,这些数据可能是一个N-Dimension matrix,或者是指向一批数据文件的向量。其实我们可以会问为什么要设计的这么复杂,又是matrix,又是Tensor的,直接用Tensor/Matrix的API来读取训练数据不就行了么? 我觉得可以从下面几个方向来思考:
- 在训练我们的model的时候,需要把训练数据input到我们的算法model中。但有时候训练数据不是说只有几百条,而是成千上万的,这样如果直接把这些数据load到内存中的Tensor肯定是吃不消的,所以需要一种数据结构让算法能够批量地从disk中分批读取,然后用它们来训练我们的model, Dataset正是提供这种机制(transformation)来满足这方面的需求。
- 相比Tensor,Dataset对训练数据的读取更加灵活。当我们用常用的梯度下降算法来minimize我们的cost function时,需要不断地调整parameter的数值从而使cost不断地下降。这是一个迭代过程,每个迭代都需要读取不同batch size的训练数据来计算cost。Dataset提供了一些丰富的API可以读取不同batch size的数据。
回到正题,Dataset提供Iterator.get_next() API来读取它的每一个element,这个element包含一个或者多个我们需要的Tensor objects。
至于每次调用get_next()返回多少个element,则取决于batch size的大小。或者你可以认为batch size就是决定每次读取多少个训练数据,一个训练数据就是一个element。
Iterator的调用步骤:
定义一个Dataset
dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4,5])) #dataset = dataset.batch(2)
定义一个Iterator
iterator = dataset.make_initializable_iterator() next_element = iterator.get_next()
初始化Iterator (one shot iterator 除外), 如果有parameter需要初始化,将初始化的值传递给Feed_dict
sess.run(iterator.initializer, Feed_dict={...})
用Iterator读取数据
sess.run(next_element) # output [ 0.58478916 0.3431859 0.23752177 0.19337153 0.05314612]
如果将dataset的batch size定义成2,那么next element将会包含两个数组:
sess.run(next_element) #output [[ 0.38093257 0.31324649 0.16414177 0.84969711 0.40212131] [ 0.18354928 0.55987918 0.09232235 0.98887277 0.21049285]]
这里需要特别提一下one shot iterator,它每次只读取一个element,而且 这种Iterator不需要初始化,也就是上面的第3步不需要显式地调用。但是只有当Dataset不包含任何参数时才可以为它创建one shot iterator, 前面例子里的Dataset都不能创建one shot iterator。
你可以这样来创建one shot iterator:
dataset2 = tf.data.Dataset.from_tensor_slices(tf.constant([[1,2,3],[2,4,6],[3,6,9]]))
iter2 = dataset2.make_one_shot_iterator()
用Dataset读取文件
前面的例子里很多的都是从Ternsor对象中创建Dataset, 所以用Iterator读取到的可能是一些常量数据,比如文件名,数组之类的。但是在真实的世界中,训练数据都是存放在文件中的,比如CSV,JPG,所以我们关心的其实并不是这些文件名本身,还是其中的内容。那么如果我的Tensor中存放的是一些文件名字,怎么用Dataset来读取其中的数据呢?
Dataset提供了一个数据预处理的API map()。 预处理的意思是可以对每一个element进行transformation,Iterator的get_next()拿到的可能是一个字符串代表某个文件名或者CSV文件里的一行,然后transformation的时候将这个文件的内容读取出来并保存在内存的Tensor对象。
读取文本文件@H_301_206@
这里用TextLineDataset读取csv文件:
def @H_502_215@readTextFile(filename):
_CSV_COLUMN_DEFAULTS = [[1],[0],[''],['']]
_CSV_COLUMNS = [
'age','workclass','education','education_num','marital_status','occupation','income_bracket'
]
dataset = tf.data.TextLineDataset(filename)
iterator = dataset.make_one_shot_iterator()
textline = iterator.get_next()
with tf.Session() as sess:
print(textline.eval())
# convert text to list of tensors for each column
def @H_502_215@parseCSVLine(value):
columns = tf.decode_csv(value,_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS,columns))
return features
dataset2 = dataset.map(parseCSVLine)
iterator2 = dataset2.make_one_shot_iterator()
textline2 = iterator2.get_next()
with tf.Session() as sess:
print(textline2)
这里parseCSVLine 将从csv读取到的每一行进行decode 处理(tf.decode_csv), 从而将每一列转成对应的Tensor object。