Fixed accumulator precision for generic DepthWise implementation.
Removed inlined constants for kernel sizes.
Added test.

PiperOrigin-RevId: 304026983
Change-Id: I4f9eac57ba1ec4e6f929d3ab7c9176f0d6f3b4ce
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc
index 9fa627b..6c26a87 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc
@@ -475,91 +475,93 @@
   std::string shader_source = R"(
     #include <metal_stdlib>
     using namespace metal;
-    constant int kernel_x = $0;
-    constant int kernel_y = $1;
     struct uniforms {
-      int4 stride;
-      int4 padding;
-      int4 dilation;
-      int4 size;
+      int4 src_size;
+      int4 dst_size;
+      int2 stride;
+      int2 padding;
+      int2 dilation;
+      int2 kernel_size;
       int4 channel_multiplier;
     };
-    $$0
+    $0
     kernel void ComputeFunction(
-                                $$1
+                                $1
                                 uint tid[[thread_index_in_threadgroup]],
                                 uint3 gid[[thread_position_in_grid]]) {
-      const bool outside = static_cast<int>(gid.x) >= params.size.z ||
-        static_cast<int>(gid.y) >= params.size.w;
-      if (outside) {
-        return;
-      }
-      device FLT4* temp = filters + gid.z * kernel_y * kernel_x;
-      float4 sum0 = float4(0.0f, 0.0f, 0.0f, 0.0f);
+      int dst_x = static_cast<int>(gid.x);
+      int dst_y = static_cast<int>(gid.y);
+      int dst_z = static_cast<int>(gid.z);
 
-      for(int ky = 0; ky < kernel_y; ++ky) {
-        for(int kx = 0; kx < kernel_x; ++kx) {
-          int2 coords  = int2(gid.xy) * params.stride.xy + int2(kx, ky) * params.dilation.xy -
-            params.padding.xy;
-          const bool outside = coords.x < 0 || coords.y < 0 ||
-            coords.x >= params.size.x || coords.y >= params.size.y;
-          if (outside) continue;
+      if (dst_x >= U.dst_size.x || dst_y >= U.dst_size.y) return;
+
+      device FLT4* temp = filters + dst_z * U.kernel_size.x * U.kernel_size.y;
+      ACCUM_FLT4 sum0 = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);
+
+      int src_x = dst_x * U.stride.x + U.padding.x;
+      int src_y = dst_y * U.stride.y + U.padding.y;
+
+      for(int ky = 0; ky < U.kernel_size.y; ++ky) {
+        int yc = ky * U.dilation.y + src_y;
+        if (yc < 0 || yc >= U.src_size.y) continue;
+        for(int kx = 0; kx < U.kernel_size.x; ++kx) {
+          int xc = kx * U.dilation.x + src_x;
+          if (xc < 0 || xc >= U.src_size.x) continue;
 )";
   if (channels_multiplier == 1) {
     shader_source += R"(
-        const int src_layer = gid.z;
-        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
-        const FLT4 src_modified = src_buffer[src_index];
+        int src_layer = dst_z;
+        int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc;
+        FLT4 src_modified = src_buffer[src_index];
 )";
   } else if (channels_multiplier == 2) {
     shader_source += R"(
-        const int src_layer = gid.z / 2;
-        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
-        const FLT4 src = src_buffer[src_index];
-        const FLT2 t0 = gid.z % 2 == 0 ? src.xy : src.zw;
-        const FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y);
+        int src_layer = dst_z / 2;
+        int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc;
+        FLT4 src = src_buffer[src_index];
+        FLT2 t0 = dst_z % 2 == 0 ? src.xy : src.zw;
+        FLT4 src_modified = FLT4(t0.x, t0.x, t0.y, t0.y);
 )";
   } else if (channels_multiplier == 4) {
     shader_source += R"(
-        const int src_layer = gid.z / 4;
-        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
-        const FLT4 src = src_buffer[src_index];
-        const FLT t0 = src[gid.z % 4];
-        const FLT4 src_modified = FLT4(t0, t0, t0, t0);
+        int src_layer = dst_z / 4;
+        int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc;
+        FLT4 src = src_buffer[src_index];
+        FLT t0 = src[dst_z % 4];
+        FLT4 src_modified = FLT4(t0, t0, t0, t0);
 )";
   } else {
     shader_source += R"(
-        const int src_layer = gid.z / params.channel_multiplier.x;
-        const int src_index = (src_layer * params.size.y + coords.y) * params.size.x + coords.x;
-        const FLT4 src = src_buffer[src_index];
+        int src_layer = dst_z / U.channel_multiplier.x;
+        int src_index = (src_layer * U.src_size.y + yc) * U.src_size.x + xc;
+        FLT4 src = src_buffer[src_index];
         FLT4 src_modified;
-        const int src_layer_offset = (gid.z % params.channel_multiplier.x) * 4;
-        src_modified.x = src[(src_layer_offset + 0) / params.channel_multiplier.x];
-        src_modified.y = src[(src_layer_offset + 1) / params.channel_multiplier.x];
-        src_modified.z = src[(src_layer_offset + 2) / params.channel_multiplier.x];
-        src_modified.w = src[(src_layer_offset + 3) / params.channel_multiplier.x];
+        const int src_layer_offset = (dst_z % U.channel_multiplier.x) * 4;
+        src_modified.x = src[(src_layer_offset + 0) / U.channel_multiplier.x];
+        src_modified.y = src[(src_layer_offset + 1) / U.channel_multiplier.x];
+        src_modified.z = src[(src_layer_offset + 2) / U.channel_multiplier.x];
+        src_modified.w = src[(src_layer_offset + 3) / U.channel_multiplier.x];
 )";
   }
   shader_source += R"(
-          sum0 += float4(src_modified * temp[ky * kernel_x + kx]);
+          sum0 += TO_ACCUM4_TYPE(src_modified * temp[ky * U.kernel_size.x + kx]);
         }
       }
-      FLT4 res = FLT4(sum0 + float4(biases[gid.z]));
-      const int linear_index = (gid.z * params.size.w + int(gid.y)) * params.size.z + int(gid.x);
+      FLT4 res = FLT4(sum0) + biases[dst_z];
+      const int linear_index = (dst_z * U.dst_size.y + dst_y) * U.dst_size.x + dst_x;
       FLT4 value = res;
-      $$2
-      output_buffer[linear_index] = value;
+      $2
+      dst_buffer[linear_index] = value;
     }
   )";
-  desc->shader_source = absl::Substitute(shader_source, attr.weights.shape.w,
-                                         attr.weights.shape.h);
+  desc->shader_source = shader_source;
 
   desc->input_buffers = {
       {input_id, "device FLT4* const src_buffer"},
   };
 
   desc->output_buffer = {
-      output_id, "device FLT4* output_buffer",
+      output_id, "device FLT4* dst_buffer",
       [input_id, attr](const std::map<ValueId, BHWC>& buffers) {
         auto out_shape =
             CalculateOutputShape(buffers.find(input_id)->second, attr);
@@ -577,27 +579,27 @@
   };
 
   desc->uniform_buffers = {
-      {"constant uniforms& params",
+      {"constant uniforms& U",
        [input_id, output_id, attr](const std::map<ValueId, BHWC>& buffers) {
          const auto& dimension = buffers.find(input_id)->second;
          const auto& output_dimension = buffers.find(output_id)->second;
          std::vector<int> uniform_params{
-             attr.strides.w,
-             attr.strides.h,
-             1,
-             1,
-             attr.padding.prepended.w,
-             attr.padding.prepended.h,
-             1,
-             1,
-             attr.dilations.w,
-             attr.dilations.h,
-             1,
-             1,
              dimension.w,
              dimension.h,
+             IntegralDivideRoundUp(dimension.c, 4),
+             0,
              output_dimension.w,
              output_dimension.h,
+             IntegralDivideRoundUp(output_dimension.c, 4),
+             0,
+             attr.strides.w,
+             attr.strides.h,
+             -attr.padding.prepended.w,
+             -attr.padding.prepended.h,
+             attr.dilations.w,
+             attr.dilations.h,
+             attr.weights.shape.w,
+             attr.weights.shape.h,
              attr.weights.shape.o,
              0,
              0,
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm
index d765072..dcf550f 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm
@@ -167,4 +167,43 @@
   XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
 }
 
+- (void)testShape2x2Kernel2x2 {
+  TensorRef<BHWC> input;
+  input.type = DataType::FLOAT32;
+  input.ref = 0;
+  input.shape = BHWC(1, 2, 2, 1);
+
+  DepthwiseConvolution2DAttributes attr;
+  Tensor<Linear, DataType::FLOAT32> bias;
+  bias.shape.v = 1;
+  bias.id = 1;
+  bias.data = {0};
+  attr.bias = std::move(bias);
+
+  Tensor<OHWI, DataType::FLOAT32> weights;
+  weights.shape = OHWI(1, 2, 2, 1);
+  weights.id = 1;
+  weights.data = {1, 2, 3, 4};
+
+  attr.weights = std::move(weights);
+
+  attr.dilations = HW(1, 1);
+  attr.padding.prepended = HW(0, 0);
+  attr.padding.appended = HW(1, 1);
+  attr.strides = HW(1, 1);
+
+  TensorRef<BHWC> output;
+  output.type = DataType::FLOAT32;
+  output.ref = 3;
+  output.shape = BHWC(1, 2, 2, 1);
+
+  SingleOpModel model({ToString(OperationType::DEPTHWISE_CONVOLUTION), std::move(attr)}, {input},
+                      {output});
+  XCTAssertTrue(model.PopulateTensor(0, {1, 4, 9, 16}));
+  auto status = model.Invoke();
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+  status = CompareVectors({100, 52, 41, 16}, model.GetOutput(0), 1e-6f);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
 @end