Added Winograd selection for Metal backend.

PiperOrigin-RevId: 303212484
Change-Id: Ia19dbaf8d9a5bdf344a7f3306edfdfc18eb9c78a
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index b407083..8407803 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -32,6 +32,7 @@
         "//tensorflow/lite/delegates/gpu/common:operations",
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:status",
+        "//tensorflow/lite/delegates/gpu/common:util",
         "//tensorflow/lite/delegates/gpu/metal/kernels",
         "//tensorflow/lite/delegates/gpu/metal/kernels:custom_registry",
     ],
diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc
index 744094c..1f111cb 100644
--- a/tensorflow/lite/delegates/gpu/metal/api.cc
+++ b/tensorflow/lite/delegates/gpu/metal/api.cc
@@ -22,6 +22,7 @@
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
 #include "tensorflow/lite/delegates/gpu/common/shape.h"
 #include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
 #include "tensorflow/lite/delegates/gpu/metal/environment.h"
@@ -45,6 +46,7 @@
 #include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
 #include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
 
 namespace tflite {
@@ -142,10 +144,29 @@
   return SpaceToDepth(id, input_id, output_id, attr);
 }
 
+bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
+                                   const BHWC& dst_shape) {
+  const int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4);
+  const int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4);
+  const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
+  const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
+  const bool suitable_attributes =
+      attr.weights.shape.w == 3 && attr.weights.shape.h == 3 &&
+      attr.dilations == HW(1, 1) && attr.strides == HW(1, 1);
+
+  const int min_depth = 16;
+  const int min_hw = 32;
+  const bool recommended_channels =
+      src_depth >= min_depth && dst_depth >= min_depth;
+  const bool recommended_hw = tiles_x * tiles_y >= min_hw;
+  return suitable_attributes && recommended_channels && recommended_hw;
+}
+
 absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
                                 const std::vector<ValueId>& inputs,
                                 const std::vector<ValueId>& outputs,
                                 const RuntimeOptions& options,
+                                int* last_node_id, int* last_value_id,
                                 std::vector<ComputeTaskDescriptorPtr>* tasks) {
   if (!IsBatchMatchesForAllValues(graph)) {
     return absl::InvalidArgumentError(
@@ -185,8 +206,35 @@
       const auto dst_shape = graph.FindOutputs(node_id)[0]->tensor.shape;
       auto attr =
           absl::any_cast<Convolution2DAttributes>(node->operation.attributes);
-      *tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape,
-                                  attr, options);
+      if (IsSuitableForWinograd4x4To6x6(attr, dst_shape)) {
+        int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4);
+        int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4);
+
+        Winograd4x4To36Attributes wino_up_attr;
+        wino_up_attr.padding = attr.padding;
+        (*last_node_id) += 1;
+        int value_id = *last_value_id + 1;
+        *tasks =
+            Winograd4x4To36(*last_node_id, inputs[0], value_id, wino_up_attr);
+
+        BHWC conv_shape{dst_shape.b, 36, tiles_x * tiles_y, dst_shape.c};
+        (*last_node_id) += 1;
+        auto t1 = ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1,
+                                          conv_shape, attr, options);
+        tasks->insert(tasks->end(), t1.begin(), t1.end());
+
+        Winograd36To4x4Attributes wino_down_attr;
+        wino_down_attr.output_shape = dst_shape;
+        wino_down_attr.biases = attr.bias;
+        (*last_node_id) += 1;
+        auto t2 = Winograd36To4x4(*last_node_id, value_id + 1, outputs[0],
+                                  options, wino_down_attr);
+        tasks->insert(tasks->end(), t2.begin(), t2.end());
+        (*last_value_id) += 2;
+      } else {
+        *tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape,
+                                    attr, options);
+      }
       break;
     }
     case OperationType::CONVOLUTION_TRANSPOSED:
@@ -342,6 +390,14 @@
 
 absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
                      CompiledModel* compiled_model) {
+  int last_node_id = 0;
+  for (const auto& node : graph.nodes()) {
+    last_node_id = std::max(last_node_id, static_cast<int>(node->id));
+  }
+  int last_value_id = 0;
+  for (const auto& value : graph.values()) {
+    last_value_id = std::max(last_value_id, static_cast<int>(value->id));
+  }
   for (const auto& node : graph.nodes()) {
     std::vector<ValueId> inputs;
     for (auto& input : graph.FindInputs(node->id)) {
@@ -356,7 +412,8 @@
         RegisterCustomOps(graph, node, inputs, outputs, options, &tasks);
     if (!custom_status.ok()) {
       auto primary_status =
-          RegisterPrimaryOps(graph, node, inputs, outputs, options, &tasks);
+          RegisterPrimaryOps(graph, node, inputs, outputs, options,
+                             &last_node_id, &last_value_id, &tasks);
       if (!primary_status.ok()) {
         return absl::UnimplementedError(
             absl::Substitute("Unsupported op type: $0; custom registry error: "
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index 5130415..b96f6b95 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -37,6 +37,7 @@
         ":softmax",
         ":space_to_depth",
         ":transpose_conv",
+        ":winograd",
     ],
 )