input_tensors在tf.keras.models.clone_model中的作用

我正在尝试复制现有的keras模型。以下是我创建的示例代码,看来它可以按预期工作。

model = CreateSimpleModel()
model.compile(loss="sparse_categorical_crossentropy",optimizer="adam",metrics=["accuracy"])

model.summary()


model_cloned = tf.keras.models.clone_model(model)
model_cloned.set_weights(model.get_weights())

print(model(np.array([[1,2]])))
print(model_cloned(np.array([[1,2]])))

但是,如果我们在下一页中查看有关tf.keras.models.clone_model的官方文档,则会有一个名为input_tensors的参数。

https://www.tensorflow.org/api_docs/python/tf/keras/models/clone_model

我不太确定此参数的作用。从上面的示例代码中,我不太清楚为什么在某些情况下需要使用它。有人可以举例说明吗?

huayuan1005 回答:input_tensors在tf.keras.models.clone_model中的作用

编辑:请不要做我在下面所做的事情。使用 GradCAM 查看第一个卷积层中使用的权重后,似乎 input_tensors 参数没有任何影响,尽管输入发生变化,但 base_model 的所有克隆都具有相同的权重。

就我而言,我使用 tf.keras.models.clone_model 来克隆基础的预训练神经网络,以便我的所有多个输入都有自己的路径:

# inputs:
x1 = tf.keras.layers.Input(shape=(None,None,3),name="x1")
x2 = tf.keras.layers.Input(shape=(None,name="x2")
x3 = tf.keras.layers.Input(shape=(None,name="x3")

# load base model:
base_model = tf.keras.applications.DenseNet169(input_tensor=x1,input_shape=(224,224,include_top=False,pooling='avg')

# create copies:
base_model2 = tf.keras.models.clone_model(base_model,input_tensors=x2)
base_model3 = tf.keras.models.clone_model(base_model,input_tensors=x3)

# you have to rename the layers in each model so there aren't any conflicts:
cnt = 0
for mod in [base_model1,base_model2,base_model3,base_model4,base_model5,base_model6]:
    cnt += 1
    for layer in mod.layers:
        old_name = layer.name
        layer._name = f"base_model{cnt}_{old_name}"

# this bit isn't necessary unless you want to access weights easily later on:
base1_out = base_model.output
base2_out = base_model2.output
base3_out = base_model3.output

# concatenate the outputs:
concatenated = tf.keras.layers.concatenate([base1_out,base2_out,base3_out],axis=-1)

# add dense layers if you want:
concat_dense = tf.keras.layers.Dense(2048)(concatenated)
out = tf.keras.layers.Dense(class_count,activation='softmax')(concat_dense)

tf.keras.models.Model(inputs=[x1,x2,x3],outputs=[out])

请注意,我的输入(以字典的形式)来自使用 TensorFlow 的 tf.data.Dataset 创建的符号张量。

本文链接:https://www.f2er.com/2782212.html

大家都在问