blob: ad8e5a029c8372453b1dfc8e67b11d6b298b53aa [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
namespace tensorflow {
namespace {
// TODO: This is only a dummy kernel
class DenseBincountOp : public XlaOpKernel {
private:
bool binary_output_;
public:
explicit DenseBincountOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_output", &binary_output_));
}
void Compile(XlaOpKernelContext* ctx) override {
// Dumb implementation for the simplest test case
xla::XlaOp input = ctx->Input(0);
xla::XlaOp weights = ctx->Input(2);
StatusOr<xla::Shape> weights_shape_or = ctx->builder()->GetShape(weights);
OP_REQUIRES_OK(ctx, weights_shape_or.status());
auto weights_shape = weights_shape_or.ValueOrDie();
auto weights_size = weights_shape.dimensions(0);
auto input_xla_type = ctx->input_xla_type(0);
xla::PrimitiveType dtype;
bool has_weight;
if (weights_size){
has_weight = true;
dtype = ctx->input_xla_type(2);
}
else {
has_weight = false;
dtype = input_xla_type;
}
int64_t output_size;
ctx->ConstantInputAsIntScalar("size", &output_size);
StatusOr<xla::Shape> input_shape_or = ctx->builder()->GetShape(input);
OP_REQUIRES_OK(ctx, input_shape_or.status());
auto input_shape = input_shape_or.ValueOrDie();
auto size = input_shape.dimensions(0);
auto dim = 1;
auto rank = input_shape.rank();
auto counter_shape = xla::ShapeUtil::MakeShape(xla::S32, {});
const xla::Shape data_shape = xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()});
xla::Shape output_shape = xla::ShapeUtil::MakeShape(dtype, {output_size});
if (rank == 2) {
output_shape = xla::ShapeUtil::MakeShape(dtype, {rank, output_size});
dim = input_shape.dimensions(1);
}
auto loop_shape = xla::ShapeUtil::MakeTupleShape(
{counter_shape, data_shape, output_shape, weights_shape});
// Create a computation for the condition
xla::XlaComputation condition;
{
std::unique_ptr<xla::XlaBuilder> builder =
ctx->builder()->CreateSubBuilder("condition");
auto param = xla::Parameter(builder.get(), 0, loop_shape, "param");
auto counter = xla::GetTupleElement(param, 0);
xla::Gt(xla::ConstantR0<int32_t>(builder.get(), size*dim), counter);
condition = builder->Build().ConsumeValueOrDie();
}
// Create a computation for the body
xla::XlaComputation body;
{
std::unique_ptr<xla::XlaBuilder> builder =
ctx->builder()->CreateSubBuilder("body");
auto param = Parameter(builder.get(), 0, loop_shape, "param");
auto counter = xla::GetTupleElement(param, 0);
auto data_stack = xla::GetTupleElement(param, 1);
auto accum_stack = xla::GetTupleElement(param, 2);
auto weights = xla::GetTupleElement(param, 3);
auto accum_shape = xla::ShapeUtil::MakeShape(dtype, {});
if (rank == 1) {
auto data = xla::DynamicSlice(data_stack, {counter}, {1});
auto accum = xla::DynamicSlice(accum_stack, {data}, {1});
accum = xla::Reshape(accum, {0}, {});
accum = xla::ConvertElementType(accum, dtype);
auto data_scalar = xla::Reshape(data, {0}, {});
auto condition_shape = xla::ShapeUtil::MakeTupleShape(
{counter_shape, counter_shape, accum_shape, output_shape, weights_shape});
xla::XlaComputation update;
{
std::unique_ptr<xla::XlaBuilder> true_builder =
ctx->builder()->CreateSubBuilder("update");
auto param = Parameter(true_builder.get(), 0, condition_shape, "param");
auto data_scalar = xla::GetTupleElement(param, 0);
auto counter = xla::GetTupleElement(param, 1);
auto accum = xla::GetTupleElement(param, 2);
auto accum_stack = xla::GetTupleElement(param, 3);
auto weights = xla::GetTupleElement(param, 4);
if (binary_output_){
accum = xla::One(true_builder.get(), dtype);
}
else if (! has_weight) {
accum = accum + xla::One(true_builder.get(), dtype);
}
else {
auto weight = xla::DynamicSlice(weights, {counter}, {1});
weight = xla::Reshape(weight, {0}, {});
accum = accum + weight;
}
accum_stack = xla::DynamicUpdateSlice(
accum_stack, xla::Reshape(accum, {1}), {data_scalar});
xla::Tuple(true_builder.get(), {accum, accum_stack});
update = true_builder->Build().ValueOrDie();
}
xla::XlaComputation no_update;
{
std::unique_ptr<xla::XlaBuilder> false_builder =
ctx->builder()->CreateSubBuilder("no_update");
auto param = Parameter(false_builder.get(), 0, condition_shape, "param");
auto data = xla::GetTupleElement(param, 0);
auto count = xla::GetTupleElement(param, 1);
auto accum = xla::GetTupleElement(param, 2);
auto accum_stack = xla::GetTupleElement(param, 3);
xla::Tuple(false_builder.get(), {accum, accum_stack});
no_update = false_builder->Build().ValueOrDie();
}
auto output_size_xla = xla::ConstantR0<int32_t>(builder.get(), output_size);
auto pred = xla::Lt(data_scalar, output_size_xla);
auto tuple = xla::Tuple(builder.get(), {data_scalar, counter, accum, accum_stack, weights});
auto cond = xla::Conditional(pred, tuple, update, tuple, no_update);
accum = xla::GetTupleElement(cond, 0);
accum_stack = xla::GetTupleElement(cond, 1);
}
else {
auto condition_shape = xla::ShapeUtil::MakeTupleShape(
{counter_shape, counter_shape, counter_shape, output_shape,
accum_shape, weights_shape});
auto dim_xla = xla::ConstantR0<int32_t>(builder.get(), dim);
auto idx_1 = xla::Div(counter, dim_xla);
auto idx_2 = counter % dim_xla;
auto data = xla::DynamicSlice(data_stack, {idx_1, idx_2}, {1, 1});
auto data_scalar = xla::Reshape(data, {0,1}, {});
auto accum = xla::DynamicSlice(accum_stack, {idx_1, data_scalar}, {1, 1});
accum = xla::Reshape(accum, {0,1}, {});
accum = xla::ConvertElementType(accum, dtype);
xla::XlaComputation update;
{
std::unique_ptr<xla::XlaBuilder> true_builder =
builder->CreateSubBuilder("update_rank2");
auto param = Parameter(true_builder.get(), 0, condition_shape, "param");
auto data_scalar = xla::GetTupleElement(param, 0);
auto idx_1 = xla::GetTupleElement(param, 1);
auto idx_2 = xla::GetTupleElement(param, 2);
auto accum_stack = xla::GetTupleElement(param, 3);
auto accum = xla::GetTupleElement(param, 4);
auto weights = xla::GetTupleElement(param, 5);
if (binary_output_){
accum = xla::One(true_builder.get(), dtype);
}
else if (! has_weight) {
accum = accum + xla::One(true_builder.get(), dtype);
}
else {
auto weight = xla::DynamicSlice(weights, {idx_1, idx_2}, {1, 1});
auto weigth_scalar = xla::Reshape(weight, {0,1}, {});
accum = accum + weigth_scalar;
}
accum_stack = xla::DynamicUpdateSlice(
accum_stack, xla::Reshape(accum, {1, 1}), {idx_1, data_scalar});
xla::Tuple(true_builder.get(), {accum, accum_stack});
update = true_builder->Build().ValueOrDie();
}
xla::XlaComputation no_update;
{
std::unique_ptr<xla::XlaBuilder> false_builder =
builder->CreateSubBuilder("no_update_rank2");
auto param = Parameter(false_builder.get(), 0, condition_shape, "param");
auto accum_stack = xla::GetTupleElement(param, 3);
auto accum = xla::GetTupleElement(param, 4);
xla::Tuple(false_builder.get(), {accum, accum_stack});
no_update = false_builder->Build().ValueOrDie();
}
auto limit = xla::ConstantR0<int32_t>(builder.get(), output_size);
auto pred = xla::Lt(data_scalar, limit);
auto tuple = xla::Tuple(builder.get(), {data_scalar, idx_1, idx_2, accum_stack, accum, weights});
auto cond = xla::Conditional(pred, tuple, update, tuple, no_update);
accum = xla::GetTupleElement(cond, 0);
accum_stack = xla::GetTupleElement(cond, 1);
}
counter = counter + xla::One(builder.get(), xla::S32);
xla::Tuple(builder.get(), {counter, data_stack, accum_stack, weights});
body = builder->Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
auto zero = xla::Zero(ctx->builder(), xla::S32);
auto zero_out = xla::Zero(ctx->builder(), dtype);
auto zero_broadcast = xla::Broadcast(zero_out, {output_shape.dimensions()});
auto init = xla::Tuple(ctx->builder(), {zero, input, zero_broadcast, weights});
auto result = xla::While(condition, body, init);
auto output = xla::GetTupleElement(result,2);
ctx->SetOutput(0, output);
}
};
REGISTER_XLA_OP(Name("DenseBincount").CompileTimeConstantInput("size"), DenseBincountOp);
} // namespace
} // namespace tensorflow