Metal DepthWiseConvolution kernels converted from ComputeTaskDescriptor to GPUOperation specializations.
Tests updated.

PiperOrigin-RevId: 354617532
Change-Id: I0726385a558bf7c7d77766fce538702223b936ac
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
index 3e03f92..957f7fc 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/BUILD
@@ -136,6 +136,7 @@
         "//tensorflow/lite/delegates/gpu/common:shape",
         "//tensorflow/lite/delegates/gpu/common:types",
         "//tensorflow/lite/delegates/gpu/common:util",
+        "//tensorflow/lite/delegates/gpu/common/task:gpu_operation",
         "//tensorflow/lite/delegates/gpu/metal:compute_task_descriptor",
         "@com_google_absl//absl/strings",
     ],
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc
index 4dea617..3e421cf 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.cc
@@ -363,7 +363,11 @@
 
 }  // namespace
 
-ComputeTaskDescriptor DepthWiseConvolution(
+int3 DepthWiseConvolution::GetGridSize() const {
+  return int3(dst_[0]->Width(), dst_[0]->Height(), dst_[0]->Slices());
+}
+
+DepthWiseConvolution CreateDepthWiseConvolution(
     const OperationDef& definition,
     const DepthwiseConvolution2DAttributes& attr) {
   int channels_multiplier = attr.weights.shape.o;
@@ -429,20 +433,20 @@
   args.dst_tensor.Write(res, dst_x, dst_y, dst_z);
 }
 )";
-  ComputeTaskDescriptor desc(definition);
-  desc.shader_source = shader_source;
+  DepthWiseConvolution desc(definition);
+  desc.code_ = shader_source;
   desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
   desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
 
-  desc.args.AddInt("padding_x", -attr.padding.prepended.w);
-  desc.args.AddInt("padding_y", -attr.padding.prepended.h);
-  desc.args.AddInt("dilation_x", attr.dilations.w);
-  desc.args.AddInt("dilation_y", attr.dilations.h);
-  desc.args.AddInt("stride_x", attr.strides.w);
-  desc.args.AddInt("stride_y", attr.strides.h);
-  desc.args.AddInt("kernel_size_x", attr.weights.shape.w);
-  desc.args.AddInt("kernel_size_y", attr.weights.shape.h);
-  desc.args.AddInt("channel_multiplier", attr.weights.shape.o);
+  desc.args_.AddInt("padding_x", -attr.padding.prepended.w);
+  desc.args_.AddInt("padding_y", -attr.padding.prepended.h);
+  desc.args_.AddInt("dilation_x", attr.dilations.w);
+  desc.args_.AddInt("dilation_y", attr.dilations.h);
+  desc.args_.AddInt("stride_x", attr.strides.w);
+  desc.args_.AddInt("stride_y", attr.strides.h);
+  desc.args_.AddInt("kernel_size_x", attr.weights.shape.w);
+  desc.args_.AddInt("kernel_size_y", attr.weights.shape.h);
+  desc.args_.AddInt("channel_multiplier", attr.weights.shape.o);
 
   auto data_type = DeduceDataTypeFromPrecision(definition.precision);
   const int output_channels_count = attr.weights.shape.i * attr.weights.shape.o;
@@ -453,7 +457,7 @@
   weights_desc.data =
       GetByteBufferConverted(ConvertToPIOHW4(attr.weights), data_type);
   weights_desc.size = weights_desc.data.size();
-  desc.args.AddObject(
+  desc.args_.AddObject(
       "weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
 
   BufferDescriptor bias_desc;
@@ -462,31 +466,43 @@
   bias_desc.data =
       GetByteBufferConvertedResized(attr.bias.data, data_type, dst_ch_aligned);
   bias_desc.size = bias_desc.data.size();
-  desc.args.AddObject(
+  desc.args_.AddObject(
       "biases", absl::make_unique<BufferDescriptor>(std::move(bias_desc)));
 
-  desc.resize_function = [](const std::vector<BHWC>& src_shapes,
-                            const std::vector<BHWC>& dst_shapes) {
-    uint3 groups_size{8, 4, 1};
-    uint3 groups_count{DivideRoundUp(dst_shapes[0].w, groups_size.x),
-                       DivideRoundUp(dst_shapes[0].h, groups_size.y),
-                       DivideRoundUp(dst_shapes[0].c, 4)};
-    return std::make_pair(groups_size, groups_count);
-  };
+  desc.work_group_size_ = int3(8, 4, 1);
 
   return desc;
 }
 
-ComputeTaskDescriptor DepthWiseConv3x3Stride1x1(
+void DepthWiseConv3x3Stride1x1::GetPossibleKernelWorkGroups(
+    TuningType tuning_type, const GpuInfo& gpu_info,
+    const KernelInfo& kernel_info, std::vector<int3>* work_groups) const {
+  const int grid_x = DivideRoundUp(dst_[0]->Width(), 2);
+  const int grid_z = dst_[0]->Slices();
+  int3 group_size{8, 4, 1};
+  if (grid_x <= 4) {
+    group_size.x = 4;
+    group_size.z = grid_z % 2 == 0 ? 2 : 1;
+  }
+  work_groups->push_back(group_size);
+}
+int3 DepthWiseConv3x3Stride1x1::GetGridSize() const {
+  const int grid_x = DivideRoundUp(dst_[0]->Width(), 2);
+  const int grid_y = DivideRoundUp(dst_[0]->Height(), 2);
+  const int grid_z = dst_[0]->Slices();
+  return int3(grid_x, grid_y, grid_z);
+}
+
+DepthWiseConv3x3Stride1x1 CreateDepthWiseConv3x3Stride1x1(
     const OperationDef& definition,
     const DepthwiseConvolution2DAttributes& attr) {
-  ComputeTaskDescriptor desc(definition);
-  desc.shader_source = GetKernelDepthWiseConv3x3Stride1x1();
+  DepthWiseConv3x3Stride1x1 desc(definition);
+  desc.code_ = GetKernelDepthWiseConv3x3Stride1x1();
   desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
   desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
 
-  desc.args.AddInt("padding_x", -attr.padding.prepended.w);
-  desc.args.AddInt("padding_y", -attr.padding.prepended.h);
+  desc.args_.AddInt("padding_x", -attr.padding.prepended.w);
+  desc.args_.AddInt("padding_y", -attr.padding.prepended.h);
 
   // For this operation we keep weights and biases in one buffer
   auto weights_reordered = ReorderWeightsDepthWiseConv3x3Stride1x1(attr);
@@ -496,25 +512,9 @@
   weights_desc.element_size = 4;
   weights_desc.data = GetByteBufferConverted(weights_reordered, data_type);
   weights_desc.size = weights_desc.data.size();
-  desc.args.AddObject(
+  desc.args_.AddObject(
       "weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
 
-  desc.resize_function = [](const std::vector<BHWC>& src_shapes,
-                            const std::vector<BHWC>& dst_shapes) {
-    const int grid_x = DivideRoundUp(dst_shapes[0].w, 2);
-    const int grid_y = DivideRoundUp(dst_shapes[0].h, 2);
-    const int grid_z = DivideRoundUp(dst_shapes[0].c, 4);
-    uint3 group_size{8, 4, 1};
-    if (grid_x <= 4) {
-      group_size.x = 4;
-      group_size.z = grid_z % 2 == 0 ? 2 : 1;
-    }
-    const int groups_x = DivideRoundUp(grid_x, group_size.x);
-    const int groups_y = DivideRoundUp(grid_y, group_size.y);
-    const int groups_z = DivideRoundUp(grid_z, group_size.z);
-    return std::make_pair(group_size, uint3(groups_x, groups_y, groups_z));
-  };
-
   return desc;
 }
 
@@ -525,18 +525,25 @@
          attr.strides.w == 1 && attr.dilations.h == 1 && attr.dilations.w == 1;
 }
 
-ComputeTaskDescriptor DepthWiseConv3x3Stride2(
+int3 DepthWiseConv3x3Stride2::GetGridSize() const {
+  const int grid_x = dst_[0]->Width();
+  const int grid_y = DivideRoundUp(dst_[0]->Height(), 2);
+  const int grid_z = dst_[0]->Slices();
+  return int3(grid_x, grid_y, grid_z);
+}
+
+DepthWiseConv3x3Stride2 CreateDepthWiseConv3x3Stride2(
     const OperationDef& definition,
     const DepthwiseConvolution2DAttributes& attr) {
-  ComputeTaskDescriptor desc(definition);
-  desc.shader_source = GetKernelDepthWiseConv3x3Stride2();
+  DepthWiseConv3x3Stride2 desc(definition);
+  desc.code_ = GetKernelDepthWiseConv3x3Stride2();
   desc.AddSrcTensor("src_tensor", definition.src_tensors[0]);
   desc.AddDstTensor("dst_tensor", definition.dst_tensors[0]);
 
-  desc.args.AddInt("padding_x", -attr.padding.prepended.w);
-  desc.args.AddInt("padding_y", -attr.padding.prepended.h);
-  desc.args.AddInt("stride_x", attr.strides.w);
-  desc.args.AddInt("dilation_x", attr.dilations.w);
+  desc.args_.AddInt("padding_x", -attr.padding.prepended.w);
+  desc.args_.AddInt("padding_y", -attr.padding.prepended.h);
+  desc.args_.AddInt("stride_x", attr.strides.w);
+  desc.args_.AddInt("dilation_x", attr.dilations.w);
 
   // For this operation we keep weights and biases in one buffer
   auto weights_reordered = ReorderWeightsDepthWiseConv3x3Stride2(attr);
@@ -546,21 +553,10 @@
   weights_desc.element_size = 4;
   weights_desc.data = GetByteBufferConverted(weights_reordered, data_type);
   weights_desc.size = weights_desc.data.size();
-  desc.args.AddObject(
+  desc.args_.AddObject(
       "weights", absl::make_unique<BufferDescriptor>(std::move(weights_desc)));
 
-  desc.resize_function = [](const std::vector<BHWC>& src_shapes,
-                            const std::vector<BHWC>& dst_shapes) {
-    const int grid_x = dst_shapes[0].w;
-    const int grid_y = DivideRoundUp(dst_shapes[0].h, 2);
-    const int grid_z = DivideRoundUp(dst_shapes[0].c, 4);
-    const uint3 group_size{8, 4, 1};
-    const int groups_x = DivideRoundUp(grid_x, group_size.x);
-    const int groups_y = DivideRoundUp(grid_y, group_size.y);
-    const int groups_z = DivideRoundUp(grid_z, group_size.z);
-    return std::make_pair(group_size, uint3(groups_x, groups_y, groups_z));
-  };
-
+  desc.work_group_size_ = int3(8, 4, 1);
   return desc;
 }
 
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h
index 25490ea..057c5c9 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h
@@ -20,13 +20,38 @@
 
 #include "tensorflow/lite/delegates/gpu/common/model.h"
 #include "tensorflow/lite/delegates/gpu/common/operations.h"
-#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
+#include "tensorflow/lite/delegates/gpu/common/task/gpu_operation.h"
 
 namespace tflite {
 namespace gpu {
 namespace metal {
 
-ComputeTaskDescriptor DepthWiseConvolution(
+class DepthWiseConvolution : public GPUOperation {
+ public:
+  DepthWiseConvolution() = default;
+  void GetPossibleKernelWorkGroups(
+      TuningType tuning_type, const GpuInfo& gpu_info,
+      const KernelInfo& kernel_info,
+      std::vector<int3>* work_groups) const override {
+    work_groups->push_back(work_group_size_);
+  }
+  int3 GetGridSize() const override;
+
+  // Move only
+  DepthWiseConvolution(DepthWiseConvolution&& kernel) = default;
+  DepthWiseConvolution& operator=(DepthWiseConvolution&& kernel) = default;
+  DepthWiseConvolution(const DepthWiseConvolution&) = delete;
+  DepthWiseConvolution& operator=(const DepthWiseConvolution&) = delete;
+
+ private:
+  explicit DepthWiseConvolution(const OperationDef& definition)
+      : GPUOperation(definition) {}
+  friend DepthWiseConvolution CreateDepthWiseConvolution(
+      const OperationDef& definition,
+      const DepthwiseConvolution2DAttributes& attr);
+};
+
+DepthWiseConvolution CreateDepthWiseConvolution(
     const OperationDef& definition,
     const DepthwiseConvolution2DAttributes& attr);
 
@@ -36,7 +61,32 @@
 //   kernel_size = 3x3;
 //   dilation = 1x1;
 //   stride = 1x1;
-ComputeTaskDescriptor DepthWiseConv3x3Stride1x1(
+class DepthWiseConv3x3Stride1x1 : public GPUOperation {
+ public:
+  DepthWiseConv3x3Stride1x1() = default;
+  void GetPossibleKernelWorkGroups(
+      TuningType tuning_type, const GpuInfo& gpu_info,
+      const KernelInfo& kernel_info,
+      std::vector<int3>* work_groups) const override;
+  int3 GetGridSize() const override;
+
+  // Move only
+  DepthWiseConv3x3Stride1x1(DepthWiseConv3x3Stride1x1&& kernel) = default;
+  DepthWiseConv3x3Stride1x1& operator=(DepthWiseConv3x3Stride1x1&& kernel) =
+      default;
+  DepthWiseConv3x3Stride1x1(const DepthWiseConv3x3Stride1x1&) = delete;
+  DepthWiseConv3x3Stride1x1& operator=(const DepthWiseConv3x3Stride1x1&) =
+      delete;
+
+ private:
+  explicit DepthWiseConv3x3Stride1x1(const OperationDef& definition)
+      : GPUOperation(definition) {}
+  friend DepthWiseConv3x3Stride1x1 CreateDepthWiseConv3x3Stride1x1(
+      const OperationDef& definition,
+      const DepthwiseConvolution2DAttributes& attr);
+};
+
+DepthWiseConv3x3Stride1x1 CreateDepthWiseConv3x3Stride1x1(
     const OperationDef& definition,
     const DepthwiseConvolution2DAttributes& attr);
 
@@ -50,7 +100,33 @@
 //   kernel_size = 3x3;
 //   dilation.y = 1;
 //   stride.y = 2;
-ComputeTaskDescriptor DepthWiseConv3x3Stride2(
+class DepthWiseConv3x3Stride2 : public GPUOperation {
+ public:
+  DepthWiseConv3x3Stride2() = default;
+  void GetPossibleKernelWorkGroups(
+      TuningType tuning_type, const GpuInfo& gpu_info,
+      const KernelInfo& kernel_info,
+      std::vector<int3>* work_groups) const override {
+    work_groups->push_back(work_group_size_);
+  }
+  int3 GetGridSize() const override;
+
+  // Move only
+  DepthWiseConv3x3Stride2(DepthWiseConv3x3Stride2&& kernel) = default;
+  DepthWiseConv3x3Stride2& operator=(DepthWiseConv3x3Stride2&& kernel) =
+      default;
+  DepthWiseConv3x3Stride2(const DepthWiseConv3x3Stride2&) = delete;
+  DepthWiseConv3x3Stride2& operator=(const DepthWiseConv3x3Stride2&) = delete;
+
+ private:
+  explicit DepthWiseConv3x3Stride2(const OperationDef& definition)
+      : GPUOperation(definition) {}
+  friend DepthWiseConv3x3Stride2 CreateDepthWiseConv3x3Stride2(
+      const OperationDef& definition,
+      const DepthwiseConvolution2DAttributes& attr);
+};
+
+DepthWiseConv3x3Stride2 CreateDepthWiseConv3x3Stride2(
     const OperationDef& definition,
     const DepthwiseConvolution2DAttributes& attr);
 
diff --git a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm
index 763e6a9..bc543c3 100644
--- a/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm
+++ b/tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv_test.mm
@@ -26,181 +26,181 @@
 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
 #include "tensorflow/lite/delegates/gpu/common/util.h"
 #include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
+#include "tensorflow/lite/delegates/gpu/metal/kernels/depthwise_conv.h"
 #include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
 
-using ::tflite::gpu::Axis;
-using ::tflite::gpu::BHWC;
-using ::tflite::gpu::DataType;
-using ::tflite::gpu::DepthwiseConvolution2DAttributes;
-using ::tflite::gpu::HW;
-using ::tflite::gpu::Linear;
-using ::tflite::gpu::OHWI;
-using ::tflite::gpu::OperationType;
-using ::tflite::gpu::Tensor;
-using ::tflite::gpu::TensorRef;
-using ::tflite::gpu::metal::CompareVectors;
-using ::tflite::gpu::metal::SingleOpModel;
-
-@interface DepthwiseConvTest : XCTestCase
+@interface DepthwiseConvMetalTest : XCTestCase
 @end
 
-@implementation DepthwiseConvTest {
+@implementation DepthwiseConvMetalTest {
   tflite::gpu::metal::MetalExecutionEnvironment exec_env_;
 }
 
-- (void)testO4H1W1I2Strides1x1Dilation1x1 {
-  TensorRef<BHWC> input;
-  input.type = DataType::FLOAT32;
-  input.ref = 0;
-  input.shape = BHWC(1, 1, 1, 2);
+namespace tflite {
+namespace gpu {
+namespace metal {
+
+absl::Status DepthWiseO4H1W1I2Strides1x1Dilation1x1Test(TestExecutionEnvironment* env) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 1, 1, 2);
+  src_tensor.data = {1, 3};
 
   DepthwiseConvolution2DAttributes attr;
-  Tensor<Linear, DataType::FLOAT32> bias;
-  bias.shape.v = 4;
-  bias.id = 1;
-  bias.data = {1, 2, 3, 4};
-  attr.bias = std::move(bias);
-
-  Tensor<OHWI, DataType::FLOAT32> weights;
-  weights.shape = OHWI(2, 1, 1, 2);
-  weights.id = 2;
-  weights.data = {1, 3, 2, 4};
-
-  attr.weights = std::move(weights);
-
+  attr.weights.shape = OHWI(2, 1, 1, 2);
+  attr.weights.data = {1, 3, 2, 4};
+  attr.bias.shape = Linear(4);
+  attr.bias.data = {1, 2, 3, 4};
   attr.dilations = HW(1, 1);
   attr.padding.prepended = HW(0, 0);
   attr.padding.appended = HW(0, 0);
   attr.strides = HW(1, 1);
 
-  TensorRef<BHWC> output;
-  output.type = DataType::FLOAT32;
-  output.ref = 3;
-  output.shape = BHWC(1, 1, 1, 4);
-
-  SingleOpModel model({ToString(OperationType::DEPTHWISE_CONVOLUTION), std::move(attr)}, {input},
-                      {output});
-  XCTAssertTrue(model.PopulateTensor(0, {1, 3}));
-  auto status = model.Invoke();
-  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
-  status = CompareVectors({2, 4, 12, 16}, model.GetOutput(0), 1e-6f);
-  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+  for (auto storage : env->GetSupportedStorages()) {
+    for (auto precision : env->GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      DepthWiseConvolution operation = CreateDepthWiseConvolution(op_def, attr);
+      RETURN_IF_ERROR(env->ExecuteGPUOperation(
+          src_tensor, absl::make_unique<DepthWiseConvolution>(std::move(operation)),
+          BHWC(1, 1, 1, 4), &dst_tensor));
+      RETURN_IF_ERROR(PointWiseNear({2, 4, 12, 16}, dst_tensor.data, eps))
+          << "Failed using precision " << ToString(precision);
+    }
+  }
+  return absl::OkStatus();
 }
 
-- (void)testO2H1W1I1Strides2x2Dilation1x1 {
-  TensorRef<BHWC> input;
-  input.type = DataType::FLOAT32;
-  input.ref = 0;
-  input.shape = BHWC(1, 3, 3, 1);
+absl::Status DepthWiseO2H1W1I1Strides2x2Dilation1x1Test(TestExecutionEnvironment* env) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 3, 3, 1);
+  src_tensor.data = {1, 0, 1, 1, 0, 1, 1, 0, 1};
 
   DepthwiseConvolution2DAttributes attr;
-  Tensor<Linear, DataType::FLOAT32> bias;
-  bias.shape.v = 4;
-  bias.id = 1;
-  bias.data = {0, 0};
-  attr.bias = std::move(bias);
-
-  Tensor<OHWI, DataType::FLOAT32> weights;
-  weights.shape = OHWI(2, 1, 1, 1);
-  weights.id = 1;
-  weights.data = {1, 3};
-
-  attr.weights = std::move(weights);
-
+  attr.weights.shape = OHWI(2, 1, 1, 1);
+  attr.weights.data = {1, 3};
+  attr.bias.shape = Linear(2);
+  attr.bias.data = {0.0f, 0.0f};
   attr.dilations = HW(1, 1);
   attr.padding.prepended = HW(0, 0);
   attr.padding.appended = HW(0, 0);
   attr.strides = HW(2, 2);
 
-  TensorRef<BHWC> output;
-  output.type = DataType::FLOAT32;
-  output.ref = 3;
-  output.shape = BHWC(1, 2, 2, 2);
-
-  SingleOpModel model({ToString(OperationType::DEPTHWISE_CONVOLUTION), std::move(attr)}, {input},
-                      {output});
-  XCTAssertTrue(model.PopulateTensor(0, {1, 0, 1, 1, 0, 1, 1, 0, 1}));
-  auto status = model.Invoke();
-  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
-  status = CompareVectors({1, 3, 1, 3, 1, 3, 1, 3}, model.GetOutput(0), 1e-6f);
-  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+  for (auto storage : env->GetSupportedStorages()) {
+    for (auto precision : env->GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      DepthWiseConvolution operation = CreateDepthWiseConvolution(op_def, attr);
+      RETURN_IF_ERROR(env->ExecuteGPUOperation(
+          src_tensor, absl::make_unique<DepthWiseConvolution>(std::move(operation)),
+          BHWC(1, 2, 2, 2), &dst_tensor));
+      RETURN_IF_ERROR(PointWiseNear({1, 3, 1, 3, 1, 3, 1, 3}, dst_tensor.data, eps))
+          << "Failed using precision " << ToString(precision);
+    }
+  }
+  return absl::OkStatus();
 }
 
-- (void)testO2H2W2I1Strides1x1Dilation2x2 {
-  TensorRef<BHWC> input;
-  input.type = DataType::FLOAT32;
-  input.ref = 0;
-  input.shape = BHWC(1, 3, 3, 1);
+absl::Status DepthWiseO2H2W2I1Strides1x1Dilation2x2Test(TestExecutionEnvironment* env) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 3, 3, 1);
+  src_tensor.data = {1, 0, 1, 1, 0, 1, 1, 0, 1};
 
   DepthwiseConvolution2DAttributes attr;
-  Tensor<Linear, DataType::FLOAT32> bias;
-  bias.shape.v = 4;
-  bias.id = 1;
-  bias.data = {0, 0};
-  attr.bias = std::move(bias);
-
-  Tensor<OHWI, DataType::FLOAT32> weights;
-  weights.shape = OHWI(2, 2, 2, 1);
-  weights.id = 1;
-  weights.data = {1, 2, 3, 4, 5, 6, 7, 8};
-
-  attr.weights = std::move(weights);
-
+  attr.weights.shape = OHWI(2, 2, 2, 1);
+  attr.weights.data = {1, 2, 3, 4, 5, 6, 7, 8};
+  attr.bias.shape = Linear(2);
+  attr.bias.data = {0.0f, 0.0f};
   attr.dilations = HW(2, 2);
   attr.padding.prepended = HW(0, 0);
   attr.padding.appended = HW(0, 0);
   attr.strides = HW(1, 1);
 
-  TensorRef<BHWC> output;
-  output.type = DataType::FLOAT32;
-  output.ref = 3;
-  output.shape = BHWC(1, 1, 1, 2);
-
-  SingleOpModel model({ToString(OperationType::DEPTHWISE_CONVOLUTION), std::move(attr)}, {input},
-                      {output});
-  XCTAssertTrue(model.PopulateTensor(0, {1, 0, 1, 1, 0, 1, 1, 0, 1}));
-  auto status = model.Invoke();
-  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
-  status = CompareVectors({10, 26}, model.GetOutput(0), 1e-6f);
-  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+  for (auto storage : env->GetSupportedStorages()) {
+    for (auto precision : env->GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      DepthWiseConvolution operation = CreateDepthWiseConvolution(op_def, attr);
+      RETURN_IF_ERROR(env->ExecuteGPUOperation(
+          src_tensor, absl::make_unique<DepthWiseConvolution>(std::move(operation)),
+          BHWC(1, 1, 1, 2), &dst_tensor));
+      RETURN_IF_ERROR(PointWiseNear({10, 26}, dst_tensor.data, eps))
+          << "Failed using precision " << ToString(precision);
+    }
+  }
+  return absl::OkStatus();
 }
 
-- (void)testShape2x2Kernel2x2 {
-  TensorRef<BHWC> input;
-  input.type = DataType::FLOAT32;
-  input.ref = 0;
-  input.shape = BHWC(1, 2, 2, 1);
+absl::Status DepthWiseShape2x2Kernel2x2Test(TestExecutionEnvironment* env) {
+  TensorFloat32 src_tensor;
+  src_tensor.shape = BHWC(1, 2, 2, 1);
+  src_tensor.data = {1, 4, 9, 16};
 
   DepthwiseConvolution2DAttributes attr;
-  Tensor<Linear, DataType::FLOAT32> bias;
-  bias.shape.v = 1;
-  bias.id = 1;
-  bias.data = {0};
-  attr.bias = std::move(bias);
-
-  Tensor<OHWI, DataType::FLOAT32> weights;
-  weights.shape = OHWI(1, 2, 2, 1);
-  weights.id = 1;
-  weights.data = {1, 2, 3, 4};
-
-  attr.weights = std::move(weights);
-
+  attr.weights.shape = OHWI(1, 2, 2, 1);
+  attr.weights.data = {1, 2, 3, 4};
+  attr.bias.shape = Linear(1);
+  attr.bias.data = {0.0f};
   attr.dilations = HW(1, 1);
   attr.padding.prepended = HW(0, 0);
   attr.padding.appended = HW(1, 1);
   attr.strides = HW(1, 1);
 
-  TensorRef<BHWC> output;
-  output.type = DataType::FLOAT32;
-  output.ref = 3;
-  output.shape = BHWC(1, 2, 2, 1);
+  for (auto storage : env->GetSupportedStorages()) {
+    for (auto precision : env->GetSupportedPrecisions()) {
+      const float eps = precision == CalculationsPrecision::F32 ? 1e-6f : 1e-3f;
+      OperationDef op_def;
+      op_def.precision = precision;
+      auto data_type = DeduceDataTypeFromPrecision(precision);
+      op_def.src_tensors.push_back({data_type, storage, Layout::HWC});
+      op_def.dst_tensors.push_back({data_type, storage, Layout::HWC});
+      TensorFloat32 dst_tensor;
+      DepthWiseConvolution operation = CreateDepthWiseConvolution(op_def, attr);
+      RETURN_IF_ERROR(env->ExecuteGPUOperation(
+          src_tensor, absl::make_unique<DepthWiseConvolution>(std::move(operation)),
+          BHWC(1, 2, 2, 1), &dst_tensor));
+      RETURN_IF_ERROR(PointWiseNear({100, 52, 41, 16}, dst_tensor.data, eps))
+          << "Failed using precision " << ToString(precision);
+    }
+  }
+  return absl::OkStatus();
+}
 
-  SingleOpModel model({ToString(OperationType::DEPTHWISE_CONVOLUTION), std::move(attr)}, {input},
-                      {output});
-  XCTAssertTrue(model.PopulateTensor(0, {1, 4, 9, 16}));
-  auto status = model.Invoke();
+}  // namespace metal
+}  // namespace gpu
+}  // namespace tflite
+
+- (void)testO4H1W1I2Strides1x1Dilation1x1 {
+  auto status = tflite::gpu::metal::DepthWiseO4H1W1I2Strides1x1Dilation1x1Test(&exec_env_);
   XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
-  status = CompareVectors({100, 52, 41, 16}, model.GetOutput(0), 1e-6f);
+}
+
+- (void)testO2H1W1I1Strides2x2Dilation1x1 {
+  auto status = tflite::gpu::metal::DepthWiseO2H1W1I1Strides2x2Dilation1x1Test(&exec_env_);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testO2H2W2I1Strides1x1Dilation2x2 {
+  auto status = tflite::gpu::metal::DepthWiseO2H2W2I1Strides1x1Dilation2x2Test(&exec_env_);
+  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
+}
+
+- (void)testShape2x2Kernel2x2 {
+  auto status = tflite::gpu::metal::DepthWiseShape2x2Kernel2x2Test(&exec_env_);
   XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
 }
 
diff --git a/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc b/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc
index 2b3002e..06d78be 100644
--- a/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc
+++ b/tensorflow/lite/delegates/gpu/metal/selectors/operation_selector.cc
@@ -60,17 +60,17 @@
 namespace metal {
 namespace {
 
-std::unique_ptr<ComputeTaskDescriptor> SelectDepthWiseConv(
+std::unique_ptr<GPUOperation> SelectDepthWiseConv(
     const OperationDef& op_def, const DepthwiseConvolution2DAttributes& attr) {
   if (CheckDepthWiseConv3x3Stride1x1Support(attr)) {
-    auto gpu_op = DepthWiseConv3x3Stride1x1(op_def, attr);
-    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+    auto gpu_op = CreateDepthWiseConv3x3Stride1x1(op_def, attr);
+    return absl::make_unique<DepthWiseConv3x3Stride1x1>(std::move(gpu_op));
   } else if (CheckDepthWiseConv3x3Stride2Support(attr)) {
-    auto gpu_op = DepthWiseConv3x3Stride2(op_def, attr);
-    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+    auto gpu_op = CreateDepthWiseConv3x3Stride2(op_def, attr);
+    return absl::make_unique<DepthWiseConv3x3Stride2>(std::move(gpu_op));
   } else {
-    auto gpu_op = DepthWiseConvolution(op_def, attr);
-    return absl::make_unique<ComputeTaskDescriptor>(std::move(gpu_op));
+    auto gpu_op = CreateDepthWiseConvolution(op_def, attr);
+    return absl::make_unique<DepthWiseConvolution>(std::move(gpu_op));
   }
 }
 
@@ -387,7 +387,7 @@
             "DepthWise Convolution does not support more than 1 runtime "
             "tensor");
       }
-      gpu_operation->task_desc = SelectDepthWiseConv(
+      gpu_operation->operation = SelectDepthWiseConv(
           op_def, absl::any_cast<DepthwiseConvolution2DAttributes>(
                       node.operation.attributes));
       break;