Reduce dependency of interfaces on c_api.h
* Remove TF_Tensor, TFE_TensorDebugInfo, TF_CancellationManager from
interfaces.
* Some clean-ups around TensorFromInterface, TensorHandleFromInterface &
OperationFromInterface.
PiperOrigin-RevId: 303848592
Change-Id: Iff1f4b2a377fa9565f4e7ac1ba7a33ee07d77957
diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc
index 5830a7d..1e94987 100644
--- a/tensorflow/c/c_api_experimental.cc
+++ b/tensorflow/c/c_api_experimental.cc
@@ -827,8 +827,7 @@
for (int i = 0; i < num_inputs; ++i) {
node_def.add_input("dummy_input");
}
- tensorflow::down_cast<tensorflow::OperationInterface*>(
- tfe_op->operation.get())
+ OperationFromInterface(tfe_op->operation)
->Attrs()
.FillAttrValueMap(node_def.mutable_attr());
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 4e3915b..a38bdc6 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -1098,13 +1098,22 @@
return nullptr;
}
- return h->handle->Resolve(&status->status);
+ std::unique_ptr<tensorflow::AbstractTensorInterface> t =
+ h->handle->Resolve(&status->status);
+ if (t == nullptr) {
+ return nullptr;
+ }
+
+ tensorflow::Tensor tensor = tensorflow::TensorFromInterface(t);
+ return tensorflow::TF_TensorFromTensor(tensor, &status->status);
}
-TF_Tensor* tensorflow::TensorHandleInterface::Resolve(Status* status) {
+std::unique_ptr<tensorflow::AbstractTensorInterface>
+tensorflow::TensorHandleInterface::Resolve(Status* status) {
if (!IsValid(status)) {
return nullptr;
}
+
if (VariantDeviceIsCustom(handle_->device())) {
tensorflow::CustomDevice* custom_device =
absl::get<tensorflow::CustomDevice*>(handle_->device());
@@ -1133,7 +1142,7 @@
h_cpu->Unref();
return nullptr;
}
- TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
+ auto retval = std::make_unique<tensorflow::TensorInterface>(*t);
h_cpu->Unref();
return retval;
} else {
@@ -1160,7 +1169,7 @@
if (!status->ok()) return nullptr;
}
}
- return tensorflow::TF_TensorFromTensor(tensor, status);
+ return std::make_unique<tensorflow::TensorInterface>(std::move(tensor));
}
}
@@ -1407,7 +1416,10 @@
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
TF_Status* status) {
- status->status = op->operation->SetAttrTensor(attr_name, tensor);
+ tensorflow::Tensor t;
+ status->status = TF_TensorToTensor(tensor, &t);
+ status->status = op->operation->SetAttrTensor(
+ attr_name, std::make_unique<tensorflow::TensorInterface>(t));
}
void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
@@ -1657,16 +1669,14 @@
}
void TFE_OpGetAttrs(TFE_Op* op, TFE_OpAttrs* attrs) {
- auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
- op->operation.get());
- *attrs = TFE_OpAttrs(&operation->Attrs(), op->operation->Name().c_str());
+ tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
+ *attrs = TFE_OpAttrs(&operation->Attrs(), operation->Name().c_str());
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
attrs->attributes->FillAttrValueMap(&m);
- auto operation = tensorflow::down_cast<tensorflow::OperationInterface*>(
- op->operation.get());
+ tensorflow::EagerOperation* operation = OperationFromInterface(op->operation);
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (auto attribute : m) {
destination->Set(attribute.first, attribute.second);
diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc
index 2d6dd21..50f31fa 100644
--- a/tensorflow/c/eager/c_api_debug.cc
+++ b/tensorflow/c/eager/c_api_debug.cc
@@ -54,36 +54,32 @@
TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo(
TFE_TensorHandle* h, TF_Status* status) {
- return h->handle->TensorDebugInfo(&status->status);
-}
-
-TFE_TensorDebugInfo* tensorflow::TensorHandleInterface::TensorDebugInfo(
- Status* status) {
+ tensorflow::TensorHandle* handle = TensorHandleFromInterface(h->handle);
const tensorflow::Tensor* tensor;
- *status = handle_->Tensor(&tensor);
- if (!status->ok()) {
+ status->status = handle->Tensor(&tensor);
+ if (!status->status.ok()) {
return nullptr;
}
#ifdef TENSORFLOW_EAGER_USE_XLA
- tensorflow::Device* device = absl::get<Device*>(handle_->device());
+ auto* device = absl::get<tensorflow::Device*>(handle->device());
// If tensor resides on an XLA device, use XLA device's PaddedShapeFn.
- tensorflow::XlaDevice* xla_device =
- dynamic_cast<tensorflow::XlaDevice*>(device);
+ auto* xla_device = dynamic_cast<tensorflow::XlaDevice*>(device);
if (xla_device != nullptr) {
tensorflow::XlaDevice::PaddedShapeFn shape_fn =
xla_device->metadata().padded_shape_fn();
xla::Shape padded_shape;
- *status = shape_fn(*tensor, &padded_shape);
- if (!status->ok()) {
+ status->status = shape_fn(*tensor, &padded_shape);
+ if (!status->status.ok()) {
return nullptr;
}
if (VLOG_IS_ON(3)) {
- std::vector<int64> shape_to_log = TensorShapeAsVector(*handle_, status);
- if (!status->ok()) {
+ std::vector<int64> shape_to_log =
+ TensorShapeAsVector(*handle, &status->status);
+ if (!status->status.ok()) {
// Ignore the status here as we are simply logging.
- *status = tensorflow::Status::OK();
+ status->status = tensorflow::Status::OK();
} else {
VLOG(3) << "Fully padded shape of ["
<< absl::StrJoin(shape_to_log, ", ") << "] is "
@@ -96,7 +92,7 @@
// Currently, the only case of XlaTensor containing a tuple shape is to
// represent 64 bit ints, doubles, and complex numbers (we don't support
// 64bit complex numbers).
- *status = tensorflow::errors::InvalidArgument(
+ status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should only contain tuples of size 2. Shape: ",
padded_shape.DebugString());
return nullptr;
@@ -108,13 +104,13 @@
const xla::Shape& shape1 =
xla::ShapeUtil::GetTupleElementShape(padded_shape, 1);
if (shape0.IsTuple() || shape1.IsTuple()) {
- *status = tensorflow::errors::InvalidArgument(
+ status->status = tensorflow::errors::InvalidArgument(
"XlaTensors should not contain nested tuples. Shape: ",
padded_shape.DebugString());
return nullptr;
}
if (!xla::ShapeUtil::Equal(shape0, shape1)) {
- *status = tensorflow::errors::InvalidArgument(
+ status->status = tensorflow::errors::InvalidArgument(
"Subshapes of XlaTensors should be the same. Shape: ",
padded_shape.DebugString());
return nullptr;
@@ -139,15 +135,15 @@
dev_dims.push_back(padded_shape.dimensions(dim_index));
}
}
- *status = tensorflow::Status::OK();
+ status->status = tensorflow::Status::OK();
return new TFE_TensorDebugInfo(dev_dims);
}
#endif // TENSORFLOW_EAGER_USE_XLA
// If the tensor is not an XLA tensor, the device shape is
// the same as regular tensor shape.
- std::vector<int64> dev_dims = TensorShapeAsVector(*handle_, status);
- if (!status->ok()) {
+ std::vector<int64> dev_dims = TensorShapeAsVector(*handle, &status->status);
+ if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorDebugInfo(dev_dims);
diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc
index 6f4bf82..4d01a06 100644
--- a/tensorflow/c/eager/c_api_experimental.cc
+++ b/tensorflow/c/eager/c_api_experimental.cc
@@ -526,7 +526,11 @@
void TFE_OpSetCancellationManager(TFE_Op* op,
TFE_CancellationManager* cancellation_manager,
TF_Status* status) {
- status->status = op->operation->SetCancellationManager(cancellation_manager);
+ tensorflow::EagerOperation* operation =
+ tensorflow::OperationFromInterface(op->operation);
+ operation->SetCancellationManager(
+ &cancellation_manager->cancellation_manager);
+ status->status = tensorflow::Status::OK();
}
TFE_Executor* TFE_NewExecutor(bool is_async) {
diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/context_interface.h
index fa1d8cb..665651c 100644
--- a/tensorflow/c/eager/context_interface.h
+++ b/tensorflow/c/eager/context_interface.h
@@ -22,6 +22,7 @@
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
namespace tensorflow {
diff --git a/tensorflow/c/eager/operation_interface.cc b/tensorflow/c/eager/operation_interface.cc
index 6dd5a70..136fdef 100644
--- a/tensorflow/c/eager/operation_interface.cc
+++ b/tensorflow/c/eager/operation_interface.cc
@@ -98,9 +98,8 @@
AttrValue attr_value;
NameAttrList* func = attr_value.mutable_func();
func->set_name(value->Name());
- OperationInterface* value_operation =
- tensorflow::down_cast<OperationInterface*>(value.get());
- value_operation->operation_.Attrs().FillAttrValueMap(func->mutable_attr());
+ EagerOperation* value_operation = OperationFromInterface(value);
+ value_operation->Attrs().FillAttrValueMap(func->mutable_attr());
operation_.MutableAttrs()->Set(attr_name, attr_value);
return Status::OK();
}
@@ -115,10 +114,9 @@
return Status::OK();
}
-Status OperationInterface::SetAttrTensor(const char* attr_name,
- TF_Tensor* tensor) {
- Tensor t;
- TF_RETURN_IF_ERROR(TF_TensorToTensor(tensor, &t));
+Status OperationInterface::SetAttrTensor(
+ const char* attr_name, std::unique_ptr<AbstractTensorInterface> tensor) {
+ Tensor t = TensorFromInterface(tensor);
operation_.MutableAttrs()->Set(attr_name, t);
return Status::OK();
}
@@ -208,11 +206,10 @@
int num_values) {
std::unique_ptr<NameAttrList[]> funcs(new NameAttrList[num_values]);
for (int i = 0; i < num_values; i++) {
- auto value_operation =
- tensorflow::down_cast<OperationInterface*>(value[i]->operation.get());
- funcs[i].set_name(value_operation->operation_.Name());
- value_operation->operation_.Attrs().FillAttrValueMap(
- funcs[i].mutable_attr());
+ EagerOperation* value_operation =
+ OperationFromInterface(value[i]->operation);
+ funcs[i].set_name(value_operation->Name());
+ value_operation->Attrs().FillAttrValueMap(funcs[i].mutable_attr());
}
operation_.MutableAttrs()->Set(
attr_name, gtl::ArraySlice<const NameAttrList>(funcs.get(), num_values));
@@ -294,13 +291,6 @@
return Status::OK();
}
-Status OperationInterface::SetCancellationManager(
- TFE_CancellationManager* cancellation_manager) {
- operation_.SetCancellationManager(
- &cancellation_manager->cancellation_manager);
- return Status::OK();
-}
-
Status OperationInterface::SetUseXla(bool enable) {
operation_.SetUseXla(enable);
return Status::OK();
diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/operation_interface.h
index 499728c..eb81882 100644
--- a/tensorflow/c/eager/operation_interface.h
+++ b/tensorflow/c/eager/operation_interface.h
@@ -19,9 +19,12 @@
#include "absl/container/fixed_array.h"
#include "tensorflow/c/eager/c_api.h"
-#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
+#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
+#include "tensorflow/core/framework/tensor_interface.h"
+#include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/status.h"
namespace tensorflow {
@@ -60,7 +63,9 @@
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) = 0;
- virtual Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) = 0;
+ virtual Status SetAttrTensor(
+ const char* attr_name,
+ std::unique_ptr<AbstractTensorInterface> tensor) = 0;
virtual Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths, int num_values) = 0;
@@ -82,14 +87,7 @@
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
- virtual Status SetUseXla(bool enable) {
- return errors::Unimplemented("SetUseXla not implemented");
- }
-
- virtual Status SetCancellationManager(
- TFE_CancellationManager* cancellation_manager) {
- return errors::Unimplemented("SetCancellationManager not implemented");
- }
+ virtual Status SetUseXla(bool enable) = 0;
};
class OpDef;
@@ -133,7 +131,9 @@
const std::unique_ptr<AbstractOperationInterface>& value) override;
Status SetAttrFunctionName(const char* attr_name, const char* data,
size_t length) override;
- Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) override;
+ Status SetAttrTensor(
+ const char* attr_name,
+ std::unique_ptr<AbstractTensorInterface> tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
@@ -153,8 +153,6 @@
Status OutputLength(const char* output_name, int* length) override;
Status SetUseXla(bool enable) override;
- Status SetCancellationManager(
- TFE_CancellationManager* cancellation_manager) override;
// TODO(gjn): Remove once TFE_InferShapes is removed
const AttrBuilder& Attrs() const { return operation_.Attrs(); }
@@ -162,11 +160,18 @@
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
+ EagerOperation* Operation() { return &operation_; }
+
private:
const tensorflow::OpDef* GetOpDef(Status* status);
EagerOperation operation_;
};
+inline EagerOperation* OperationFromInterface(
+ const std::unique_ptr<AbstractOperationInterface>& operation) {
+ return down_cast<OperationInterface*>(operation.get())->Operation();
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_
diff --git a/tensorflow/c/eager/tensor_handle_interface.h b/tensorflow/c/eager/tensor_handle_interface.h
index 2ea1a65..6d73ff3 100644
--- a/tensorflow/c/eager/tensor_handle_interface.h
+++ b/tensorflow/c/eager/tensor_handle_interface.h
@@ -15,11 +15,11 @@
#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
-#include "tensorflow/c/c_api.h"
-#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
+#include "tensorflow/core/framework/tensor_interface.h"
#include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/status.h"
namespace tensorflow {
@@ -52,9 +52,7 @@
// Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied.
- virtual TF_Tensor* Resolve(Status* status) = 0;
- // Returns debug information about the tensor.
- virtual TFE_TensorDebugInfo* TensorDebugInfo(Status* status) = 0;
+ virtual std::unique_ptr<AbstractTensorInterface> Resolve(Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
@@ -84,8 +82,7 @@
const char* DeviceName(Status* status) const override;
const char* BackingDeviceName(Status* status) const override;
- TF_Tensor* Resolve(Status* status) override;
- TFE_TensorDebugInfo* TensorDebugInfo(Status* status) override;
+ std::unique_ptr<AbstractTensorInterface> Resolve(Status* status) override;
AbstractTensorHandleInterface* Copy() override;
diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h
index 08a55f2..2d31418 100644
--- a/tensorflow/c/tf_tensor_internal.h
+++ b/tensorflow/c/tf_tensor_internal.h
@@ -31,7 +31,7 @@
// passed to or returned from C functions *by pointer*. Otherwise, changes to
// its internal structure will break the C API's binary interface.
typedef struct TF_Tensor {
- std::unique_ptr<AbstractTensorInterface> tensor;
+ std::unique_ptr<tensorflow::AbstractTensorInterface> tensor;
} TF_Tensor;
class TF_ManagedBuffer : public tensorflow::TensorBuffer {
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc
index 7c4d046..81d0528 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.cc
+++ b/tensorflow/core/common_runtime/eager/eager_operation.cc
@@ -58,7 +58,15 @@
executor_ = executor ? executor : &ctx_.Executor();
remote_func_params_ = remote_func_params;
op_name_ = op;
- return SetDeviceName(raw_device_name, true);
+ if (raw_device_name != nullptr && strlen(raw_device_name) > 0) {
+ return SetDeviceName(raw_device_name);
+ } else {
+ raw_device_name_.clear();
+ device_name_.clear();
+ device_parsed_name_.Clear();
+ device_ = kVariantDeviceNull;
+ return Status::OK();
+ }
}
Status EagerOperation::MaybeInferSingleInputAttrs(TensorHandle* handle) {
@@ -128,7 +136,7 @@
return Status::OK();
}
-Status EagerOperation::SetDeviceName(const char* device, const bool reset) {
+Status EagerOperation::SetDeviceName(const char* device) {
if (device != nullptr && strlen(device) > 0) {
if (device != raw_device_name_) {
if (!DeviceNameUtils::ParseFullName(device, &device_parsed_name_)) {
@@ -150,11 +158,6 @@
device_ = kVariantDeviceNull;
}
}
- } else if (reset) {
- raw_device_name_.clear();
- device_name_.clear();
- device_parsed_name_.Clear();
- device_ = kVariantDeviceNull;
}
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h
index 7f15033..1ba55ea 100644
--- a/tensorflow/core/common_runtime/eager/eager_operation.h
+++ b/tensorflow/core/common_runtime/eager/eager_operation.h
@@ -98,7 +98,7 @@
const DeviceNameUtils::ParsedName& GetDeviceParsedName() const {
return device_parsed_name_;
}
- Status SetDeviceName(const char* device, const bool reset = false);
+ Status SetDeviceName(const char* device);
// Indicates whether the op is assigned to a device that is local to the
// current host.
diff --git a/tensorflow/core/framework/tensor_interface.h b/tensorflow/core/framework/tensor_interface.h
index 115e1a0..9388330 100644
--- a/tensorflow/core/framework/tensor_interface.h
+++ b/tensorflow/core/framework/tensor_interface.h
@@ -17,8 +17,11 @@
#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_
#include "tensorflow/c/tf_datatype.h"
-#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/casts.h"
+#include "tensorflow/core/platform/status.h"
+
+namespace tensorflow {
// Abstract interface to a Tensor.
//
@@ -49,8 +52,6 @@
virtual bool CanMove() const = 0;
};
-namespace tensorflow {
-
class TensorInterface : public AbstractTensorInterface {
public:
TensorInterface() {}
@@ -72,12 +73,17 @@
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
- tensorflow::Tensor Tensor() { return tensor_; }
+ tensorflow::Tensor& Tensor() { return tensor_; }
private:
tensorflow::Tensor tensor_;
};
+inline Tensor& TensorFromInterface(
+ const std::unique_ptr<AbstractTensorInterface>& tensor) {
+ return down_cast<TensorInterface*>(tensor.get())->Tensor();
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_INTERFACE_H_