blob: 4c2c0664d5c00efaea3c80dc2d3fa6e9e7474846 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include <numeric>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/value_inference.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
: context_(context),
dynamic_dimension_is_minus_one_(false),
value_inference_(xla_context()->builder()) {}
bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
return context_->ValidateInputsAreSameShape(op);
}
XlaContext* XlaOpKernelContext::xla_context() const {
return &XlaContext::Get(context_);
}
xla::XlaBuilder* XlaOpKernelContext::builder() const {
return xla_context()->builder();
}
xla::ValueInference& XlaOpKernelContext::value_inference() {
return value_inference_;
}
XlaCompiler* XlaOpKernelContext::compiler() const {
return xla_context()->compiler();
}
const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
return *XlaExpression::CastExpressionFromTensor(context_->input(index));
}
const XlaExpression& XlaOpKernelContext::InputExpression(
absl::string_view name) {
return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
}
xla::XlaOp XlaOpKernelContext::Input(int index) {
return InputExpression(index).AsXlaOp(builder());
}
xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
return InputExpression(name).AsXlaOp(builder());
}
TensorShape XlaOpKernelContext::InputShape(int index) {
return context_->input(index).shape();
}
TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
return GetInputTensorByName(name).shape();
}
StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(int index) {
return builder()->GetShape(Input(index));
}
StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(absl::string_view name) {
return builder()->GetShape(Input(name));
}
DataType XlaOpKernelContext::input_type(int index) const {
DataType type = context_->input_dtype(index);
if (type == DT_UINT8) {
// Masqueraded XlaExpression could have different type. See
// XlaOpKernelContext::SetOutputExpression for details.
auto expression =
XlaExpression::CastExpressionFromTensor(context_->input(index));
type = expression->dtype();
}
return type;
}
DataType XlaOpKernelContext::InputType(absl::string_view name) {
const Tensor& tensor = GetInputTensorByName(name);
DataType type = tensor.dtype();
if (type == DT_UINT8) {
// Masqueraded XlaExpression could have different type. See
// XlaOpKernelContext::SetOutputExpression for details.
auto expression = XlaExpression::CastExpressionFromTensor(tensor);
type = expression->dtype();
}
return type;
}
xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(input_type(index), &type);
if (!status.ok()) {
SetStatus(status);
return xla::PRIMITIVE_TYPE_INVALID;
}
return type;
}
xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(InputType(name), &type);
if (!status.ok()) {
SetStatus(status);
return xla::PRIMITIVE_TYPE_INVALID;
}
return type;
}
Status XlaOpKernelContext::ConstantInput(int index,
xla::Literal* constant_literal,
xla::ValueInferenceMode mode) {
if (this->InputXlaShape(index)->is_dynamic()) {
return errors::InvalidArgument(
"Reading input as constant from a dynamic tensor is not yet supported. "
"Xla shape: ",
this->InputXlaShape(index)->ToString());
}
return ConstantInputReshaped(index,
context_->input(index).shape().dim_sizes(),
constant_literal, mode);
}
static StatusOr<int> InputIndex(XlaOpKernelContext* context,
absl::string_view name) {
int start, stop;
TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was "
"expected");
}
return start;
}
Status XlaOpKernelContext::ResolveInputDynamism(
int index, xla::Literal* dynamism_literal) {
return ResolveInputDynamismReshaped(
index, context_->input(index).shape().dim_sizes(), dynamism_literal);
}
Status XlaOpKernelContext::ResolveInputDynamism(
absl::string_view name, xla::Literal* dynamism_literal) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ResolveInputDynamism(index, dynamism_literal);
}
Status XlaOpKernelContext::ConstantInput(absl::string_view name,
xla::Literal* constant_literal,
xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ConstantInput(index, constant_literal, mode);
}
Status XlaOpKernelContext::ConstantInputReshaped(
int index, absl::Span<const int64_t> new_dims,
xla::Literal* constant_literal, xla::ValueInferenceMode mode) {
XlaExpression e = InputExpression(index);
auto* client = compiler() ? compiler()->client() : nullptr;
StatusOr<absl::optional<Tensor>> constant_or_status =
e.ResolveConstant(client, dynamic_dimension_is_minus_one_, mode);
if (!constant_or_status.ok()) {
Status status = constant_or_status.status();
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
context_->op_kernel().type_string(),
" operator as a compile-time constant.");
return status;
}
absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
if (!constant.has_value()) {
return errors::InvalidArgument(
"Input ", index, " to node `", context_->op_kernel().name(),
"` with op ", context_->op_kernel().type_string(),
" must be a compile-time constant.\n\n"
"XLA compilation requires that operator arguments that represent "
"shapes or dimensions be evaluated to concrete values at compile time. "
"This error means that a shape or dimension argument could not be "
"evaluated at compile time, usually because the value of the argument "
"depends on a parameter to the computation, on a variable, or on a "
"stateful operation such as a random number generator.");
}
Tensor temp(constant->dtype());
if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
constant->shape().DebugString(),
" but was asked to be reshaped to incompatible shape ",
TensorShape(new_dims).DebugString());
}
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
return Status::OK();
}
// Converts an int32 or int64 scalar literal to an int64.
static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
int64_t* out) {
if (literal.shape().rank() != 0) {
return errors::InvalidArgument("value is not a scalar");
}
if (literal.shape().element_type() == xla::S32) {
*out = literal.Get<int32>({});
} else if (literal.shape().element_type() == xla::S64) {
*out = literal.Get<int64_t>({});
} else {
return errors::InvalidArgument("value must be either int32 or int64");
}
return Status::OK();
}
// Converts an float32 or float64 scalar literal to a float64.
static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
double* out) {
if (literal.shape().rank() != 0) {
return errors::InvalidArgument("value is not a scalar");
}
if (literal.shape().element_type() == xla::F32) {
*out = literal.Get<float>({});
} else if (literal.shape().element_type() == xla::F64) {
*out = literal.Get<double>({});
} else {
return errors::InvalidArgument("value must be either float32 or float64");
}
return Status::OK();
}
Status XlaOpKernelContext::ConstantInputAsIntScalar(
int index, int64_t* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
return LiteralToInt64Scalar(literal, out);
}
Status XlaOpKernelContext::ConstantInputAsIntScalar(
absl::string_view name, int64_t* out, xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ConstantInputAsIntScalar(index, out, mode);
}
Status XlaOpKernelContext::ConstantInputAsFloatScalar(
int index, double* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
return LiteralToFloat64Scalar(literal, out);
}
static Status LiteralToPredVector(const xla::LiteralSlice& literal,
std::vector<bool>* out) {
if (literal.shape().rank() != 1) {
return errors::InvalidArgument("value is not 1D, rank: ",
literal.shape().rank());
}
int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
if (literal.shape().element_type() != xla::PRED) {
return errors::InvalidArgument("value is not PRED");
}
for (int64_t i = 0; i < size; ++i) {
out->push_back(literal.Get<bool>({i}));
}
return Status::OK();
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
xla::Literal literal;
XlaExpression e = InputExpression(index);
auto* client = compiler() ? compiler()->client() : nullptr;
StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
if (!dynamism_or_status.ok()) {
// When failed to resolve dynamism, conservatively consider the value
// dynamic. This could happen if the input depends on some ops like
// custom-call that is not supported generally for dynamism computation.
//
// TODO(b/176993339): Support resolving dynamism across computations so
// resolving dynamism will not fail in those cases.
*out = true;
return Status::OK();
}
Tensor dynamism = dynamism_or_status.ValueOrDie();
Tensor temp(dynamism.dtype());
TensorShape tensor_shape({});
if (!temp.CopyFrom(dynamism, tensor_shape)) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
}
TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
*out = literal.Get<bool>({});
return Status::OK();
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
absl::string_view name, std::vector<bool>* out) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ResolveInputDynamismIntoPredVector(index, out);
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPred(absl::string_view name,
bool* out) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ResolveInputDynamismIntoPred(index, out);
}
Status XlaOpKernelContext::ResolveInputDynamismReshaped(
int index, absl::Span<const int64_t> new_dims,
xla::Literal* dynamism_literal) {
XlaExpression e = InputExpression(index);
auto* client = compiler() ? compiler()->client() : nullptr;
StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism(client);
if (!dynamism_or_status.ok()) {
xla::Literal true_literal = xla::LiteralUtil::CreateR0<bool>(true);
// When failed to resolve dynamism, conservatively consider the value
// dynamic. This could happen if the input depends on some ops like
// custom-call that is not supported generally for dynamism computation.
*dynamism_literal =
true_literal
.Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {})
.ValueOrDie();
return Status::OK();
}
Tensor dynamism = dynamism_or_status.ValueOrDie();
Tensor temp(dynamism.dtype());
if (!temp.CopyFrom(dynamism, TensorShape(new_dims))) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
dynamism.shape().DebugString(),
" but was asked to be reshaped to incompatible shape ",
TensorShape(new_dims).DebugString());
}
TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp));
return Status::OK();
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
int index, std::vector<bool>* out) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ResolveInputDynamismReshaped(
index, {InputShape(index).num_elements()}, &literal));
return LiteralToPredVector(literal, out);
}
// Converts an int32 or int64 1D literal to an int64 vector.
static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
std::vector<int64_t>* out) {
if (literal.shape().rank() != 1) {
return errors::InvalidArgument("value is not 1D, rank: ",
literal.shape().rank());
}
int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
if (literal.shape().element_type() == xla::S32) {
for (int64_t i = 0; i < size; ++i) {
out->push_back(literal.Get<int32>({i}));
}
} else if (literal.shape().element_type() == xla::S64) {
for (int64_t i = 0; i < size; ++i) {
out->push_back(literal.Get<int64_t>({i}));
}
} else {
return errors::InvalidArgument("value must be either int32 or int64");
}
return Status::OK();
}
Status XlaOpKernelContext::ConstantInputAsIntVector(
int index, std::vector<int64_t>* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputAsIntVector(
absl::string_view name, std::vector<int64_t>* out,
xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ConstantInputAsIntVector(index, out, mode);
}
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
int index, std::vector<int64_t>* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInputReshaped(
index, {InputShape(index).num_elements()}, &literal, mode));
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
absl::string_view name, std::vector<int64_t>* out,
xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInputReshaped(
index, {InputShape(index).num_elements()}, &literal, mode));
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputAsInt64Literal(
int index, xla::Literal* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
switch (literal.shape().element_type()) {
case xla::S32: {
*out = xla::Literal(
xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
auto src_data = literal.data<int32>();
for (int64_t i = 0; i < src_data.size(); ++i) {
out->data<int64_t>()[i] = src_data[i];
}
return Status::OK();
}
case xla::S64:
*out = std::move(literal);
return Status::OK();
default:
return errors::InvalidArgument(
"Invalid argument to ConstantInputAsInt64Literal: ",
xla::ShapeUtil::HumanString(literal.shape()));
}
}
Status XlaOpKernelContext::ConstantInputAsInt64Literal(
absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ConstantInputAsInt64Literal(index, out, mode);
}
// TODO(phawkins): validate that the dimensions form a valid shape, fail
// gracefully if they do not.
Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape,
xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
std::vector<int64_t> dims;
TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
*shape = TensorShape(dims);
return Status::OK();
}
Status XlaOpKernelContext::ConstantInputAsPartialShape(
int index, PartialTensorShape* shape) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
// If `literal` is a scalar it's value must be -1.
if (literal.shape().rank() == 0) {
int64_t shape_val;
TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
if (shape_val != -1) {
return errors::InvalidArgument(
"Cannot convert value to PartialTensorShape: ", shape_val);
}
*shape = PartialTensorShape(); // Shape with unknown rank.
return Status::OK();
}
std::vector<int64_t> dims;
TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
*shape = PartialTensorShape(dims);
return Status::OK();
}
Status XlaOpKernelContext::InputList(absl::string_view name,
std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes) {
OpInputList inputs;
TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
handles->clear();
shapes->clear();
for (const Tensor& input : inputs) {
handles->push_back(
XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
shapes->push_back(input.shape());
}
return Status::OK();
}
Status XlaOpKernelContext::ConstantInputList(absl::string_view name,
std::vector<xla::Literal>* outputs,
xla::ValueInferenceMode mode) {
int start, stop;
TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
outputs->resize(stop - start);
for (int i = start; i < stop; ++i) {
TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode));
}
return Status::OK();
}
namespace {
Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
const XlaOpKernelContext* ctx,
TensorShape* shape, xla::XlaOp* value) {
const XlaExpression* expression =
XlaExpression::CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
if (!variable->initialized()) {
return errors::FailedPrecondition(
"Read variable failure ", variable->name(),
". It could mean the variable is uninitialized or the variable is on "
"another device ");
}
if (variable->type() != type) {
return errors::InvalidArgument(
"Type mismatch for read of variable ", variable->name(), ". Expected ",
DataTypeString(type), "; got ", DataTypeString(variable->type()));
}
if (shape) {
*shape = variable->shape();
}
if (!variable->IsOverwritten() && expression->constant_value()) {
TF_ASSIGN_OR_RETURN(xla::Literal literal,
HostTensorToLiteral(*expression->constant_value()));
*value = xla::ConstantLiteral(ctx->builder(), literal);
return Status::OK();
}
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
ctx->compiler()->options().shape_representation_fn(
variable->shape(), variable->type(),
/*use_fast_memory=*/false));
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
*value = variable->value();
} else {
*value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
}
return Status::OK();
}
} // namespace
Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
TensorShape* shape,
xla::XlaOp* value) {
return ReadVariableInputTensor(context_->input(index), type, this, shape,
value);
}
Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
DataType type, TensorShape* shape,
xla::XlaOp* value) {
return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
value);
}
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression =
XlaExpression::CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
if (!variable->initialized()) {
return errors::InvalidArgument(
"Read variable failure ", variable->name(),
". It could mean the variable is uninitialized or the variable is on "
"another device ");
}
*type = variable->type();
*shape = variable->shape();
return Status::OK();
}
void XlaOpKernelContext::SetOutputExpression(int index,
const XlaExpression& expression) {
Status status = [&] {
// The step's default allocator is the dummy XlaCompilationAllocator which
// simply allocates a metadata buffer to hold the expression to which it
// corresponds.
// Provides a special behavior for DT_VARIANT and other types that are not
// trivially copyable. In those cases, allocate a tensor of type DT_UINT8.
if (!DataTypeCanUseMemcpy(expression.dtype())) {
// tensor_data() is not supported for tensors that cannot be copied via
// memcpy, as the copy logic might try to inspect the stored data (e.g.
// a std::string). This is likely to fail, as the data is invalid given
// that it actually encodes an XlaExpression. Using a uint8 tensor is
// always safe, so simply do that.
// TODO(jpienaar): This should be refactored to stop masquerading
// XlaExpressions as Tensors.
Tensor output;
TensorShape tensor_shape;
TF_RETURN_IF_ERROR(
context_->allocate_temp(DT_UINT8, tensor_shape, &output));
context_->set_output(index, output);
} else {
Tensor* output = nullptr;
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
}
XlaExpression::AssignExpressionToTensor(expression,
context_->mutable_output(index));
return Status::OK();
}();
if (!status.ok()) {
SetStatus(status);
}
}
xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
if (!status.ok()) {
SetStatus(status);
return xla::PRIMITIVE_TYPE_INVALID;
}
return type;
}
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
SetOutputExpression(
index,
XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
}
void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
SetOutputExpression(index, XlaExpression::Constant(constant));
}
void XlaOpKernelContext::SetTensorListOutput(int index,
const xla::XlaOp& handle) {
SetOutputExpression(index, XlaExpression::TensorList(handle));
}
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
SetOutputExpression(index, XlaExpression::Resource(resource));
}
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
const XlaExpression* expression =
XlaExpression::CastExpressionFromTensor(context_->input(index));
TF_RET_CHECK(expression->resource() != nullptr);
*resource = expression->resource();
return Status::OK();
}
namespace {
Status AssignVariableTensor(const Tensor& tensor, DataType type,
const XlaOpKernelContext* ctx, xla::XlaOp handle,
xla::XlaBuilder* builder) {
const XlaExpression* expression =
XlaExpression::CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
auto shape_or_status = builder->GetShape(handle);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
TensorShape shape;
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(shape_or_status.ValueOrDie(), &shape));
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
ctx->compiler()->options().shape_representation_fn(
shape, type,
/*use_fast_memory=*/false));
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
handle = xla::Reshape(handle,
xla::AsInt64Slice(representation_shape.dimensions()));
}
variable->SetRepresentationShape(representation_shape);
return variable->SetValue(handle);
}
} // namespace
Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
xla::XlaOp handle) {
TF_RET_CHECK(handle.valid());
return AssignVariableTensor(context_->input(input_index), type, this, handle,
builder());
}
Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
xla::XlaOp handle) {
TF_RET_CHECK(handle.valid());
return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
builder());
}
static Status GetStatusWithStackTrace(const Status& s,
const XlaOpKernelContext* ctx) {
if (s.code() == error::INVALID_ARGUMENT) {
return Status{s.code(),
absl::StrCat(s.error_message(), "\n", ctx->StackTrace())};
}
return s;
}
void XlaOpKernelContext::CtxFailure(const Status& s) {
context_->CtxFailure(GetStatusWithStackTrace(s, this));
}
void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this));
}
void XlaOpKernelContext::CtxFailure(const char* file, int line,
const Status& s) {
context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this));
}
void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
const Status& s) {
context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this));
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
const DataType type) {
return xla_context()->GetOrCreateMax(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
const DataType type) {
return xla_context()->GetOrCreateMin(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
const DataType type) {
return xla_context()->GetOrCreateAdd(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
const DataType type) {
return xla_context()->GetOrCreateMul(type);
}
const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
const Tensor* tensor;
CHECK(context_->input(name, &tensor).ok());
return *tensor;
}
XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
void XlaOpKernel::Compute(OpKernelContext* context) {
XlaOpKernelContext xla_context(context);
Compile(&xla_context);
}
std::string XlaOpKernelContext::StackTrace() const {
if (const AbstractStackTrace* stack_trace =
xla_context()->StackTraceForNodeName(op_kernel().name())) {
AbstractStackTrace::TracePrintingOptions opts;
opts.show_line_contents = true;
opts.filter_common_prefix = true;
opts.drop_internal_frames = true;
return absl::StrCat("\nStack trace for op definition: \n",
stack_trace->ToString(opts), "\n");
} else {
return "";
}
}
} // namespace tensorflow