Added proper Batch dimension support in CalculateOutputShape for Concat.

PiperOrigin-RevId: 312867341
Change-Id: I089c71c5e913d089488f80a923caa81f6f156f7b
diff --git a/tensorflow/lite/delegates/gpu/common/operations.cc b/tensorflow/lite/delegates/gpu/common/operations.cc
index bdcf6f6..8fcbe37 100644
--- a/tensorflow/lite/delegates/gpu/common/operations.cc
+++ b/tensorflow/lite/delegates/gpu/common/operations.cc
@@ -534,9 +534,10 @@
   switch (attr.axis) {
     case Axis::CHANNELS:
       for (int i = 1; i < input.size(); i++) {
-        if (input[i].h != new_shape.h || input[i].w != new_shape.w) {
+        if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
+            input[i].b != new_shape.b) {
           return absl::InvalidArgumentError(
-              "Height and Width must be the same when concatenating "
+              "Height, Width and Batch must be the same when concatenating "
               "by channels axis");
         }
         new_shape.c += input[i].c;
@@ -544,9 +545,10 @@
       break;
     case Axis::HEIGHT:
       for (int i = 1; i < input.size(); i++) {
-        if (input[i].w != new_shape.w || input[i].c != new_shape.c) {
+        if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
+            input[i].b != new_shape.b) {
           return absl::InvalidArgumentError(
-              "Channels and Width must be the same when concatenating "
+              "Channels, Width and Batch must be the same when concatenating "
               "by height axis");
         }
         new_shape.h += input[i].h;
@@ -554,14 +556,26 @@
       break;
     case Axis::WIDTH:
       for (int i = 1; i < input.size(); i++) {
-        if (input[i].h != new_shape.h || input[i].c != new_shape.c) {
+        if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
+            input[i].b != new_shape.b) {
           return absl::InvalidArgumentError(
-              "Height and Channels must be the same when concatenating "
+              "Height, Channels and Batch must be the same when concatenating "
               "by width axis");
         }
         new_shape.w += input[i].w;
       }
       break;
+    case Axis::BATCH:
+      for (int i = 1; i < input.size(); i++) {
+        if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
+            input[i].w != new_shape.w) {
+          return absl::InvalidArgumentError(
+              "Width, Height and Channels must be the same when concatenating "
+              "by batch axis");
+        }
+        new_shape.b += input[i].b;
+      }
+      break;
     default:
       return absl::InvalidArgumentError("Invalid axis");
       break;
@@ -578,9 +592,10 @@
     case Axis::CHANNELS:
       for (int i = 1; i < input.size(); ++i) {
         if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
-            input[i].d != new_shape.d) {
+            input[i].d != new_shape.d || input[i].b != new_shape.b) {
           return absl::InvalidArgumentError(
-              "Height, Width and Depth must be the same when concatenating "
+              "Height, Width, Batch and Depth must be the same when "
+              "concatenating "
               "by channels axis");
         }
         new_shape.c += input[i].c;
@@ -589,9 +604,10 @@
     case Axis::HEIGHT:
       for (int i = 1; i < input.size(); ++i) {
         if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
-            input[i].d != new_shape.d) {
+            input[i].d != new_shape.d || input[i].b != new_shape.b) {
           return absl::InvalidArgumentError(
-              "Width, Depth and Channels must be the same when concatenating "
+              "Width, Depth, Batch and Channels must be the same when "
+              "concatenating "
               "by height axis");
         }
         new_shape.h += input[i].h;
@@ -600,9 +616,10 @@
     case Axis::WIDTH:
       for (int i = 1; i < input.size(); ++i) {
         if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
-            input[i].d != new_shape.d) {
+            input[i].d != new_shape.d || input[i].b != new_shape.b) {
           return absl::InvalidArgumentError(
-              "Height, Depth and Channels must be the same when concatenating "
+              "Height, Depth, Batch and Channels must be the same when "
+              "concatenating "
               "by width axis");
         }
         new_shape.w += input[i].w;
@@ -611,14 +628,27 @@
     case Axis::DEPTH:
       for (int i = 1; i < input.size(); ++i) {
         if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
-            input[i].c != new_shape.c) {
+            input[i].c != new_shape.c || input[i].b != new_shape.b) {
           return absl::InvalidArgumentError(
-              "Width, Height and Channels must be the same when concatenating "
+              "Width, Height, Batch and Channels must be the same when "
+              "concatenating "
               "by depth axis");
         }
         new_shape.d += input[i].d;
       }
       break;
+    case Axis::BATCH:
+      for (int i = 1; i < input.size(); ++i) {
+        if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
+            input[i].c != new_shape.c || input[i].d != new_shape.d) {
+          return absl::InvalidArgumentError(
+              "Width, Height, Depth and Channels must be the same when "
+              "concatenating "
+              "by batch axis");
+        }
+        new_shape.b += input[i].b;
+      }
+      break;
     default:
       return absl::InvalidArgumentError("Invalid axis");
   }