我想使用from_generator()函数创建一些tf.data.Dataset.我想向生成器函数(raw_data_gen)发送一个参数.这个想法是生成器函数将根据发送的参数产生不同的数据.通过这种方式,我希望raw_data_gen能够提供培训,验证或测试数据.
@H_403_6@training_dataset = tf.data.Dataset.from_generator(raw_data_gen,(tf.float32,tf.uint8),([None,1],[None]),args=([1])) validation_dataset = tf.data.Dataset.from_generator(raw_data_gen,args=([2])) test_dataset = tf.data.Dataset.from_generator(raw_data_gen,args=([3]))
我尝试以这种方式调用from_generator()时得到的错误消息是:
@H_403_6@TypeError: from_generator() got an unexpected keyword argument 'args'
这是raw_data_gen函数,虽然我不确定你是否需要这个,因为我的预感是问题是调用from_generator():
@H_403_6@def raw_data_gen(train_val_or_test): if train_val_or_test == 1: #For every filename collected in the list for filename,lab in training_filepath_label_dict.items(): raw_data,samplerate = soundfile.read(filename) try: #assume the audio is stereo,ready to be sliced raw_data = raw_data[:,0] #raw_data is a np.array,just take first channel with slice except IndexError: pass #this must be mono audio yield raw_data,lab elif train_val_or_test == 2: #For every filename collected in the list for filename,lab in validation_filepath_label_dict.items(): raw_data,lab elif train_val_or_test == 3: #For every filename collected in the list for filename,lab in test_filepath_label_dict.items(): raw_data,lab else: print("generator function called with an argument not in [1,2,3]") raise ValueError()
最佳答案
原文链接:https://www.f2er.com/python/438922.html