Added support of 3x2 padding in model builder.
PiperOrigin-RevId: 314350524
Change-Id: I151c40662b94b02fe648d41da168f851f08c0046
diff --git a/tensorflow/lite/delegates/gpu/common/model_builder.cc b/tensorflow/lite/delegates/gpu/common/model_builder.cc
index 475f2ba..c110d46 100644
--- a/tensorflow/lite/delegates/gpu/common/model_builder.cc
+++ b/tensorflow/lite/delegates/gpu/common/model_builder.cc
@@ -1330,9 +1330,11 @@
"Invalid paddings tensor dimension: expected 2 dim, got ",
pad_tensor->dims->size, " dim"));
}
- if (pad_tensor->dims->data[0] != 4 || pad_tensor->dims->data[1] != 2) {
+ bool supported =
+ pad_tensor->dims->data[0] == 3 || pad_tensor->dims->data[0] == 4;
+ if (!supported || pad_tensor->dims->data[1] != 2) {
return absl::InvalidArgumentError(absl::StrCat(
- "Invalid paddings tensor shape: expected 4x2, got ",
+ "Invalid paddings tensor shape: expected 4x2 or 3x2, got ",
pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1]));
}
return absl::OkStatus();
@@ -1356,16 +1358,23 @@
Tensor<HW, DataType::INT32> paddings;
RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
- // 4x2 tensor with paddings.
- if (paddings.shape.h != 4 || paddings.shape.w != 2) {
+ if (paddings.shape.h == 4 && paddings.shape.w == 2) {
+ // 4x2 tensor with paddings.
+ attr.prepended = BHWC(paddings.data[0], paddings.data[2],
+ paddings.data[4], paddings.data[6]);
+ attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
+ paddings.data[7]);
+ } else if (paddings.shape.h == 3 && paddings.shape.w == 2) {
+ // 3x2 tensor with paddings.
+ attr.prepended =
+ BHWC(1, paddings.data[0], paddings.data[2], paddings.data[4]);
+ attr.appended =
+ BHWC(1, paddings.data[1], paddings.data[3], paddings.data[5]);
+ } else {
// It shouldn't fail here since it's checked at IsSupported().
return absl::InvalidArgumentError(
"Paddings tensor has unexpected shape.");
}
- attr.prepended = BHWC(paddings.data[0], paddings.data[2], paddings.data[4],
- paddings.data[6]);
- attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
- paddings.data[7]);
node->operation.attributes = attr;
return absl::OkStatus();
}