blob: f5085fdb926c9194d0c9194f2f3e15b6503b193b [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/core/common_runtime/eager/attr_builder.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace gradients {
Status GradientRegistry::Register(const string& op_name,
GradientFunctionFactory factory) {
auto iter = registry_.find(op_name);
if (iter != registry_.end()) {
const string error_msg = "Gradient already exists for op: " + op_name + ".";
return errors::AlreadyExists(error_msg);
}
registry_.insert({op_name, factory});
return Status::OK();
}
Status GradientRegistry::Lookup(
const ForwardOperation& op,
std::unique_ptr<GradientFunction>* grad_fn) const {
auto iter = registry_.find(op.op_name);
if (iter == registry_.end()) {
const string error_msg = "No gradient defined for op: " + op.op_name + ".";
return errors::NotFound(error_msg);
}
grad_fn->reset(iter->second(op));
return Status::OK();
}
int64 ToId(AbstractTensorHandle* t) {
return static_cast<int64>(reinterpret_cast<uintptr_t>(t));
}
TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx)
: handle_(handle), ctx_(ctx) {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
}
TapeTensor::TapeTensor(const TapeTensor& other) {
handle_ = other.handle_;
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Ref();
ctx_ = other.ctx_;
}
TapeTensor::~TapeTensor() {
// TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely
// on the client to keep this tensor live for the duration of the gradient
// computation.
// handle_->Unref();
}
tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); }
tensorflow::DataType TapeTensor::GetDType() const {
return handle_->DataType();
}
AbstractTensorHandle* TapeTensor::OnesLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
Status s = op->Reset("OnesLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
AbstractTensorHandle* TapeTensor::ZerosLike() const {
AbstractOperationPtr op(ctx_->CreateOperation());
// TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR.
Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
if (isa<tracing::TracingOperation>(op.get())) {
s = dyn_cast<tracing::TracingOperation>(op.get())->SetOpName(
absl::StrCat("OnesLike", ToId(handle_)).c_str());
if (!s.ok()) {
return nullptr;
}
}
s = op->AddInput(handle_);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
// TODO(srbs): Figure out who is in charge of releasing this.
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
// Returns the number of elements in the gradient tensor.
int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const {
// TODO(srbs): It seems like this is used only for performance optimization
// and not for correctness. The only downside of keeping this 1 seems to be
// that the gradient accumulation is unbounded and we will never
// aggressively aggregate accumulated gradients to recover memory.
// Revisit and fix.
return 1;
}
// Consumes references to the tensors in the gradient_tensors list and returns
// a tensor with the result.
AbstractTensorHandle* TapeVSpace::AggregateGradients(
gtl::ArraySlice<AbstractTensorHandle*> gradient_tensors) const {
if (gradient_tensors.size() == 1) {
return gradient_tensors[0];
}
AbstractOperationPtr op(ctx_->CreateOperation());
Status s = op->Reset("AddN", /*raw_device_name=*/nullptr);
if (!s.ok()) {
return nullptr;
}
s = op->AddInputList(gradient_tensors);
if (!s.ok()) {
return nullptr;
}
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
s = op->Execute(absl::Span<AbstractTensorHandle*>(outputs), &num_outputs);
if (!s.ok()) {
return nullptr;
}
return outputs[0];
}
// Calls the passed-in backward function.
Status TapeVSpace::CallBackwardFunction(
GradientFunction* backward_function,
const std::vector<int64>& unneeded_gradients,
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK();
Context ctx = {ctx_};
return backward_function->Compute(&ctx, output_gradients, result);
}
// Looks up the ID of a Gradient.
int64 TapeVSpace::TensorId(AbstractTensorHandle* tensor) const {
return ToId(tensor);
}
// Converts a Gradient to a TapeTensor.
TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const {
return TapeTensor(g, ctx_);
}
void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {}
void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const {
gradient->Release();
}
// Helper functions which delegate to `AbstractOperation`, update
// the state of the ForwardOperation and call the tape as appropriate.
// These APIs are mainly to faciliate testing and are subject to change.
namespace internal {
Status Reset(AbstractOperation* op_, const char* op,
const char* raw_device_name, ForwardOperation* forward_op_) {
forward_op_->op_name = op;
return op_->Reset(op, raw_device_name);
}
Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input,
ForwardOperation* forward_op_) {
TF_RETURN_IF_ERROR(op_->AddInput(input));
forward_op_->inputs.push_back(input);
return Status::OK();
}
Status AddInputList(AbstractOperation* op_,
absl::Span<AbstractTensorHandle* const> inputs,
ForwardOperation* forward_op_) {
TF_RETURN_IF_ERROR(op_->AddInputList(inputs));
for (auto input : inputs) {
forward_op_->inputs.push_back(input);
}
return Status::OK();
}
Status SetAttrString(AbstractOperation* op_, const char* attr_name,
const char* data, size_t length,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, StringPiece(data, length));
return op_->SetAttrString(attr_name, data, length);
}
Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, static_cast<int64>(value));
return op_->SetAttrInt(attr_name, value);
}
Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrFloat(attr_name, value);
}
Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrBool(attr_name, value);
}
Status SetAttrType(AbstractOperation* op_, const char* attr_name,
DataType value, ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name, value);
return op_->SetAttrType(attr_name, value);
}
Status SetAttrShape(AbstractOperation* op_, const char* attr_name,
const int64_t* dims, const int num_dims,
ForwardOperation* forward_op_) {
if (num_dims > TensorShape::MaxDimensions()) {
return errors::InvalidArgument("Value specified for `", attr_name, "` has ",
num_dims,
" dimensions which is over the limit of ",
TensorShape::MaxDimensions(), ".");
}
TensorShapeProto proto;
if (num_dims < 0) {
proto.set_unknown_rank(true);
} else {
for (int d = 0; d < num_dims; ++d) {
proto.add_dim()->set_size(dims[d]);
}
}
forward_op_->attrs.Set(attr_name, proto);
return op_->SetAttrShape(attr_name, dims, num_dims);
}
Status SetAttrFunction(AbstractOperation* op_, const char* attr_name,
const AbstractOperation* value,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunction has not been implemented yet.");
}
Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name,
const char* value, size_t length,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented "
"yet.");
}
Status SetAttrTensor(AbstractOperation* op_, const char* attr_name,
AbstractTensorInterface* tensor,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrTensor has not been implemented yet.");
}
Status SetAttrStringList(AbstractOperation* op_, const char* attr_name,
const void* const* values, const size_t* lengths,
int num_values, ForwardOperation* forward_op_) {
std::vector<StringPiece> v(num_values);
for (int i = 0; i < num_values; ++i) {
v[i] = StringPiece(static_cast<const char*>(values[i]), lengths[i]);
}
forward_op_->attrs.Set(attr_name, v);
return op_->SetAttrStringList(attr_name, values, lengths, num_values);
}
Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name,
const float* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const float>(values, num_values));
return op_->SetAttrFloatList(attr_name, values, num_values);
}
Status SetAttrIntList(AbstractOperation* op_, const char* attr_name,
const int64_t* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(
attr_name, gtl::ArraySlice<const int64>(
reinterpret_cast<const int64*>(values), num_values));
return op_->SetAttrIntList(attr_name, values, num_values);
}
Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name,
const DataType* values, int num_values,
ForwardOperation* forward_op_) {
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const DataType>(values, num_values));
return op_->SetAttrTypeList(attr_name, values, num_values);
}
Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name,
const unsigned char* values, int num_values,
ForwardOperation* forward_op_) {
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
forward_op_->attrs.Set(attr_name,
gtl::ArraySlice<const bool>(b.get(), num_values));
return op_->SetAttrBoolList(attr_name, values, num_values);
}
Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name,
const int64_t** dims, const int* num_dims,
int num_values, ForwardOperation* forward_op_) {
std::unique_ptr<TensorShapeProto[]> proto(new TensorShapeProto[num_values]);
for (int i = 0; i < num_values; ++i) {
const auto num_dims_i = num_dims[i];
if (num_dims_i > TensorShape::MaxDimensions()) {
return errors::InvalidArgument(
strings::StrCat("Value specified for `", attr_name, "` has ",
num_dims_i, " dimensions which is over the limit of ",
TensorShape::MaxDimensions(), "."));
}
if (num_dims_i < 0) {
proto[i].set_unknown_rank(true);
} else {
const int64_t* dims_i = dims[i];
auto proto_i = &proto[i];
for (int d = 0; d < num_dims_i; ++d) {
proto_i->add_dim()->set_size(dims_i[d]);
}
}
}
forward_op_->attrs.Set(
attr_name, gtl::ArraySlice<TensorShapeProto>(proto.get(), num_values));
return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name,
absl::Span<const AbstractOperation*> values,
ForwardOperation* forward_op_) {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionList has not been "
"implemented yet.");
}
Status Execute(AbstractOperation* op_, AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> retvals, int* num_retvals,
ForwardOperation* forward_op_, Tape* tape,
const GradientRegistry& registry) {
TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals));
std::vector<int64> input_ids(forward_op_->inputs.size());
std::vector<tensorflow::DataType> input_dtypes(forward_op_->inputs.size());
for (int i = 0; i < forward_op_->inputs.size(); i++) {
input_ids[i] = ToId(forward_op_->inputs[i]);
input_dtypes[i] = forward_op_->inputs[i]->DataType();
}
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t, ctx));
}
tape->RecordOperation(
op_->Name(), tape_tensors, input_ids, input_dtypes,
[registry, forward_op_]() -> GradientFunction* {
std::unique_ptr<GradientFunction> grad_fn;
Status s = registry.Lookup(*forward_op_, &grad_fn);
if (!s.ok()) {
return nullptr;
}
return grad_fn.release();
},
[](GradientFunction* ptr) {
if (ptr) {
delete ptr;
}
});
return Status::OK();
}
} // namespace internal
} // namespace gradients
} // namespace tensorflow