Handle weights and binary_output
diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc
index a945e84..ad8e5a0 100644
--- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc
@@ -17,6 +17,7 @@
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/xla/client/lib/comparators.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
@@ -27,14 +28,32 @@
// TODO: This is only a dummy kernel
class DenseBincountOp : public XlaOpKernel {
+ private:
+ bool binary_output_;
public:
explicit DenseBincountOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
- DataType dtype;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_output", &binary_output_));
}
void Compile(XlaOpKernelContext* ctx) override {
// Dumb implementation for the simplest test case
xla::XlaOp input = ctx->Input(0);
+ xla::XlaOp weights = ctx->Input(2);
+ StatusOr<xla::Shape> weights_shape_or = ctx->builder()->GetShape(weights);
+ OP_REQUIRES_OK(ctx, weights_shape_or.status());
+ auto weights_shape = weights_shape_or.ValueOrDie();
+ auto weights_size = weights_shape.dimensions(0);
+ auto input_xla_type = ctx->input_xla_type(0);
+ xla::PrimitiveType dtype;
+ bool has_weight;
+ if (weights_size){
+ has_weight = true;
+ dtype = ctx->input_xla_type(2);
+ }
+ else {
+ has_weight = false;
+ dtype = input_xla_type;
+ }
int64_t output_size;
ctx->ConstantInputAsIntScalar("size", &output_size);
StatusOr<xla::Shape> input_shape_or = ctx->builder()->GetShape(input);
@@ -44,16 +63,16 @@
auto dim = 1;
auto rank = input_shape.rank();
auto counter_shape = xla::ShapeUtil::MakeShape(xla::S32, {});
- const xla::Shape data_shape = xla::ShapeUtil::MakeShape(xla::S32, {input_shape.dimensions()});
+ const xla::Shape data_shape = xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()});
- xla::Shape output_shape = xla::ShapeUtil::MakeShape(xla::S32, {output_size});
+ xla::Shape output_shape = xla::ShapeUtil::MakeShape(dtype, {output_size});
if (rank == 2) {
- output_shape = xla::ShapeUtil::MakeShape(xla::S32, {rank, output_size});
+ output_shape = xla::ShapeUtil::MakeShape(dtype, {rank, output_size});
dim = input_shape.dimensions(1);
}
auto loop_shape = xla::ShapeUtil::MakeTupleShape(
- {counter_shape, data_shape, output_shape});
+ {counter_shape, data_shape, output_shape, weights_shape});
// Create a computation for the condition
xla::XlaComputation condition;
@@ -75,15 +94,18 @@
auto counter = xla::GetTupleElement(param, 0);
auto data_stack = xla::GetTupleElement(param, 1);
auto accum_stack = xla::GetTupleElement(param, 2);
-
+ auto weights = xla::GetTupleElement(param, 3);
+ auto accum_shape = xla::ShapeUtil::MakeShape(dtype, {});
+
if (rank == 1) {
auto data = xla::DynamicSlice(data_stack, {counter}, {1});
auto accum = xla::DynamicSlice(accum_stack, {data}, {1});
accum = xla::Reshape(accum, {0}, {});
+ accum = xla::ConvertElementType(accum, dtype);
auto data_scalar = xla::Reshape(data, {0}, {});
auto condition_shape = xla::ShapeUtil::MakeTupleShape(
- {counter_shape, counter_shape, output_shape});
+ {counter_shape, counter_shape, accum_shape, output_shape, weights_shape});
xla::XlaComputation update;
{
@@ -91,9 +113,21 @@
ctx->builder()->CreateSubBuilder("update");
auto param = Parameter(true_builder.get(), 0, condition_shape, "param");
auto data_scalar = xla::GetTupleElement(param, 0);
- auto accum = xla::GetTupleElement(param, 1);
- auto accum_stack = xla::GetTupleElement(param, 2);
- accum = accum + xla::One(true_builder.get(), xla::S32);
+ auto counter = xla::GetTupleElement(param, 1);
+ auto accum = xla::GetTupleElement(param, 2);
+ auto accum_stack = xla::GetTupleElement(param, 3);
+ auto weights = xla::GetTupleElement(param, 4);
+ if (binary_output_){
+ accum = xla::One(true_builder.get(), dtype);
+ }
+ else if (! has_weight) {
+ accum = accum + xla::One(true_builder.get(), dtype);
+ }
+ else {
+ auto weight = xla::DynamicSlice(weights, {counter}, {1});
+ weight = xla::Reshape(weight, {0}, {});
+ accum = accum + weight;
+ }
accum_stack = xla::DynamicUpdateSlice(
accum_stack, xla::Reshape(accum, {1}), {data_scalar});
xla::Tuple(true_builder.get(), {accum, accum_stack});
@@ -106,22 +140,25 @@
ctx->builder()->CreateSubBuilder("no_update");
auto param = Parameter(false_builder.get(), 0, condition_shape, "param");
auto data = xla::GetTupleElement(param, 0);
- auto accum = xla::GetTupleElement(param, 1);
- auto accum_stack = xla::GetTupleElement(param, 2);
+ auto count = xla::GetTupleElement(param, 1);
+ auto accum = xla::GetTupleElement(param, 2);
+ auto accum_stack = xla::GetTupleElement(param, 3);
xla::Tuple(false_builder.get(), {accum, accum_stack});
no_update = false_builder->Build().ValueOrDie();
}
auto output_size_xla = xla::ConstantR0<int32_t>(builder.get(), output_size);
auto pred = xla::Lt(data_scalar, output_size_xla);
- auto tuple = xla::Tuple(builder.get(), {data_scalar, accum, accum_stack});
+ auto tuple = xla::Tuple(builder.get(), {data_scalar, counter, accum, accum_stack, weights});
auto cond = xla::Conditional(pred, tuple, update, tuple, no_update);
accum = xla::GetTupleElement(cond, 0);
accum_stack = xla::GetTupleElement(cond, 1);
+
}
else {
auto condition_shape = xla::ShapeUtil::MakeTupleShape(
- {counter_shape, counter_shape, output_shape, counter_shape});
+ {counter_shape, counter_shape, counter_shape, output_shape,
+ accum_shape, weights_shape});
auto dim_xla = xla::ConstantR0<int32_t>(builder.get(), dim);
auto idx_1 = xla::Div(counter, dim_xla);
@@ -130,7 +167,7 @@
auto data_scalar = xla::Reshape(data, {0,1}, {});
auto accum = xla::DynamicSlice(accum_stack, {idx_1, data_scalar}, {1, 1});
accum = xla::Reshape(accum, {0,1}, {});
-
+ accum = xla::ConvertElementType(accum, dtype);
xla::XlaComputation update;
{
std::unique_ptr<xla::XlaBuilder> true_builder =
@@ -139,9 +176,21 @@
auto data_scalar = xla::GetTupleElement(param, 0);
auto idx_1 = xla::GetTupleElement(param, 1);
- auto accum_stack = xla::GetTupleElement(param, 2);
- auto accum = xla::GetTupleElement(param, 3);
- accum = accum + xla::One(true_builder.get(), xla::S32);
+ auto idx_2 = xla::GetTupleElement(param, 2);
+ auto accum_stack = xla::GetTupleElement(param, 3);
+ auto accum = xla::GetTupleElement(param, 4);
+ auto weights = xla::GetTupleElement(param, 5);
+ if (binary_output_){
+ accum = xla::One(true_builder.get(), dtype);
+ }
+ else if (! has_weight) {
+ accum = accum + xla::One(true_builder.get(), dtype);
+ }
+ else {
+ auto weight = xla::DynamicSlice(weights, {idx_1, idx_2}, {1, 1});
+ auto weigth_scalar = xla::Reshape(weight, {0,1}, {});
+ accum = accum + weigth_scalar;
+ }
accum_stack = xla::DynamicUpdateSlice(
accum_stack, xla::Reshape(accum, {1, 1}), {idx_1, data_scalar});
xla::Tuple(true_builder.get(), {accum, accum_stack});
@@ -153,30 +202,29 @@
std::unique_ptr<xla::XlaBuilder> false_builder =
builder->CreateSubBuilder("no_update_rank2");
auto param = Parameter(false_builder.get(), 0, condition_shape, "param");
- auto data_scalar = xla::GetTupleElement(param, 0);
- auto idx_1 = xla::GetTupleElement(param, 1);
- auto accum_stack = xla::GetTupleElement(param, 2);
- auto accum = xla::GetTupleElement(param, 3);
+ auto accum_stack = xla::GetTupleElement(param, 3);
+ auto accum = xla::GetTupleElement(param, 4);
xla::Tuple(false_builder.get(), {accum, accum_stack});
no_update = false_builder->Build().ValueOrDie();
}
auto limit = xla::ConstantR0<int32_t>(builder.get(), output_size);
auto pred = xla::Lt(data_scalar, limit);
- auto tuple = xla::Tuple(builder.get(), {data_scalar, idx_1, accum_stack, accum});
+ auto tuple = xla::Tuple(builder.get(), {data_scalar, idx_1, idx_2, accum_stack, accum, weights});
auto cond = xla::Conditional(pred, tuple, update, tuple, no_update);
accum = xla::GetTupleElement(cond, 0);
accum_stack = xla::GetTupleElement(cond, 1);
}
counter = counter + xla::One(builder.get(), xla::S32);
- xla::Tuple(builder.get(), {counter, data_stack, accum_stack});
+ xla::Tuple(builder.get(), {counter, data_stack, accum_stack, weights});
body = builder->Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
auto zero = xla::Zero(ctx->builder(), xla::S32);
- auto zero_broadcast = xla::Broadcast(zero, {output_shape.dimensions()});
- auto init = xla::Tuple(ctx->builder(), {zero, input, zero_broadcast});
+ auto zero_out = xla::Zero(ctx->builder(), dtype);
+ auto zero_broadcast = xla::Broadcast(zero_out, {output_shape.dimensions()});
+ auto init = xla::Tuple(ctx->builder(), {zero, input, zero_broadcast, weights});
auto result = xla::While(condition, body, init);
auto output = xla::GetTupleElement(result,2);
ctx->SetOutput(0, output);
diff --git a/tensorflow/python/ops/bincount_ops_test.py b/tensorflow/python/ops/bincount_ops_test.py
index 8c482db..92ee1ee 100644
--- a/tensorflow/python/ops/bincount_ops_test.py
+++ b/tensorflow/python/ops/bincount_ops_test.py
@@ -165,10 +165,6 @@
@parameterized.named_parameters(
{
- "testcase_name": "_baseline",
- "x": np.array([1, 1, 2, 3, 2, 4, 4, 5], dtype=np.int32),
- "expected_values": [0, 2, 2, 1, 2, 1]
- }, {
"testcase_name": "_no_maxlength",
"x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
"expected_values": [[0, 1, 1, 1, 0, 0],[0, 0, 0, 0, 2, 1]]
@@ -189,7 +185,65 @@
"minlength": 3,
"expected_values": [[0, 1, 1, 1, 0, 0, 0, 1],
[1, 0, 0, 0, 2, 0, 0, 1]]
- })
+ }, {
+ "testcase_name": "_no_maxlength_binary",
+ "x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
+ "expected_values": [[0, 1, 1, 1, 0, 0],
+ [0, 0, 0, 0, 1, 1]],
+ "binary_output": True,
+ }, {
+ "testcase_name": "_maxlength_binary",
+ "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
+ "maxlength": 7,
+ "expected_values": [[0, 1, 1, 1, 0, 0, 0],
+ [1, 0, 0, 0, 1, 0, 0]],
+ "binary_output": True,
+ }, {
+ "testcase_name": "_minlength_binary",
+ "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
+ "minlength": 9,
+ "expected_values": [[0, 1, 1, 1, 0, 0, 0, 1, 0],
+ [1, 0, 0, 0, 1, 0, 0, 1, 0]],
+ "binary_output": True,
+ }, {
+ "testcase_name": "_minlength_larger_values_binary",
+ "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
+ "minlength": 3,
+ "expected_values": [[0, 1, 1, 1, 0, 0, 0, 1],
+ [1, 0, 0, 0, 1, 0, 0, 1]],
+ "binary_output": True,
+ }, {
+ "testcase_name": "_no_maxlength_weights",
+ "x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
+ "expected_values": [[0. , 2. , 1. , 0.5, 0. , 0. ],
+ [0. , 0. , 0. , 0. , 9. , 3. ]],
+ "weights": [[0.5, 1, 2], [3, 4, 5]]
+ }, {
+ "testcase_name": "_1d",
+ "x": np.array([3, 2, 1, 1], dtype=np.int32),
+ "expected_values": [0, 2, 1, 1]
+ }, {
+ "testcase_name": "_1d_binary",
+ "x": np.array([3, 2, 1, 1], dtype=np.int32),
+ "expected_values": [0, 1, 1, 1],
+ "binary_output": True
+ }, {
+ "testcase_name": "_1d_no_maxlenght_weights",
+ "x": np.array([3, 2, 1, 5, 4, 4], dtype=np.int32),
+ "weights": [0.5, 1, 2, 3, 4, 5],
+ "expected_values": [0. , 2. , 1. , 0.5, 9. , 3. ]
+ }, #{
+ # This is going to fail
+ # INVALID_ARGUMENT: Detected unsupported operations when trying to compile graph...
+ # Bincount (No registered 'Bincount' OpKernel for XLA_CPU_JIT devices compatible
+ # with node {{node bincount/Bincount}}){{node bincount/Bincount}}`
+ #
+ # "testcase_name": "_all_axes",
+ # "x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
+ # "expected_values": [0, 4, 4, 5],
+ # "axis": None
+ #}
+ )
def test_compiled_dense(self,
x,
expected_values,