Added proper Batch dimension support in CalculateOutputShape for Concat.
PiperOrigin-RevId: 312867341
Change-Id: I089c71c5e913d089488f80a923caa81f6f156f7b
diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc
index bdcf6f6..8fcbe37 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.cc
+++ b/tensorflow/lite/delegates/gpu/common/operations.cc
@@ -534,9 +534,10 @@
switch (attr.axis) {
case Axis::CHANNELS:
for (int i = 1; i < input.size(); i++) {
- if (input[i].h != new_shape.h || input[i].w != new_shape.w) {
+ if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
+ input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
- "Height and Width must be the same when concatenating "
+ "Height, Width and Batch must be the same when concatenating "
"by channels axis");
}
new_shape.c += input[i].c;
@@ -544,9 +545,10 @@
break;
case Axis::HEIGHT:
for (int i = 1; i < input.size(); i++) {
- if (input[i].w != new_shape.w || input[i].c != new_shape.c) {
+ if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
+ input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
- "Channels and Width must be the same when concatenating "
+ "Channels, Width and Batch must be the same when concatenating "
"by height axis");
}
new_shape.h += input[i].h;
@@ -554,14 +556,26 @@
break;
case Axis::WIDTH:
for (int i = 1; i < input.size(); i++) {
- if (input[i].h != new_shape.h || input[i].c != new_shape.c) {
+ if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
+ input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
- "Height and Channels must be the same when concatenating "
+ "Height, Channels and Batch must be the same when concatenating "
"by width axis");
}
new_shape.w += input[i].w;
}
break;
+ case Axis::BATCH:
+ for (int i = 1; i < input.size(); i++) {
+ if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
+ input[i].w != new_shape.w) {
+ return absl::InvalidArgumentError(
+ "Width, Height and Channels must be the same when concatenating "
+ "by batch axis");
+ }
+ new_shape.b += input[i].b;
+ }
+ break;
default:
return absl::InvalidArgumentError("Invalid axis");
break;
@@ -578,9 +592,10 @@
case Axis::CHANNELS:
for (int i = 1; i < input.size(); ++i) {
if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
- input[i].d != new_shape.d) {
+ input[i].d != new_shape.d || input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
- "Height, Width and Depth must be the same when concatenating "
+ "Height, Width, Batch and Depth must be the same when "
+ "concatenating "
"by channels axis");
}
new_shape.c += input[i].c;
@@ -589,9 +604,10 @@
case Axis::HEIGHT:
for (int i = 1; i < input.size(); ++i) {
if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
- input[i].d != new_shape.d) {
+ input[i].d != new_shape.d || input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
- "Width, Depth and Channels must be the same when concatenating "
+ "Width, Depth, Batch and Channels must be the same when "
+ "concatenating "
"by height axis");
}
new_shape.h += input[i].h;
@@ -600,9 +616,10 @@
case Axis::WIDTH:
for (int i = 1; i < input.size(); ++i) {
if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
- input[i].d != new_shape.d) {
+ input[i].d != new_shape.d || input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
- "Height, Depth and Channels must be the same when concatenating "
+ "Height, Depth, Batch and Channels must be the same when "
+ "concatenating "
"by width axis");
}
new_shape.w += input[i].w;
@@ -611,14 +628,27 @@
case Axis::DEPTH:
for (int i = 1; i < input.size(); ++i) {
if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
- input[i].c != new_shape.c) {
+ input[i].c != new_shape.c || input[i].b != new_shape.b) {
return absl::InvalidArgumentError(
- "Width, Height and Channels must be the same when concatenating "
+ "Width, Height, Batch and Channels must be the same when "
+ "concatenating "
"by depth axis");
}
new_shape.d += input[i].d;
}
break;
+ case Axis::BATCH:
+ for (int i = 1; i < input.size(); ++i) {
+ if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
+ input[i].c != new_shape.c || input[i].d != new_shape.d) {
+ return absl::InvalidArgumentError(
+ "Width, Height, Depth and Channels must be the same when "
+ "concatenating "
+ "by batch axis");
+ }
+ new_shape.b += input[i].b;
+ }
+ break;
default:
return absl::InvalidArgumentError("Invalid axis");
}