Moved common kernel code processing from cl/metal arguments to gpu_operation.

PiperOrigin-RevId: 411747825
Change-Id: I02cd5bec64929323069d924b26b710c295b78f9a
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc b/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc
index db0bddc..2d5554e 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc
@@ -144,8 +144,13 @@
 // Static
 constexpr char CLArguments::kArgsPrefix[];
 
-absl::Status CLArguments::Init(const GpuInfo& gpu_info, CLContext* context,
-                               Arguments* args, std::string* code) {
+absl::Status CLArguments::Init(
+    const GpuInfo& gpu_info,
+    const std::map<std::string, std::string>& linkables, CLContext* context,
+    Arguments* args, std::string* code) {
+  RETURN_IF_ERROR(args->AddObjectsScalarArgs(gpu_info));
+  RETURN_IF_ERROR(args->ResolveSelectorsPass(gpu_info, linkables, code));
+  args->GetActiveArguments(*code);
   RETURN_IF_ERROR(AllocateObjects(*args, context));
   RETURN_IF_ERROR(AddObjectArgs(gpu_info, *args));
   object_refs_ = std::move(args->object_refs_);
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
index 5832cb6..aad4579 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
+++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
@@ -35,6 +35,7 @@
   CLArguments() = default;
 
   absl::Status Init(const GpuInfo& gpu_info,
+                    const std::map<std::string, std::string>& linkables,
                     CLContext* context, Arguments* args, std::string* code);
   absl::Status Init(const GpuInfo& gpu_info, Arguments* args,
                     CLContext* context);
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
index c11b999..7b2537c 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
@@ -181,12 +181,13 @@
 }
 
 absl::Status ClOperation::Compile(const CreationContext& creation_context) {
-  RETURN_IF_ERROR(operation_->AssembleCode(creation_context.GetGpuInfo()));
+  operation_->AssembleCode(creation_context.GetGpuInfo());
   operation_->code_ =
       GetCommonOpenCLDefines(operation_->definition_.precision) +
       operation_->code_;
   RETURN_IF_ERROR(cl_args_.Init(
       creation_context.GetGpuInfo(),
+      {{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},
       creation_context.context, &operation_->args_, &operation_->code_));
   RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
       operation_->code_, "main_function", operation_->compiler_options_,
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
index 58a1b38..e5a361e 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
@@ -162,9 +162,7 @@
     context_ = &environment->context();
     shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h,
                   input_def.dimensions.w, input_def.dimensions.c);
-    RETURN_IF_ERROR(
-        args.Compile(environment->device().GetInfo(), {}, &shader_src));
-    RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), nullptr,
+    RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), {}, nullptr,
                                   &args, &shader_src));
     return environment->program_cache()->GetOrCreateCLKernel(
         shader_src, "tensor_to_tensor", environment->context(),
@@ -257,9 +255,7 @@
     context_ = &environment->context();
     shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h,
                   input_def.dimensions.w, input_def.dimensions.c);
-    RETURN_IF_ERROR(
-        args.Compile(environment->device().GetInfo(), {}, &shader_src));
-    RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), nullptr,
+    RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), {}, nullptr,
                                   &args, &shader_src));
     return environment->program_cache()->GetOrCreateCLKernel(
         shader_src, "tensor_to_bhwc", environment->context(),
@@ -354,9 +350,7 @@
     context_ = &environment->context();
     shape_ = BHWC(output_def.dimensions.b, output_def.dimensions.h,
                   output_def.dimensions.w, output_def.dimensions.c);
-    RETURN_IF_ERROR(
-        args.Compile(environment->device().GetInfo(), {}, &shader_src));
-    RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), nullptr,
+    RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), {}, nullptr,
                                   &args, &shader_src));
     return environment->program_cache()->GetOrCreateCLKernel(
         shader_src, "bhwc_to_tensor", environment->context(),
diff --git a/tensorflow/lite/delegates/gpu/common/task/arguments.cc b/tensorflow/lite/delegates/gpu/common/task/arguments.cc
index 281be99..60f839a 100644
--- a/tensorflow/lite/delegates/gpu/common/task/arguments.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/arguments.cc
@@ -286,15 +286,6 @@
   }
 }
 
-absl::Status Arguments::Compile(
-    const GpuInfo& gpu_info,
-    const std::map<std::string, std::string>& linkables, std::string* code) {
-  RETURN_IF_ERROR(AddObjectsScalarArgs(gpu_info));
-  RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, linkables, code));
-  GetActiveArguments(*code);
-  return absl::OkStatus();
-}
-
 absl::Status Arguments::ResolveSelectorsPass(
     const GpuInfo& gpu_info,
     const std::map<std::string, std::string>& linkables,
diff --git a/tensorflow/lite/delegates/gpu/common/task/arguments.h b/tensorflow/lite/delegates/gpu/common/task/arguments.h
index df87322..b706ab9 100644
--- a/tensorflow/lite/delegates/gpu/common/task/arguments.h
+++ b/tensorflow/lite/delegates/gpu/common/task/arguments.h
@@ -77,6 +77,8 @@
 
   void ReleaseCPURepresentation();
 
+  void GetActiveArguments(const std::string& code);
+
   void SetStateValueForAllObjects(const std::string& key,
                                   const std::string& value);
 
@@ -122,18 +124,6 @@
     *result = std::move(object_refs_);
   }
 
-  absl::Status Compile(const GpuInfo& gpu_info,
-                       const std::map<std::string, std::string>& linkables,
-                       std::string* code);
-
- private:
-  friend flatbuffers::Offset<tflite::gpu::data::Arguments> Encode(
-      const Arguments& args, flatbuffers::FlatBufferBuilder* builder);
-  friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args,
-                             Arguments* args);
-
-  void GetActiveArguments(const std::string& code);
-
   absl::Status ResolveSelectorsPass(
       const GpuInfo& gpu_info,
       const std::map<std::string, std::string>& linkables,
@@ -152,6 +142,12 @@
   absl::Status AddObjectsScalarArgs(const GpuInfo& gpu_info);
   void ResolveArgsPass(std::string* code) const;
 
+ private:
+  friend flatbuffers::Offset<tflite::gpu::data::Arguments> Encode(
+      const Arguments& args, flatbuffers::FlatBufferBuilder* builder);
+  friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args,
+                             Arguments* args);
+
   friend class cl::CLArguments;
   friend class metal::MetalArguments;
 
diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
index 68bbb2e..9400269 100644
--- a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
@@ -183,7 +183,7 @@
   args_.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new));
 }
 
-absl::Status GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
+void GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
   if (elementwise_) {
     auto src_desc =
         absl::make_unique<TensorDescriptor>(definition_.src_tensors[0]);
@@ -204,9 +204,6 @@
     elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_;
     code_ = GetElementWiseCode(definition_, check_src_channels_size_);
   }
-  RETURN_IF_ERROR(args_.Compile(
-      gpu_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_));
-  return absl::OkStatus();
 }
 
 void GPUOperation::GetPossibleKernelWorkGroups(
diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
index 9018967..9196f72 100644
--- a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
+++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
@@ -120,7 +120,7 @@
       TuningType tuning_type, const GpuInfo& gpu_info,
       const KernelInfo& kernel_info, std::vector<int3>* work_groups) const;
 
-  absl::Status AssembleCode(const GpuInfo& gpu_info);
+  void AssembleCode(const GpuInfo& gpu_info);
 
   virtual absl::Status PostCompileCheck(const GpuInfo& gpu_info,
                                         const KernelInfo& kernel_info) {
diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.cc b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
index f67e500..714a0ef 100644
--- a/tensorflow/lite/delegates/gpu/metal/compute_task.cc
+++ b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
@@ -118,8 +118,10 @@
 }
 
 absl::Status ComputeTask::Compile(MetalDevice* device) {
-  RETURN_IF_ERROR(operation_->AssembleCode(device->GetInfo()));
-  RETURN_IF_ERROR(metal_args_.Init(use_arguments_buffer_, device,
+  operation_->AssembleCode(device->GetInfo());
+  const std::map<std::string, std::string> linkables = {
+      {operation_->dst_tensors_names_[0], operation_->elementwise_code_}};
+  RETURN_IF_ERROR(metal_args_.Init(linkables, use_arguments_buffer_, device,
                                    &operation_->args_, &operation_->code_));
 
   operation_->args_.ReleaseCPURepresentation();
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
index 94be450..a8245ab 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
+++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
@@ -114,8 +114,13 @@
 constexpr char MetalArguments::kArgsPrefix[];
 
 absl::Status MetalArguments::Init(
+    const std::map<std::string, std::string>& linkables,
     bool use_arguments_buffer, MetalDevice* device, Arguments* args,
     std::string* code) {
+  RETURN_IF_ERROR(args->AddObjectsScalarArgs(device->GetInfo()));
+  RETURN_IF_ERROR(
+      args->ResolveSelectorsPass(device->GetInfo(), linkables, code));
+  args->GetActiveArguments(*code);
   RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
   RETURN_IF_ERROR(AddObjectArgs(device->GetInfo(), *args));
   object_refs_ = std::move(args->object_refs_);
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.h b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h
index 849de33..d6afa88 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.h
+++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h
@@ -35,7 +35,8 @@
  public:
   MetalArguments() = default;
 
-  absl::Status Init(bool use_arguments_buffer, MetalDevice* device,
+  absl::Status Init(const std::map<std::string, std::string>& linkables,
+                    bool use_arguments_buffer, MetalDevice* device,
                     Arguments* args, std::string* code);
 
   // Move only