python - 没有类型形状的TensorFlow 2.0层Tensor

我正在尝试遵循TF 2.0调用tf.keras.layers.Lambda()下面的函数。 inputsoutputs张量将是具有3个颜色通道的相同尺寸的两个图像。我的目标是从outputs张量提取遮罩,将其应用于inputs张量,然后返回结果张量。拉平张量的动机是由于tf.tensor_scatter_nd_update()函数的限制。当我构建模型时,由于updatesindices.shape[0]值,因此无法初始化None。如果我在模型外部用两个tf.constant()张量调用此层以初始化x,则它在渴望执行时运行得很好(因为x张量已定义值)。不幸的是,当我使用tf.keras.layers.Lambda()调用此函数时,出现以下错误:

TypeError: can't multiply sequence by non-int of type 'NoneType'

@tf.function
def applyMask(x):
  # Extract Tensors
  inputs = x[0]
  outputs = x[1]

  # Flatten the Outputs Tensor and Extract Mask Indices 
  outputs = tf.reshape(outputs,(tf.size(outputs),))
  indices = tf.where(outputs==1.)
  indices = tf.cast(indices, tf.int32)

  # Construct Updates Tensor from Mask Indices
  updates = tf.constant([1.]*indices.shape[0])

  # Flatten Input Tensor and Apply Mask
  out_dim = inputs.shape
  inputs = tf.reshape(inputs,(tf.size(inputs),))
  tensor = tf.tensor_scatter_nd_update(inputs, indices, updates)

  # Reconstruct Input Into Tensor
  tensor = tf.reshape(tensor, out_dim)
  return tensor

最佳答案

不需要这么复杂。简单地做,

inp1 = Input(shape=(None, None, 3)) # Inputs
inp2 = Input(shape=(None, None, 3)) # Outputs

out = Lambda(lambda x: tf.where(tf.equal(x[1], 1), x[1], x[0]))([inp1, inp2])


您甚至可以使用 heightwidth None,只要传递给 inp1inp2的并行样本完全相同(按形状), tf.where就可以正常工作。