Using TensorDescriptor instead of DummyTensor in OpenCL inference context.

PiperOrigin-RevId: 401161076
Change-Id: I9750034906311eb9ac234e6fade6d09b2385b41c
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.cc b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
index 4a28d6f..8b3f8ba 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.cc
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.cc
@@ -330,8 +330,9 @@
     }
     RETURN_IF_ERROR(SelectBestStorageType(gpu_info, shape, storage_type,
                                           data_type, layout, &storage_type));
-    tensor_reserver_.Add(
-        t->id, {shape, TensorDescriptor{data_type, storage_type, layout}});
+    TensorDescriptor tensor_desc{data_type, storage_type, layout};
+    tensor_desc.shape = BHWDC(shape.b, shape.h, shape.w, 1, shape.c);
+    tensor_reserver_.Add(t->id, tensor_desc);
     max_id = std::max(max_id, t->id);
   }
   tensor_reserver_.SetNext(max_id + 1);
@@ -344,7 +345,7 @@
   std::map<ValueId, TensorDescriptor> tensor_descriptors;
   const auto values = graph.values();
   for (auto value : values) {
-    tensor_descriptors[value->id] = tensor_reserver_.Get(value->id).descriptor;
+    tensor_descriptors[value->id] = tensor_reserver_.Get(value->id);
   }
   std::set<NodeId> consumed_nodes;
   std::vector<Node*> graph_nodes = graph.nodes();
@@ -365,7 +366,7 @@
           absl::any_cast<ConstTensorAttributes>(node.operation.attributes);
       auto outputs = graph.FindOutputs(node.id);
       const_tensors_descs_[outputs[0]->id] =
-          tensor_reserver_.Get(outputs[0]->id).descriptor;
+          tensor_reserver_.Get(outputs[0]->id);
       const_tensors_descs_[outputs[0]->id].UploadData(attr.tensor);
       continue;
     }
@@ -405,12 +406,10 @@
       OperationDef op_def;
       op_def.precision = precision_;
       for (int j = 0; j < inputs.size(); ++j) {
-        op_def.src_tensors.push_back(
-            tensor_reserver_.Get(inputs[j]->id).descriptor);
+        op_def.src_tensors.push_back(tensor_reserver_.Get(inputs[j]->id));
       }
       for (int j = 0; j < outputs.size(); ++j) {
-        op_def.dst_tensors.push_back(
-            tensor_reserver_.Get(outputs[j]->id).descriptor);
+        op_def.dst_tensors.push_back(tensor_reserver_.Get(outputs[j]->id));
       }
       RETURN_IF_ERROR(GPUOperationFromNode(gpu_info, op_def, hints, inputs,
                                            outputs, node, &gpu_subgraph));
@@ -418,7 +417,9 @@
     absl::flat_hash_map<int, ValueId> mapping_to_global_ids;
     for (int j = 0; j < gpu_subgraph.new_tensors.size(); ++j) {
       const auto& t = gpu_subgraph.new_tensors[j];
-      auto global_id = tensor_reserver_.Add({t.first, t.second});
+      TensorDescriptor td = t.second;
+      td.shape = BHWDC(t.first.b, t.first.h, t.first.w, 1, t.first.c);
+      auto global_id = tensor_reserver_.Add(td);
       mapping_to_global_ids[j] = global_id;
     }
     for (auto& gpu_op : gpu_subgraph.operations) {
@@ -525,8 +526,7 @@
     return TensorMemoryType::kConst;
   } else if (variable_ids_and_refs_.find(id) != variable_ids_and_refs_.end()) {
     return TensorMemoryType::kVariable;
-  } else if (IsBufferBased(gpu_info,
-                           tensor_reserver_.Get(id).descriptor.storage_type)) {
+  } else if (IsBufferBased(gpu_info, tensor_reserver_.Get(id).storage_type)) {
     return TensorMemoryType::kBuffer;
   } else {
     return TensorMemoryType::kStrongShape;
@@ -560,7 +560,7 @@
         ref_value_to_tensor_index.end()) {
       const auto& t = tensor_reserver_.Get(value_and_ref_value.first);
       const auto& shape = t.shape;
-      const auto& descriptor = t.descriptor;
+      const auto& descriptor = t;
 
       RETURN_IF_ERROR(
           CreateTensor(*context, shape, descriptor,
@@ -583,7 +583,7 @@
   for (auto& usage : buffer_usages) {
     const auto& t = tensor_reserver_.Get(usage.first);
     const auto& shape = t.shape;
-    const auto& descriptor = t.descriptor;
+    const auto& descriptor = t;
     const size_t element_size =
         descriptor.data_type == DataType::FLOAT32 ? 4 : 2;
     size_t buffer_size;
@@ -664,7 +664,8 @@
         continue;
       const int tensor_index = graph_ids_to_shared_buffer_tensors_[t.first];
       if (created_tensors[tensor_index]) continue;
-      const auto& shape = tensor_reserver_.Get(t.first).shape;
+      const auto& shape_5d = tensor_reserver_.Get(t.first).shape;
+      const auto shape = BHWC(shape_5d.b, shape_5d.h, shape_5d.w, shape_5d.c);
       const int buffer_index = use_offset_assignment
                                    ? tensor_index
                                    : buffer_assignment.object_ids[tensor_index];
@@ -698,7 +699,7 @@
       },
       &usages);
 
-  std::vector<TensorUsageRecord<DummyTensor>> usage_records;
+  std::vector<TensorUsageRecord<TensorDescriptor>> usage_records;
   std::map<ValueId, ValueId> remap_from_graph_ids;
   for (auto& usage : usages) {
     remap_from_graph_ids[usage.first] = usage_records.size();
@@ -707,7 +708,7 @@
                              static_cast<TaskId>(usage.second.y)});
   }
 
-  ObjectsAssignment<DummyTensor> assignment;
+  ObjectsAssignment<TensorDescriptor> assignment;
   RETURN_IF_ERROR(AssignObjectsToTensors(
       usage_records, MemoryStrategy::EQUALITY, &assignment));
 
diff --git a/tensorflow/lite/delegates/gpu/cl/inference_context.h b/tensorflow/lite/delegates/gpu/cl/inference_context.h
index 2266cb3..374ee8e 100644
--- a/tensorflow/lite/delegates/gpu/cl/inference_context.h
+++ b/tensorflow/lite/delegates/gpu/cl/inference_context.h
@@ -170,56 +170,32 @@
   //  anywhere.
   std::vector<CLNode> nodes_;
 
-  struct DummyTensor {
-    BHWC shape;
-    TensorDescriptor descriptor;
-
-    bool operator==(const DummyTensor& b) const {
-      return shape == b.shape && descriptor == b.descriptor;
-    }
-  };
-
   class TensorReserver {
    public:
     TensorReserver() : next_(0) {}
-    ValueId Add(const DummyTensor& dummy) {
+    ValueId Add(const TensorDescriptor& dummy) {
       reservations_[next_] = dummy;
       return next_++;
     }
-    void Add(ValueId id, const DummyTensor& dummy) {
+    void Add(ValueId id, const TensorDescriptor& dummy) {
       reservations_[id] = dummy;
     }
     void SetNext(ValueId id) { next_ = id; }
-    DummyTensor Get(ValueId id) { return reservations_[id]; }
+    TensorDescriptor Get(ValueId id) { return reservations_[id]; }
 
     std::vector<std::pair<ValueId, TensorDescriptor>> GetTensorDescs() const {
-      std::vector<std::pair<ValueId, TensorDescriptor>> result;
-      for (auto& v : reservations_) {
-        TensorDescriptor desc = v.second.descriptor;
-        desc.shape.b = v.second.shape.b;
-        desc.shape.h = v.second.shape.h;
-        desc.shape.w = v.second.shape.w;
-        desc.shape.d = 1;
-        desc.shape.c = v.second.shape.c;
-        result.push_back({v.first, desc});
-      }
-      return result;
+      return std::vector<std::pair<ValueId, TensorDescriptor>>(
+          reservations_.begin(), reservations_.end());
     }
 
     void Add(const std::vector<std::pair<ValueId, TensorDescriptor>>& tensors) {
       for (auto& v : tensors) {
-        DummyTensor dummy;
-        dummy.descriptor = v.second;
-        dummy.shape.b = v.second.shape.b;
-        dummy.shape.h = v.second.shape.h;
-        dummy.shape.w = v.second.shape.w;
-        dummy.shape.c = v.second.shape.c;
-        Add(v.first, dummy);
+        Add(v.first, v.second);
       }
     }
 
    private:
-    absl::flat_hash_map<ValueId, DummyTensor> reservations_;
+    absl::flat_hash_map<ValueId, TensorDescriptor> reservations_;
     ValueId next_;
   };
   TensorReserver tensor_reserver_;
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.cc b/tensorflow/lite/delegates/gpu/cl/tensor.cc
index 301ca6f..d8625fd 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.cc
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.cc
@@ -615,7 +615,17 @@
                                              const TensorDescriptor& descriptor,
                                              int row_bytes_alignment,
                                              Tensor* result) {
-  const int width = shape.b * shape.w;
+  BHWDC shape5d(shape.b, shape.h, shape.w, 1, shape.c);
+  return CreateSharedImage2DBufferTensor(context, memory, shape5d, descriptor,
+                                         row_bytes_alignment, result);
+}
+
+absl::Status CreateSharedImage2DBufferTensor(const CLContext& context,
+                                             cl_mem memory, const BHWDC& shape,
+                                             const TensorDescriptor& descriptor,
+                                             int row_bytes_alignment,
+                                             Tensor* result) {
+  const int width = shape.b * shape.w * shape.d;
   const int height =
       descriptor.storage_type == TensorStorageType::SINGLE_TEXTURE_2D
           ? shape.h
diff --git a/tensorflow/lite/delegates/gpu/cl/tensor.h b/tensorflow/lite/delegates/gpu/cl/tensor.h
index 1d711c8..b45f6a1 100644
--- a/tensorflow/lite/delegates/gpu/cl/tensor.h
+++ b/tensorflow/lite/delegates/gpu/cl/tensor.h
@@ -158,6 +158,12 @@
                                              int row_bytes_alignment,
                                              Tensor* result);
 
+absl::Status CreateSharedImage2DBufferTensor(const CLContext& context,
+                                             cl_mem memory, const BHWDC& shape,
+                                             const TensorDescriptor& descriptor,
+                                             int row_bytes_alignment,
+                                             Tensor* result);
+
 template <DataType T>
 absl::Status Tensor::WriteData(CLCommandQueue* queue,
                                const tflite::gpu::Tensor<BHWC, T>& src) {