Handle weights and binary_output
diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc
index a945e84..ad8e5a0 100644
--- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc
@@ -17,6 +17,7 @@
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/xla/client/lib/comparators.h"
 #include "tensorflow/compiler/xla/client/lib/constants.h"
 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
@@ -27,14 +28,32 @@
 
 // TODO: This is only a dummy kernel
 class DenseBincountOp : public XlaOpKernel {
+  private:
+  bool binary_output_;
  public:
   explicit DenseBincountOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
-    DataType dtype;
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("binary_output", &binary_output_));
   }
   
   void Compile(XlaOpKernelContext* ctx) override {
   // Dumb implementation for the simplest test case
     xla::XlaOp input = ctx->Input(0);
+    xla::XlaOp weights = ctx->Input(2);    
+    StatusOr<xla::Shape> weights_shape_or = ctx->builder()->GetShape(weights);
+    OP_REQUIRES_OK(ctx, weights_shape_or.status());
+    auto weights_shape = weights_shape_or.ValueOrDie();
+    auto weights_size = weights_shape.dimensions(0);
+    auto input_xla_type = ctx->input_xla_type(0);
+    xla::PrimitiveType dtype;
+    bool has_weight;
+    if (weights_size){
+      has_weight = true;
+      dtype = ctx->input_xla_type(2);
+    }
+    else {
+      has_weight = false;
+      dtype = input_xla_type;
+    }
     int64_t output_size;
     ctx->ConstantInputAsIntScalar("size", &output_size);
     StatusOr<xla::Shape> input_shape_or = ctx->builder()->GetShape(input);
@@ -44,16 +63,16 @@
     auto dim = 1;
     auto rank = input_shape.rank();
     auto counter_shape = xla::ShapeUtil::MakeShape(xla::S32, {});
-    const xla::Shape data_shape = xla::ShapeUtil::MakeShape(xla::S32, {input_shape.dimensions()});
+    const xla::Shape data_shape = xla::ShapeUtil::MakeShape(input_xla_type, {input_shape.dimensions()});
 
-    xla::Shape output_shape = xla::ShapeUtil::MakeShape(xla::S32, {output_size});
+    xla::Shape output_shape = xla::ShapeUtil::MakeShape(dtype, {output_size});
     if (rank == 2) {
-      output_shape = xla::ShapeUtil::MakeShape(xla::S32, {rank, output_size});
+      output_shape = xla::ShapeUtil::MakeShape(dtype, {rank, output_size});
       dim = input_shape.dimensions(1);
     }
 
     auto loop_shape = xla::ShapeUtil::MakeTupleShape(
-        {counter_shape, data_shape, output_shape});
+        {counter_shape, data_shape, output_shape, weights_shape});
 
     // Create a computation for the condition
     xla::XlaComputation condition;
@@ -75,15 +94,18 @@
       auto counter = xla::GetTupleElement(param, 0);
       auto data_stack = xla::GetTupleElement(param, 1);
       auto accum_stack = xla::GetTupleElement(param, 2);
-      
+      auto weights = xla::GetTupleElement(param, 3);
+      auto accum_shape = xla::ShapeUtil::MakeShape(dtype, {});
+
       if (rank == 1) {
         auto data = xla::DynamicSlice(data_stack, {counter}, {1});
         auto accum = xla::DynamicSlice(accum_stack, {data}, {1});
         accum = xla::Reshape(accum, {0}, {});
+        accum = xla::ConvertElementType(accum, dtype);
         auto data_scalar = xla::Reshape(data, {0}, {});
 
         auto condition_shape = xla::ShapeUtil::MakeTupleShape(
-          {counter_shape, counter_shape, output_shape});
+          {counter_shape, counter_shape, accum_shape, output_shape, weights_shape});
 
         xla::XlaComputation update;
         { 
@@ -91,9 +113,21 @@
               ctx->builder()->CreateSubBuilder("update");
           auto param = Parameter(true_builder.get(), 0, condition_shape, "param");
           auto data_scalar = xla::GetTupleElement(param, 0);
-          auto accum = xla::GetTupleElement(param, 1);
-          auto accum_stack  = xla::GetTupleElement(param, 2);
-          accum = accum + xla::One(true_builder.get(), xla::S32);
+          auto counter = xla::GetTupleElement(param, 1);
+          auto accum = xla::GetTupleElement(param, 2);
+          auto accum_stack  = xla::GetTupleElement(param, 3);
+          auto weights = xla::GetTupleElement(param, 4);
+          if (binary_output_){
+            accum = xla::One(true_builder.get(), dtype);
+          }
+          else if (! has_weight) {
+            accum = accum + xla::One(true_builder.get(), dtype);
+          }
+          else {
+            auto weight = xla::DynamicSlice(weights, {counter}, {1});
+            weight = xla::Reshape(weight, {0}, {});
+            accum = accum + weight;
+          }
           accum_stack = xla::DynamicUpdateSlice(
               accum_stack, xla::Reshape(accum, {1}), {data_scalar});
           xla::Tuple(true_builder.get(), {accum, accum_stack});
@@ -106,22 +140,25 @@
               ctx->builder()->CreateSubBuilder("no_update");
           auto param = Parameter(false_builder.get(), 0, condition_shape, "param");
           auto data = xla::GetTupleElement(param, 0);
-          auto accum = xla::GetTupleElement(param, 1);
-          auto accum_stack  = xla::GetTupleElement(param, 2);
+          auto count = xla::GetTupleElement(param, 1);
+          auto accum = xla::GetTupleElement(param, 2);
+          auto accum_stack  = xla::GetTupleElement(param, 3);
           xla::Tuple(false_builder.get(), {accum, accum_stack});
           no_update = false_builder->Build().ValueOrDie();
         }
 
         auto output_size_xla = xla::ConstantR0<int32_t>(builder.get(), output_size);
         auto pred = xla::Lt(data_scalar, output_size_xla);
-        auto tuple = xla::Tuple(builder.get(), {data_scalar, accum, accum_stack});
+        auto tuple = xla::Tuple(builder.get(), {data_scalar, counter, accum, accum_stack, weights});
         auto cond = xla::Conditional(pred, tuple, update, tuple, no_update);
         accum = xla::GetTupleElement(cond, 0);
         accum_stack = xla::GetTupleElement(cond, 1);
+
       }
       else {
         auto condition_shape = xla::ShapeUtil::MakeTupleShape(
-          {counter_shape, counter_shape, output_shape, counter_shape});
+          {counter_shape, counter_shape, counter_shape, output_shape, 
+           accum_shape, weights_shape});
 
         auto dim_xla = xla::ConstantR0<int32_t>(builder.get(), dim);
         auto idx_1 = xla::Div(counter, dim_xla);
@@ -130,7 +167,7 @@
         auto data_scalar = xla::Reshape(data, {0,1}, {});
         auto accum = xla::DynamicSlice(accum_stack, {idx_1, data_scalar}, {1, 1});
         accum = xla::Reshape(accum, {0,1}, {});
-
+        accum = xla::ConvertElementType(accum, dtype);
         xla::XlaComputation update;
         {
           std::unique_ptr<xla::XlaBuilder> true_builder =
@@ -139,9 +176,21 @@
           
           auto data_scalar = xla::GetTupleElement(param, 0);
           auto idx_1 = xla::GetTupleElement(param, 1);
-          auto accum_stack  = xla::GetTupleElement(param, 2);
-          auto accum = xla::GetTupleElement(param, 3);
-          accum = accum + xla::One(true_builder.get(), xla::S32);
+          auto idx_2 = xla::GetTupleElement(param, 2);
+          auto accum_stack  = xla::GetTupleElement(param, 3);
+          auto accum = xla::GetTupleElement(param, 4);
+          auto weights = xla::GetTupleElement(param, 5);
+          if (binary_output_){
+            accum = xla::One(true_builder.get(), dtype);
+          }
+          else if (! has_weight) {
+            accum = accum + xla::One(true_builder.get(), dtype);
+          }
+          else {
+            auto weight = xla::DynamicSlice(weights, {idx_1, idx_2}, {1, 1});
+            auto weigth_scalar = xla::Reshape(weight, {0,1}, {});
+            accum = accum + weigth_scalar;
+          }
           accum_stack = xla::DynamicUpdateSlice(
               accum_stack, xla::Reshape(accum, {1, 1}), {idx_1, data_scalar});
           xla::Tuple(true_builder.get(), {accum, accum_stack});
@@ -153,30 +202,29 @@
           std::unique_ptr<xla::XlaBuilder> false_builder =
               builder->CreateSubBuilder("no_update_rank2");
           auto param = Parameter(false_builder.get(), 0, condition_shape, "param");
-          auto data_scalar = xla::GetTupleElement(param, 0);
-          auto idx_1 = xla::GetTupleElement(param, 1);
-          auto accum_stack  = xla::GetTupleElement(param, 2);
-          auto accum = xla::GetTupleElement(param, 3);
+          auto accum_stack  = xla::GetTupleElement(param, 3);
+          auto accum = xla::GetTupleElement(param, 4);
           xla::Tuple(false_builder.get(), {accum, accum_stack});
           no_update = false_builder->Build().ValueOrDie();
         }
         auto limit = xla::ConstantR0<int32_t>(builder.get(), output_size);
         
         auto pred = xla::Lt(data_scalar, limit);
-        auto tuple = xla::Tuple(builder.get(), {data_scalar, idx_1, accum_stack, accum});
+        auto tuple = xla::Tuple(builder.get(), {data_scalar, idx_1, idx_2, accum_stack, accum, weights});
         auto cond = xla::Conditional(pred, tuple, update, tuple, no_update);
         accum = xla::GetTupleElement(cond, 0);
         accum_stack = xla::GetTupleElement(cond, 1);
       }
       counter = counter + xla::One(builder.get(), xla::S32);
-      xla::Tuple(builder.get(), {counter, data_stack, accum_stack});
+      xla::Tuple(builder.get(), {counter, data_stack, accum_stack, weights});
       body = builder->Build().ConsumeValueOrDie();
     }
 
     // Create a While node with computations for the condition and the body.
     auto zero = xla::Zero(ctx->builder(), xla::S32);
-    auto zero_broadcast = xla::Broadcast(zero, {output_shape.dimensions()});
-    auto init = xla::Tuple(ctx->builder(), {zero, input, zero_broadcast});
+    auto zero_out = xla::Zero(ctx->builder(), dtype);
+    auto zero_broadcast = xla::Broadcast(zero_out, {output_shape.dimensions()});
+    auto init = xla::Tuple(ctx->builder(), {zero, input, zero_broadcast, weights});
     auto result = xla::While(condition, body, init);
     auto output = xla::GetTupleElement(result,2);
     ctx->SetOutput(0, output);
diff --git a/tensorflow/python/ops/bincount_ops_test.py b/tensorflow/python/ops/bincount_ops_test.py
index 8c482db..92ee1ee 100644
--- a/tensorflow/python/ops/bincount_ops_test.py
+++ b/tensorflow/python/ops/bincount_ops_test.py
@@ -165,10 +165,6 @@
 
   @parameterized.named_parameters(
     {
-        "testcase_name": "_baseline",
-        "x": np.array([1, 1, 2, 3, 2, 4, 4, 5], dtype=np.int32),
-        "expected_values": [0, 2, 2, 1, 2, 1]
-    }, {
         "testcase_name": "_no_maxlength",
         "x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
         "expected_values": [[0, 1, 1, 1, 0, 0],[0, 0, 0, 0, 2, 1]]
@@ -189,7 +185,65 @@
         "minlength": 3,
         "expected_values": [[0, 1, 1, 1, 0, 0, 0, 1],
                             [1, 0, 0, 0, 2, 0, 0, 1]]
-    })
+    }, {
+        "testcase_name": "_no_maxlength_binary",
+        "x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
+        "expected_values": [[0, 1, 1, 1, 0, 0],
+                            [0, 0, 0, 0, 1, 1]],
+        "binary_output": True,
+    }, {
+        "testcase_name": "_maxlength_binary",
+        "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
+        "maxlength": 7,
+        "expected_values": [[0, 1, 1, 1, 0, 0, 0],
+                            [1, 0, 0, 0, 1, 0, 0]],
+        "binary_output": True,
+    }, {
+        "testcase_name": "_minlength_binary",
+        "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
+        "minlength": 9,
+        "expected_values": [[0, 1, 1, 1, 0, 0, 0, 1, 0],
+                            [1, 0, 0, 0, 1, 0, 0, 1, 0]],
+        "binary_output": True,
+    }, {
+        "testcase_name": "_minlength_larger_values_binary",
+        "x": np.array([[3, 2, 1, 7], [7, 0, 4, 4]], dtype=np.int32),
+        "minlength": 3,
+        "expected_values": [[0, 1, 1, 1, 0, 0, 0, 1],
+                            [1, 0, 0, 0, 1, 0, 0, 1]],
+        "binary_output": True,
+    }, {
+        "testcase_name": "_no_maxlength_weights",
+        "x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
+        "expected_values": [[0. , 2. , 1. , 0.5, 0. , 0. ],
+                            [0. , 0. , 0. , 0. , 9. , 3. ]],
+        "weights": [[0.5, 1, 2], [3, 4, 5]]
+    }, {
+        "testcase_name": "_1d",
+        "x": np.array([3, 2, 1, 1], dtype=np.int32),
+        "expected_values": [0, 2, 1, 1]
+    }, {
+        "testcase_name": "_1d_binary",
+        "x": np.array([3, 2, 1, 1], dtype=np.int32),
+        "expected_values": [0, 1, 1, 1],
+        "binary_output": True
+    }, {
+        "testcase_name": "_1d_no_maxlenght_weights",
+        "x": np.array([3, 2, 1, 5, 4, 4], dtype=np.int32),
+        "weights": [0.5, 1, 2, 3, 4, 5],
+        "expected_values": [0. , 2. , 1. , 0.5, 9. , 3. ]
+    }, #{
+       # This is going to fail
+       # INVALID_ARGUMENT: Detected unsupported operations when trying to compile graph...  
+       # Bincount (No registered 'Bincount' OpKernel for XLA_CPU_JIT devices compatible 
+       # with node {{node bincount/Bincount}}){{node bincount/Bincount}}` 
+       #
+       # "testcase_name": "_all_axes",
+       # "x": np.array([[3, 2, 1], [5, 4, 4]], dtype=np.int32),
+       # "expected_values": [0, 4, 4, 5],
+       # "axis": None
+    #}
+    )
   def test_compiled_dense(self,
                        x,
                        expected_values,