Reduce dependency of interfaces on c_api.h

* Remove TF_Tensor, TFE_TensorDebugInfo, TF_CancellationManager from
  interfaces.
* Some clean-ups around TensorFromInterface, TensorHandleFromInterface &
  OperationFromInterface.

PiperOrigin-RevId: 303848592
Change-Id: Iff1f4b2a377fa9565f4e7ac1ba7a33ee07d77957
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 5830a7d..1e94987 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -827,8 +827,7 @@
   for (int i = 0; i < num_inputs; ++i) {
     node_def.add_input("dummy_input");
   }
-  tensorflow::down_cast<tensorflow::OperationInterface*>(
-      tfe_op->operation.get())
+  OperationFromInterface(tfe_op->operation)
       ->Attrs()
       .FillAttrValueMap(node_def.mutable_attr());
 
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 4e3915b..a38bdc6 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -1098,13 +1098,22 @@
     return nullptr;
   }
 
-  return h->handle->Resolve(&status->status);
+  std::unique_ptr<tensorflow::AbstractTensorInterface> t =
+      h->handle->Resolve(&status->status);
+  if (t == nullptr) {
+    return nullptr;
+  }
+
+  tensorflow::Tensor tensor = tensorflow::TensorFromInterface(t);
+  return tensorflow::TF_TensorFromTensor(tensor, &status->status);
 }
 
-TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
+std::unique_ptr<tensorflow::AbstractTensorInterface>
+tensorflow::TensorHandleInterface::Resolve(Status* status) {
   if (!IsValid(status)) {
     return nullptr;
   }
+
   if (VariantDeviceIsCustom(handle_->device())) {
     tensorflow::CustomDevice* custom_device =
         absl::get<tensorflow::CustomDevice*>(handle_->device());
@@ -1133,7 +1142,7 @@
       h_cpu->Unref();
       return nullptr;
     }
-    TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
+    auto retval = std::make_unique<tensorflow::TensorInterface>(*t);
     h_cpu->Unref();
     return retval;
   } else {
@@ -1160,7 +1169,7 @@
         if (!status->ok()) return nullptr;
       }
     }
-    return tensorflow::TF_TensorFromTensor(tensor, status);
+    return std::make_unique<tensorflow::TensorInterface>(std::move(tensor));
   }
 }
 
@@ -1407,7 +1416,10 @@
 
 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
                          TF_Status* status) {
-  status->status = op->operation->SetAttrTensor(attr_name, tensor);
+  tensorflow::Tensor t;
+  status->status = TF_TensorToTensor(tensor, &t);
+  status->status = op->operation->SetAttrTensor(
+      attr_name, std::make_unique<tensorflow::TensorInterface>(t));
 }
 
 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
@@ -1657,16 +1669,14 @@
 }
 
 void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
-  auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
-      op->operation.get());
-  *attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
+  tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
+  *attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str());
 }
 
 void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
   tensorflow::AttrValueMap m;
   attrs->attributes->FillAttrValueMap(&m);
-  auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
-      op->operation.get());
+  tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
   tensorflow::AttrBuilder* destination = operation->MutableAttrs();
   for (auto attribute : m) {
     destination->Set(attribute.first, attribute.second);
diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc
index 2d6dd21..50f31fa 100644
--- a/tensorflow/c/eager/c_api_debug.cc
+++ b/tensorflow/c/eager/c_api_debug.cc
@@ -54,36 +54,32 @@
 
 TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
     TFE_TensorHandle* h, TF_Status* status) {
-  return h->handle->TensorDebugInfo(&status->status);
-}
-
-TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
-    Status* status) {
+  tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
   const tensorflow::Tensor* tensor;
-  *status = handle_->Tensor(&tensor);
-  if (!status->ok()) {
+  status->status = handle->Tensor(&tensor);
+  if (!status->status.ok()) {
     return nullptr;
   }
 
 #ifdef TENSORFLOW_EAGER_USE_XLA
-  tensorflow::Device* device = absl::get<Device*>(handle_->device());
+  auto* device = absl::get<tensorflow::Device*>(handle->device());
 
   // If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
-  tensorflow::XlaDevice* xla_device =
-      dynamic_cast<tensorflow::XlaDevice*>(device);
+  auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
   if (xla_device != nullptr) {
     tensorflow::XlaDevice::PaddedShapeFn shape_fn =
         xla_device->metadata().padded_shape_fn();
     xla::Shape padded_shape;
-    *status = shape_fn(*tensor, &padded_shape);
-    if (!status->ok()) {
+    status->status = shape_fn(*tensor, &padded_shape);
+    if (!status->status.ok()) {
       return nullptr;
     }
     if (VLOG_IS_ON(3)) {
-      std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
-      if (!status->ok()) {
+      std::vector<int64> shape_to_log =
+          TensorShapeAsVector(*handle, &status->status);
+      if (!status->status.ok()) {
         // Ignore the status here as we are simply logging.
-        *status = tensorflow::Status::OK();
+        status->status = tensorflow::Status::OK();
       } else {
         VLOG(3) << "Fully padded shape of ["
                 << absl::StrJoin(shape_to_log, ", ") << "] is "
@@ -96,7 +92,7 @@
         // Currently, the only case of XlaTensor containing a tuple shape is to
         // represent 64 bit ints, doubles, and complex numbers (we don't support
         // 64bit complex numbers).
-        *status = tensorflow::errors::InvalidArgument(
+        status->status = tensorflow::errors::InvalidArgument(
             "XlaTensors should only contain tuples of size 2. Shape: ",
             padded_shape.DebugString());
         return nullptr;
@@ -108,13 +104,13 @@
       const xla::Shape& shape1 =
           xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
       if (shape0.IsTuple() || shape1.IsTuple()) {
-        *status = tensorflow::errors::InvalidArgument(
+        status->status = tensorflow::errors::InvalidArgument(
             "XlaTensors should not contain nested tuples. Shape: ",
             padded_shape.DebugString());
         return nullptr;
       }
       if (!xla::ShapeUtil::Equal(shape0, shape1)) {
-        *status = tensorflow::errors::InvalidArgument(
+        status->status = tensorflow::errors::InvalidArgument(
             "Subshapes of XlaTensors should be the same. Shape: ",
             padded_shape.DebugString());
         return nullptr;
@@ -139,15 +135,15 @@
         dev_dims.push_back(padded_shape.dimensions(dim_index));
       }
     }
-    *status = tensorflow::Status::OK();
+    status->status = tensorflow::Status::OK();
     return new TFE_TensorDebugInfo(dev_dims);
   }
 #endif  // TENSORFLOW_EAGER_USE_XLA
 
   // If the tensor is not an XLA tensor, the device shape is
   // the same as regular tensor shape.
-  std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
-  if (!status->ok()) {
+  std::vector<int64> dev_dims = TensorShapeAsVector(*handle, &status->status);
+  if (!status->status.ok()) {
     return nullptr;
   }
   return new TFE_TensorDebugInfo(dev_dims);
diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc
index 6f4bf82..4d01a06 100644
--- a/tensorflow/c/eager/c_api_experimental.cc
+++ b/tensorflow/c/eager/c_api_experimental.cc
@@ -526,7 +526,11 @@
 void TFE_OpSetCancellationManager(TFE_Op* op,
                                   TFE_CancellationManager* cancellation_manager,
                                   TF_Status* status) {
-  status->status = op->operation->SetCancellationManager(cancellation_manager);
+  tensorflow::EagerOperation* operation =
+      tensorflow::OperationFromInterface(op->operation);
+  operation->SetCancellationManager(
+      &cancellation_manager->cancellation_manager);
+  status->status = tensorflow::Status::OK();
 }
 
 TFE_Executor* TFE_NewExecutor(bool is_async) {
diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h
index fa1d8cb..665651c 100644
--- a/tensorflow/c/eager/context_interface.h
+++ b/tensorflow/c/eager/context_interface.h
@@ -22,6 +22,7 @@
 #include "tensorflow/core/framework/numeric_types.h"
 #include "tensorflow/core/framework/tensor_interface.h"
 #include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/status.h"
 #include "tensorflow/core/platform/tstring.h"
 
 namespace tensorflow {
diff --git a/tensorflow/c/eager/operation_interface.cc b/tensorflow/c/eager/operation_interface.cc
index 6dd5a70..136fdef 100644
--- a/tensorflow/c/eager/operation_interface.cc
+++ b/tensorflow/c/eager/operation_interface.cc
@@ -98,9 +98,8 @@
   AttrValue attr_value;
   NameAttrList* func = attr_value.mutable_func();
   func->set_name(value->Name());
-  OperationInterface* value_operation =
-      tensorflow::down_cast<OperationInterface*>(value.get());
-  value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
+  EagerOperation* value_operation = OperationFromInterface(value);
+  value_operation->Attrs().FillAttrValueMap(func->mutable_attr());
   operation_.MutableAttrs()->Set(attr_name, attr_value);
   return Status::OK();
 }
@@ -115,10 +114,9 @@
   return Status::OK();
 }
 
-Status OperationInterface::SetAttrTensor(const char* attr_name,
-                                         TF_Tensor* tensor) {
-  Tensor t;
-  TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
+Status OperationInterface::SetAttrTensor(
+    const char* attr_name, std::unique_ptr<AbstractTensorInterface> tensor) {
+  Tensor t = TensorFromInterface(tensor);
   operation_.MutableAttrs()->Set(attr_name, t);
   return Status::OK();
 }
@@ -208,11 +206,10 @@
                                                int num_values) {
   std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
   for (int i = 0; i < num_values; i++) {
-    auto value_operation =
-        tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
-    funcs[i].set_name(value_operation->operation_.Name());
-    value_operation->operation_.Attrs().FillAttrValueMap(
-        funcs[i].mutable_attr());
+    EagerOperation* value_operation =
+        OperationFromInterface(value[i]->operation);
+    funcs[i].set_name(value_operation->Name());
+    value_operation->Attrs().FillAttrValueMap(funcs[i].mutable_attr());
   }
   operation_.MutableAttrs()->Set(
       attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
@@ -294,13 +291,6 @@
   return Status::OK();
 }
 
-Status OperationInterface::SetCancellationManager(
-    TFE_CancellationManager* cancellation_manager) {
-  operation_.SetCancellationManager(
-      &cancellation_manager->cancellation_manager);
-  return Status::OK();
-}
-
 Status OperationInterface::SetUseXla(bool enable) {
   operation_.SetUseXla(enable);
   return Status::OK();
diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/operation_interface.h
index 499728c..eb81882 100644
--- a/tensorflow/c/eager/operation_interface.h
+++ b/tensorflow/c/eager/operation_interface.h
@@ -19,9 +19,12 @@
 
 #include "absl/container/fixed_array.h"
 #include "tensorflow/c/eager/c_api.h"
-#include "tensorflow/c/eager/c_api_experimental.h"
 #include "tensorflow/c/eager/tensor_handle_interface.h"
+#include "tensorflow/c/tf_datatype.h"
 #include "tensorflow/core/common_runtime/eager/eager_operation.h"
+#include "tensorflow/core/framework/tensor_interface.h"
+#include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/status.h"
 
 namespace tensorflow {
 
@@ -60,7 +63,9 @@
       const std::unique_ptr<AbstractOperationInterface>& value) = 0;
   virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
                                      size_t length) = 0;
-  virtual Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) = 0;
+  virtual Status SetAttrTensor(
+      const char* attr_name,
+      std::unique_ptr<AbstractTensorInterface> tensor) = 0;
   virtual Status SetAttrStringList(const char* attr_name,
                                    const void* const* values,
                                    const size_t* lengths, int num_values) = 0;
@@ -82,14 +87,7 @@
   virtual Status OutputLength(const char* output_name, int* length) = 0;
 
   // Experimental
-  virtual Status SetUseXla(bool enable) {
-    return errors::Unimplemented("SetUseXla not implemented");
-  }
-
-  virtual Status SetCancellationManager(
-      TFE_CancellationManager* cancellation_manager) {
-    return errors::Unimplemented("SetCancellationManager not implemented");
-  }
+  virtual Status SetUseXla(bool enable) = 0;
 };
 
 class OpDef;
@@ -133,7 +131,9 @@
       const std::unique_ptr<AbstractOperationInterface>& value) override;
   Status SetAttrFunctionName(const char* attr_name, const char* data,
                              size_t length) override;
-  Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
+  Status SetAttrTensor(
+      const char* attr_name,
+      std::unique_ptr<AbstractTensorInterface> tensor) override;
   Status SetAttrStringList(const char* attr_name, const void* const* values,
                            const size_t* lengths, int num_values) override;
   Status SetAttrFloatList(const char* attr_name, const float* values,
@@ -153,8 +153,6 @@
   Status OutputLength(const char* output_name, int* length) override;
 
   Status SetUseXla(bool enable) override;
-  Status SetCancellationManager(
-      TFE_CancellationManager* cancellation_manager) override;
 
   // TODO(gjn): Remove once TFE_InferShapes is removed
   const AttrBuilder& Attrs() const { return operation_.Attrs(); }
@@ -162,11 +160,18 @@
 
   const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
 
+  EagerOperation* Operation() { return &operation_; }
+
  private:
   const tensorflow::OpDef* GetOpDef(Status* status);
   EagerOperation operation_;
 };
 
+inline EagerOperation* OperationFromInterface(
+    const std::unique_ptr<AbstractOperationInterface>& operation) {
+  return down_cast<OperationInterface*>(operation.get())->Operation();
+}
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
diff --git a/tensorflow/c/eager/tensor_handle_interface.h b/tensorflow/c/eager/tensor_handle_interface.h
index 2ea1a65..6d73ff3 100644
--- a/tensorflow/c/eager/tensor_handle_interface.h
+++ b/tensorflow/c/eager/tensor_handle_interface.h
@@ -15,11 +15,11 @@
 #ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
 #define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
 
-#include "tensorflow/c/c_api.h"
-#include "tensorflow/c/eager/c_api.h"
 #include "tensorflow/c/tf_datatype.h"
 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/tensor_interface.h"
 #include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/status.h"
 
 namespace tensorflow {
 
@@ -52,9 +52,7 @@
   // Returns the device where the tensor was placed.
   virtual const char* BackingDeviceName(Status* status) const = 0;
   // Returns a tensor for the handle. If tensor is remote, it will be copied.
-  virtual TF_Tensor* Resolve(Status* status) = 0;
-  // Returns debug information about the tensor.
-  virtual TFE_TensorDebugInfo* TensorDebugInfo(Status* status) = 0;
+  virtual std::unique_ptr<AbstractTensorInterface> Resolve(Status* status) = 0;
 
   // Return a copy of the handle.
   virtual AbstractTensorHandleInterface* Copy() = 0;
@@ -84,8 +82,7 @@
 
   const char* DeviceName(Status* status) const override;
   const char* BackingDeviceName(Status* status) const override;
-  TF_Tensor* Resolve(Status* status) override;
-  TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
+  std::unique_ptr<AbstractTensorInterface> Resolve(Status* status) override;
 
   AbstractTensorHandleInterface* Copy() override;
 
diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h
index 08a55f2..2d31418 100644
--- a/tensorflow/c/tf_tensor_internal.h
+++ b/tensorflow/c/tf_tensor_internal.h
@@ -31,7 +31,7 @@
 // passed to or returned from C functions *by pointer*. Otherwise, changes to
 // its internal structure will break the C API's binary interface.
 typedef struct TF_Tensor {
-  std::unique_ptr<AbstractTensorInterface> tensor;
+  std::unique_ptr<tensorflow::AbstractTensorInterface> tensor;
 } TF_Tensor;
 
 class TF_ManagedBuffer : public tensorflow::TensorBuffer {
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc
index 7c4d046..81d0528 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.cc
+++ b/tensorflow/core/common_runtime/eager/eager_operation.cc
@@ -58,7 +58,15 @@
   executor_ = executor ? executor : &ctx_.Executor();
   remote_func_params_ = remote_func_params;
   op_name_ = op;
-  return SetDeviceName(raw_device_name, true);
+  if (raw_device_name != nullptr && strlen(raw_device_name) > 0) {
+    return SetDeviceName(raw_device_name);
+  } else {
+    raw_device_name_.clear();
+    device_name_.clear();
+    device_parsed_name_.Clear();
+    device_ = kVariantDeviceNull;
+    return Status::OK();
+  }
 }
 
 Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
@@ -128,7 +136,7 @@
   return Status::OK();
 }
 
-Status EagerOperation::SetDeviceName(const char* device, const bool reset) {
+Status EagerOperation::SetDeviceName(const char* device) {
   if (device != nullptr && strlen(device) > 0) {
     if (device != raw_device_name_) {
       if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
@@ -150,11 +158,6 @@
         device_ = kVariantDeviceNull;
       }
     }
-  } else if (reset) {
-    raw_device_name_.clear();
-    device_name_.clear();
-    device_parsed_name_.Clear();
-    device_ = kVariantDeviceNull;
   }
   return Status::OK();
 }
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h
index 7f15033..1ba55ea 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.h
+++ b/tensorflow/core/common_runtime/eager/eager_operation.h
@@ -98,7 +98,7 @@
   const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
     return device_parsed_name_;
   }
-  Status SetDeviceName(const char* device, const bool reset = false);
+  Status SetDeviceName(const char* device);
 
   // Indicates whether the op is assigned to a device that is local to the
   // current host.
diff --git a/tensorflow/core/framework/tensor_interface.h b/tensorflow/core/framework/tensor_interface.h
index 115e1a0..9388330 100644
--- a/tensorflow/core/framework/tensor_interface.h
+++ b/tensorflow/core/framework/tensor_interface.h
@@ -17,8 +17,11 @@
 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
 
 #include "tensorflow/c/tf_datatype.h"
-#include "tensorflow/c/tf_status.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/status.h"
+
+namespace tensorflow {
 
 // Abstract interface to a Tensor.
 //
@@ -49,8 +52,6 @@
   virtual bool CanMove() const = 0;
 };
 
-namespace tensorflow {
-
 class TensorInterface : public AbstractTensorInterface {
  public:
   TensorInterface() {}
@@ -72,12 +73,17 @@
 
   // TODO(gjn): This is not a very generic interface, but is needed for specific
   // use cases.
-  tensorflow::Tensor Tensor() { return tensor_; }
+  tensorflow::Tensor& Tensor() { return tensor_; }
 
  private:
   tensorflow::Tensor tensor_;
 };
 
+inline Tensor& TensorFromInterface(
+    const std::unique_ptr<AbstractTensorInterface>& tensor) {
+  return down_cast<TensorInterface*>(tensor.get())->Tensor();
+}
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_