Added Winograd selection for Metal backend.
PiperOrigin-RevId: 303212484
Change-Id: Ia19dbaf8d9a5bdf344a7f3306edfdfc18eb9c78a
diff --git a/tensorflow/lite/delegates/gpu/metal/BUILD b/tensorflow/lite/delegates/gpu/metal/BUILD
index b407083..8407803 100644
--- a/tensorflow/lite/delegates/gpu/metal/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/BUILD
@@ -32,6 +32,7 @@
"//tensorflow/lite/delegates/gpu/common:operations",
"//tensorflow/lite/delegates/gpu/common:shape",
"//tensorflow/lite/delegates/gpu/common:status",
+ "//tensorflow/lite/delegates/gpu/common:util",
"//tensorflow/lite/delegates/gpu/metal/kernels",
"//tensorflow/lite/delegates/gpu/metal/kernels:custom_registry",
],
diff --git a/tensorflow/lite/delegates/gpu/metal/api.cc b/tensorflow/lite/delegates/gpu/metal/api.cc
index 744094c..1f111cb 100644
--- a/tensorflow/lite/delegates/gpu/metal/api.cc
+++ b/tensorflow/lite/delegates/gpu/metal/api.cc
@@ -22,6 +22,7 @@
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
+#include "tensorflow/lite/delegates/gpu/common/util.h"
#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
@@ -45,6 +46,7 @@
#include "tensorflow/lite/delegates/gpu/metal/kernels/softmax.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/space_to_depth.h"
#include "tensorflow/lite/delegates/gpu/metal/kernels/transpose_conv.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/winograd.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@@ -142,10 +144,29 @@
return SpaceToDepth(id, input_id, output_id, attr);
}
+bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
+ const BHWC& dst_shape) {
+ const int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4);
+ const int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4);
+ const int src_depth = IntegralDivideRoundUp(attr.weights.shape.i, 4);
+ const int dst_depth = IntegralDivideRoundUp(attr.weights.shape.o, 4);
+ const bool suitable_attributes =
+ attr.weights.shape.w == 3 && attr.weights.shape.h == 3 &&
+ attr.dilations == HW(1, 1) && attr.strides == HW(1, 1);
+
+ const int min_depth = 16;
+ const int min_hw = 32;
+ const bool recommended_channels =
+ src_depth >= min_depth && dst_depth >= min_depth;
+ const bool recommended_hw = tiles_x * tiles_y >= min_hw;
+ return suitable_attributes && recommended_channels && recommended_hw;
+}
+
absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
const std::vector<ValueId>& inputs,
const std::vector<ValueId>& outputs,
const RuntimeOptions& options,
+ int* last_node_id, int* last_value_id,
std::vector<ComputeTaskDescriptorPtr>* tasks) {
if (!IsBatchMatchesForAllValues(graph)) {
return absl::InvalidArgumentError(
@@ -185,8 +206,35 @@
const auto dst_shape = graph.FindOutputs(node_id)[0]->tensor.shape;
auto attr =
absl::any_cast<Convolution2DAttributes>(node->operation.attributes);
- *tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape,
- attr, options);
+ if (IsSuitableForWinograd4x4To6x6(attr, dst_shape)) {
+ int tiles_x = IntegralDivideRoundUp(dst_shape.w, 4);
+ int tiles_y = IntegralDivideRoundUp(dst_shape.h, 4);
+
+ Winograd4x4To36Attributes wino_up_attr;
+ wino_up_attr.padding = attr.padding;
+ (*last_node_id) += 1;
+ int value_id = *last_value_id + 1;
+ *tasks =
+ Winograd4x4To36(*last_node_id, inputs[0], value_id, wino_up_attr);
+
+ BHWC conv_shape{dst_shape.b, 36, tiles_x * tiles_y, dst_shape.c};
+ (*last_node_id) += 1;
+ auto t1 = ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1,
+ conv_shape, attr, options);
+ tasks->insert(tasks->end(), t1.begin(), t1.end());
+
+ Winograd36To4x4Attributes wino_down_attr;
+ wino_down_attr.output_shape = dst_shape;
+ wino_down_attr.biases = attr.bias;
+ (*last_node_id) += 1;
+ auto t2 = Winograd36To4x4(*last_node_id, value_id + 1, outputs[0],
+ options, wino_down_attr);
+ tasks->insert(tasks->end(), t2.begin(), t2.end());
+ (*last_value_id) += 2;
+ } else {
+ *tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape,
+ attr, options);
+ }
break;
}
case OperationType::CONVOLUTION_TRANSPOSED:
@@ -342,6 +390,14 @@
absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
CompiledModel* compiled_model) {
+ int last_node_id = 0;
+ for (const auto& node : graph.nodes()) {
+ last_node_id = std::max(last_node_id, static_cast<int>(node->id));
+ }
+ int last_value_id = 0;
+ for (const auto& value : graph.values()) {
+ last_value_id = std::max(last_value_id, static_cast<int>(value->id));
+ }
for (const auto& node : graph.nodes()) {
std::vector<ValueId> inputs;
for (auto& input : graph.FindInputs(node->id)) {
@@ -356,7 +412,8 @@
RegisterCustomOps(graph, node, inputs, outputs, options, &tasks);
if (!custom_status.ok()) {
auto primary_status =
- RegisterPrimaryOps(graph, node, inputs, outputs, options, &tasks);
+ RegisterPrimaryOps(graph, node, inputs, outputs, options,
+ &last_node_id, &last_value_id, &tasks);
if (!primary_status.ok()) {
return absl::UnimplementedError(
absl::Substitute("Unsupported op type: $0; custom registry error: "
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index 5130415..b96f6b95 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -37,6 +37,7 @@
":softmax",
":space_to_depth",
":transpose_conv",
+ ":winograd",
],
)