Moved common kernel code processing from cl/metal arguments to gpu_operation.
PiperOrigin-RevId: 411747825
Change-Id: I02cd5bec64929323069d924b26b710c295b78f9a
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc b/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc
index db0bddc..2d5554e 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.cc
@@ -144,8 +144,13 @@
// Static
constexpr char CLArguments::kArgsPrefix[];
-absl::Status CLArguments::Init(const GpuInfo& gpu_info, CLContext* context,
- Arguments* args, std::string* code) {
+absl::Status CLArguments::Init(
+ const GpuInfo& gpu_info,
+ const std::map<std::string, std::string>& linkables, CLContext* context,
+ Arguments* args, std::string* code) {
+ RETURN_IF_ERROR(args->AddObjectsScalarArgs(gpu_info));
+ RETURN_IF_ERROR(args->ResolveSelectorsPass(gpu_info, linkables, code));
+ args->GetActiveArguments(*code);
RETURN_IF_ERROR(AllocateObjects(*args, context));
RETURN_IF_ERROR(AddObjectArgs(gpu_info, *args));
object_refs_ = std::move(args->object_refs_);
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
index 5832cb6..aad4579 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
+++ b/tensorflow/lite/delegates/gpu/cl/cl_arguments.h
@@ -35,6 +35,7 @@
CLArguments() = default;
absl::Status Init(const GpuInfo& gpu_info,
+ const std::map<std::string, std::string>& linkables,
CLContext* context, Arguments* args, std::string* code);
absl::Status Init(const GpuInfo& gpu_info, Arguments* args,
CLContext* context);
diff --git a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
index c11b999..7b2537c 100644
--- a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
+++ b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc
@@ -181,12 +181,13 @@
}
absl::Status ClOperation::Compile(const CreationContext& creation_context) {
- RETURN_IF_ERROR(operation_->AssembleCode(creation_context.GetGpuInfo()));
+ operation_->AssembleCode(creation_context.GetGpuInfo());
operation_->code_ =
GetCommonOpenCLDefines(operation_->definition_.precision) +
operation_->code_;
RETURN_IF_ERROR(cl_args_.Init(
creation_context.GetGpuInfo(),
+ {{operation_->dst_tensors_names_[0], operation_->elementwise_code_}},
creation_context.context, &operation_->args_, &operation_->code_));
RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel(
operation_->code_, "main_function", operation_->compiler_options_,
diff --git a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
index 58a1b38..e5a361e 100644
--- a/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
+++ b/tensorflow/lite/delegates/gpu/cl/kernels/converter.cc
@@ -162,9 +162,7 @@
context_ = &environment->context();
shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h,
input_def.dimensions.w, input_def.dimensions.c);
- RETURN_IF_ERROR(
- args.Compile(environment->device().GetInfo(), {}, &shader_src));
- RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), nullptr,
+ RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), {}, nullptr,
&args, &shader_src));
return environment->program_cache()->GetOrCreateCLKernel(
shader_src, "tensor_to_tensor", environment->context(),
@@ -257,9 +255,7 @@
context_ = &environment->context();
shape_ = BHWC(input_def.dimensions.b, input_def.dimensions.h,
input_def.dimensions.w, input_def.dimensions.c);
- RETURN_IF_ERROR(
- args.Compile(environment->device().GetInfo(), {}, &shader_src));
- RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), nullptr,
+ RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), {}, nullptr,
&args, &shader_src));
return environment->program_cache()->GetOrCreateCLKernel(
shader_src, "tensor_to_bhwc", environment->context(),
@@ -354,9 +350,7 @@
context_ = &environment->context();
shape_ = BHWC(output_def.dimensions.b, output_def.dimensions.h,
output_def.dimensions.w, output_def.dimensions.c);
- RETURN_IF_ERROR(
- args.Compile(environment->device().GetInfo(), {}, &shader_src));
- RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), nullptr,
+ RETURN_IF_ERROR(cl_args_.Init(environment->device().GetInfo(), {}, nullptr,
&args, &shader_src));
return environment->program_cache()->GetOrCreateCLKernel(
shader_src, "bhwc_to_tensor", environment->context(),
diff --git a/tensorflow/lite/delegates/gpu/common/task/arguments.cc b/tensorflow/lite/delegates/gpu/common/task/arguments.cc
index 281be99..60f839a 100644
--- a/tensorflow/lite/delegates/gpu/common/task/arguments.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/arguments.cc
@@ -286,15 +286,6 @@
}
}
-absl::Status Arguments::Compile(
- const GpuInfo& gpu_info,
- const std::map<std::string, std::string>& linkables, std::string* code) {
- RETURN_IF_ERROR(AddObjectsScalarArgs(gpu_info));
- RETURN_IF_ERROR(ResolveSelectorsPass(gpu_info, linkables, code));
- GetActiveArguments(*code);
- return absl::OkStatus();
-}
-
absl::Status Arguments::ResolveSelectorsPass(
const GpuInfo& gpu_info,
const std::map<std::string, std::string>& linkables,
diff --git a/tensorflow/lite/delegates/gpu/common/task/arguments.h b/tensorflow/lite/delegates/gpu/common/task/arguments.h
index df87322..b706ab9 100644
--- a/tensorflow/lite/delegates/gpu/common/task/arguments.h
+++ b/tensorflow/lite/delegates/gpu/common/task/arguments.h
@@ -77,6 +77,8 @@
void ReleaseCPURepresentation();
+ void GetActiveArguments(const std::string& code);
+
void SetStateValueForAllObjects(const std::string& key,
const std::string& value);
@@ -122,18 +124,6 @@
*result = std::move(object_refs_);
}
- absl::Status Compile(const GpuInfo& gpu_info,
- const std::map<std::string, std::string>& linkables,
- std::string* code);
-
- private:
- friend flatbuffers::Offset<tflite::gpu::data::Arguments> Encode(
- const Arguments& args, flatbuffers::FlatBufferBuilder* builder);
- friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args,
- Arguments* args);
-
- void GetActiveArguments(const std::string& code);
-
absl::Status ResolveSelectorsPass(
const GpuInfo& gpu_info,
const std::map<std::string, std::string>& linkables,
@@ -152,6 +142,12 @@
absl::Status AddObjectsScalarArgs(const GpuInfo& gpu_info);
void ResolveArgsPass(std::string* code) const;
+ private:
+ friend flatbuffers::Offset<tflite::gpu::data::Arguments> Encode(
+ const Arguments& args, flatbuffers::FlatBufferBuilder* builder);
+ friend absl::Status Decode(const tflite::gpu::data::Arguments* fb_args,
+ Arguments* args);
+
friend class cl::CLArguments;
friend class metal::MetalArguments;
diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
index 68bbb2e..9400269 100644
--- a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
+++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.cc
@@ -183,7 +183,7 @@
args_.AddObjectRef(tensor_name, AccessType::WRITE, std::move(desc_new));
}
-absl::Status GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
+void GPUOperation::AssembleCode(const GpuInfo& gpu_info) {
if (elementwise_) {
auto src_desc =
absl::make_unique<TensorDescriptor>(definition_.src_tensors[0]);
@@ -204,9 +204,6 @@
elementwise_code_ = "{\n" + code_ + "\n}\n" + elementwise_code_;
code_ = GetElementWiseCode(definition_, check_src_channels_size_);
}
- RETURN_IF_ERROR(args_.Compile(
- gpu_info, {{dst_tensors_names_[0], elementwise_code_}}, &code_));
- return absl::OkStatus();
}
void GPUOperation::GetPossibleKernelWorkGroups(
diff --git a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
index 9018967..9196f72 100644
--- a/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
+++ b/tensorflow/lite/delegates/gpu/common/task/gpu_operation.h
@@ -120,7 +120,7 @@
TuningType tuning_type, const GpuInfo& gpu_info,
const KernelInfo& kernel_info, std::vector<int3>* work_groups) const;
- absl::Status AssembleCode(const GpuInfo& gpu_info);
+ void AssembleCode(const GpuInfo& gpu_info);
virtual absl::Status PostCompileCheck(const GpuInfo& gpu_info,
const KernelInfo& kernel_info) {
diff --git a/tensorflow/lite/delegates/gpu/metal/compute_task.cc b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
index f67e500..714a0ef 100644
--- a/tensorflow/lite/delegates/gpu/metal/compute_task.cc
+++ b/tensorflow/lite/delegates/gpu/metal/compute_task.cc
@@ -118,8 +118,10 @@
}
absl::Status ComputeTask::Compile(MetalDevice* device) {
- RETURN_IF_ERROR(operation_->AssembleCode(device->GetInfo()));
- RETURN_IF_ERROR(metal_args_.Init(use_arguments_buffer_, device,
+ operation_->AssembleCode(device->GetInfo());
+ const std::map<std::string, std::string> linkables = {
+ {operation_->dst_tensors_names_[0], operation_->elementwise_code_}};
+ RETURN_IF_ERROR(metal_args_.Init(linkables, use_arguments_buffer_, device,
&operation_->args_, &operation_->code_));
operation_->args_.ReleaseCPURepresentation();
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
index 94be450..a8245ab 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
+++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.cc
@@ -114,8 +114,13 @@
constexpr char MetalArguments::kArgsPrefix[];
absl::Status MetalArguments::Init(
+ const std::map<std::string, std::string>& linkables,
bool use_arguments_buffer, MetalDevice* device, Arguments* args,
std::string* code) {
+ RETURN_IF_ERROR(args->AddObjectsScalarArgs(device->GetInfo()));
+ RETURN_IF_ERROR(
+ args->ResolveSelectorsPass(device->GetInfo(), linkables, code));
+ args->GetActiveArguments(*code);
RETURN_IF_ERROR(AllocateObjects(*args, device->device()));
RETURN_IF_ERROR(AddObjectArgs(device->GetInfo(), *args));
object_refs_ = std::move(args->object_refs_);
diff --git a/tensorflow/lite/delegates/gpu/metal/metal_arguments.h b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h
index 849de33..d6afa88 100644
--- a/tensorflow/lite/delegates/gpu/metal/metal_arguments.h
+++ b/tensorflow/lite/delegates/gpu/metal/metal_arguments.h
@@ -35,7 +35,8 @@
public:
MetalArguments() = default;
- absl::Status Init(bool use_arguments_buffer, MetalDevice* device,
+ absl::Status Init(const std::map<std::string, std::string>& linkables,
+ bool use_arguments_buffer, MetalDevice* device,
Arguments* args, std::string* code);
// Move only