Add DiscretizingCombiner to KPL.

PiperOrigin-RevId: 339468010
Change-Id: Ic2a5033eae9ff305ce311e4022c1bc8ab64ad230
diff --git a/RELEASE.md b/RELEASE.md
index ed1c789..1011610 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -28,6 +28,9 @@
 *   <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
 *   <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
 *   <NOTES SHOULD BE GROUPED PER AREA>
+*   `tf.keras`:
+    *   Improvements to Keras preprocessing layers:
+        *   Discretization combiner implemented, with additional arg `epsilon`.
 
 *   `tf.data`:
     *   Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index 3d41c60..d131001 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -47,6 +47,7 @@
     name = "discretization",
     srcs = [
         "discretization.py",
+        "discretization_v1.py",
     ],
     srcs_version = "PY2AND3",
     deps = [
@@ -54,6 +55,7 @@
         "//tensorflow/python:boosted_trees_ops",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:resources",
         "//tensorflow/python:sparse_tensor",
         "//tensorflow/python:tensor_spec",
         "//tensorflow/python:tf_export",
@@ -461,6 +463,7 @@
     size = "small",
     srcs = ["discretization_test.py"],
     python_version = "PY3",
+    shard_count = 4,
     tags = ["no_rocm"],
     deps = [
         ":discretization",
diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD
index 0935d86..7a965bf 100644
--- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD
@@ -103,6 +103,16 @@
     ],
 )
 
+tf_py_test(
+    name = "discretization_adapt_benchmark",
+    srcs = ["discretization_adapt_benchmark.py"],
+    python_version = "PY3",
+    deps = [
+        "//tensorflow:tensorflow_py",
+        "//tensorflow/python/keras/layers/preprocessing:discretization",
+    ],
+)
+
 cuda_py_test(
     name = "image_preproc_benchmark",
     srcs = ["image_preproc_benchmark.py"],
diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py
new file mode 100644
index 0000000..d0fb194
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py
@@ -0,0 +1,120 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for Keras discretization preprocessing layer's adapt method."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl import flags
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.compat import v2_compat
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras.layers.preprocessing import discretization
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import benchmark
+from tensorflow.python.platform import test
+
+FLAGS = flags.FLAGS
+EPSILON = 0.1
+
+v2_compat.enable_v2_behavior()
+
+
+def reduce_fn(state, values, epsilon=EPSILON):
+  """tf.data.Dataset-friendly implementation of mean and variance."""
+
+  state_, = state
+  summary = discretization.summarize(values, epsilon)
+  if np.sum(state_[:, 0]) == 0:
+    return (summary,)
+  return (discretization.merge_summaries(state_, summary, epsilon),)
+
+
+class BenchmarkAdapt(benchmark.Benchmark):
+  """Benchmark adapt."""
+
+  def run_dataset_implementation(self, num_elements, batch_size):
+    input_t = keras.Input(shape=(1,))
+    layer = discretization.Discretization()
+    _ = layer(input_t)
+
+    num_repeats = 5
+    starts = []
+    ends = []
+    for _ in range(num_repeats):
+      ds = dataset_ops.Dataset.range(num_elements)
+      ds = ds.map(
+          lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
+      ds = ds.batch(batch_size)
+
+      starts.append(time.time())
+      # Benchmarked code begins here.
+      state = ds.reduce((np.zeros((1, 2)),), reduce_fn)
+
+      bins = discretization.get_bucket_boundaries(state, 100)
+      layer.set_weights([bins])
+      # Benchmarked code ends here.
+      ends.append(time.time())
+
+    avg_time = np.mean(np.array(ends) - np.array(starts))
+    return avg_time
+
+  def bm_adapt_implementation(self, num_elements, batch_size):
+    """Test the KPL adapt implementation."""
+    input_t = keras.Input(shape=(1,), dtype=dtypes.float32)
+    layer = discretization.Discretization()
+    _ = layer(input_t)
+
+    num_repeats = 5
+    starts = []
+    ends = []
+    for _ in range(num_repeats):
+      ds = dataset_ops.Dataset.range(num_elements)
+      ds = ds.map(
+          lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
+      ds = ds.batch(batch_size)
+
+      starts.append(time.time())
+      # Benchmarked code begins here.
+      layer.adapt(ds)
+      # Benchmarked code ends here.
+      ends.append(time.time())
+
+    avg_time = np.mean(np.array(ends) - np.array(starts))
+    name = "discretization_adapt|%s_elements|batch_%s" % (num_elements,
+                                                          batch_size)
+    baseline = self.run_dataset_implementation(num_elements, batch_size)
+    extras = {
+        "tf.data implementation baseline": baseline,
+        "delta seconds": (baseline - avg_time),
+        "delta percent": ((baseline - avg_time) / baseline) * 100
+    }
+    self.report_benchmark(
+        iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
+
+  def benchmark_vocab_size_by_batch(self):
+    for vocab_size in [100, 1000, 10000, 100000, 1000000]:
+      for batch in [64 * 2048]:
+        self.bm_adapt_implementation(vocab_size, batch)
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/python/keras/layers/preprocessing/discretization.py b/tensorflow/python/keras/layers/preprocessing/discretization.py
index cdbc8f9..4ffe35c 100644
--- a/tensorflow/python/keras/layers/preprocessing/discretization.py
+++ b/tensorflow/python/keras/layers/preprocessing/discretization.py
@@ -17,24 +17,125 @@
 from __future__ import division
 from __future__ import print_function
 
+import collections
+import json
+
 import numpy as np
 
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras import backend as K
 from tensorflow.python.keras.engine import base_preprocessing_layer
+from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_boosted_trees_ops
 from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.parallel_for import control_flow_ops
 from tensorflow.python.ops.ragged import ragged_functional_ops
+from tensorflow.python.util import compat
 from tensorflow.python.util.tf_export import keras_export
 
 
+_BINS_NAME = "bins"
+
+
+def summarize(values, epsilon):
+  """Reduce a 1D sequence of values to a summary.
+
+  This algorithm is based on numpy.quantiles but modified to allow for
+  intermediate steps between multiple data sets. It first finds the target
+  number of bins as the reciprocal of epsilon and then takes the individual
+  values spaced at appropriate intervals to arrive at that target.
+  The final step is to return the corresponding counts between those values
+  If the target num_bins is larger than the size of values, the whole array is
+  returned (with weights of 1).
+
+  Arguments:
+      values: 1-D `np.ndarray` to be summarized.
+      epsilon: A `'float32'` that determines the approxmiate desired precision.
+
+  Returns:
+      A 2-D `np.ndarray` that is a summary of the inputs. First column is the
+      interpolated partition values, the second is the weights (counts).
+  """
+
+  num_bins = 1.0 / epsilon
+  value_shape = values.shape
+  n = np.prod([[(1 if dim is None else dim) for dim in value_shape]])
+  if num_bins >= n:
+    return np.hstack((np.expand_dims(np.sort(values), 1), np.ones((n, 1))))
+  step_size = int(n / num_bins)
+  partition_indices = np.arange(step_size, n, step_size, np.int64)
+
+  part = np.partition(values, partition_indices)[partition_indices]
+
+  return np.hstack((np.expand_dims(part, 1),
+                    step_size * np.ones((np.prod(part.shape), 1))))
+
+
+def compress(summary, epsilon):
+  """Compress a summary to within `epsilon` accuracy.
+
+  The compression step is needed to keep the summary sizes small after merging,
+  and also used to return the final target boundaries. It finds the new bins
+  based on interpolating cumulative weight percentages from the large summary.
+  Taking the difference of the cumulative weights from the previous bin's
+  cumulative weight will give the new weight for that bin.
+
+  Arguments:
+      summary: 2-D `np.ndarray` summary to be compressed.
+      epsilon: A `'float32'` that determines the approxmiate desired precision.
+
+  Returns:
+      A 2-D `np.ndarray` that is a compressed summary. First column is the
+      interpolated partition values, the second is the weights (counts).
+  """
+  if np.prod(summary[:, 0].shape) * epsilon < 1:
+    return summary
+
+  percents = epsilon + np.arange(0.0, 1.0, epsilon)
+  cum_weights = summary[:, 1].cumsum()
+  cum_weight_percents = cum_weights / cum_weights[-1]
+  new_bins = np.interp(percents, cum_weight_percents, summary[:, 0])
+  cum_weights = np.interp(percents, cum_weight_percents, cum_weights)
+  new_weights = cum_weights - np.concatenate((np.array([0]), cum_weights[:-1]))
+
+  return np.hstack((np.expand_dims(new_bins, 1),
+                    np.expand_dims(new_weights, 1)))
+
+
+def merge_summaries(prev_summary, next_summary, epsilon):
+  """Weighted merge sort of summaries.
+
+  Given two summaries of distinct data, this function merges (and compresses)
+  them to stay within `epsilon` error tolerance.
+
+  Arguments:
+      prev_summary: 2-D `np.ndarray` summary to be merged with `next_summary`.
+      next_summary: 2-D `np.ndarray` summary to be merged with `prev_summary`.
+      epsilon: A `'float32'` that determines the approxmiate desired precision.
+
+  Returns:
+      A 2-D `np.ndarray` that is a merged summary. First column is the
+      interpolated partition values, the second is the weights (counts).
+  """
+  merged = np.concatenate((prev_summary, next_summary))
+  merged = merged[merged[:, 0].argsort()]
+  if np.prod(merged.shape) * epsilon < 1:
+    return merged
+  return compress(merged, epsilon)
+
+
+def get_bucket_boundaries(summary, num_bins):
+  return compress(summary, 1.0 / num_bins)[:-1, 0]
+
+
 @keras_export("keras.layers.experimental.preprocessing.Discretization")
-class Discretization(base_preprocessing_layer.PreprocessingLayer):
+class Discretization(base_preprocessing_layer.CombinerPreprocessingLayer):
   """Buckets data into discrete ranges.
 
   This layer will place each element of its input data into one of several
@@ -48,9 +149,15 @@
     Same as input shape.
 
   Attributes:
-    bins: Optional boundary specification. Bins exclude the left boundary and
-      include the right boundary, so `bins=[0., 1., 2.]` generates bins
+    bins: Optional boundary specification or number of bins to compute if `int`.
+      Bins exclude the left boundary and include the right boundary,
+      so `bins=[0., 1., 2.]` generates bins
       `(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`, and `[2., +inf)`.
+      This would correspond to bins = 4.
+    epsilon: Error tolerance, typically a small fraction close to zero (e.g.
+      0.01). Higher values of epsilon increase the quantile approximation, and
+      hence result in more unequal buckets, but could improve performance
+      and resource consumption.
 
   Examples:
 
@@ -62,19 +169,47 @@
   <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
   array([[0, 1, 3, 1],
          [0, 3, 2, 0]], dtype=int32)>
+
+  Bucketize float values based on a number of buckets to compute.
+  >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
+  >>> layer = tf.keras.layers.experimental.preprocessing.Discretization(
+  ...          bins=4, epsilon=0.01)
+  >>> layer.adapt(input)
+  >>> layer(input)
+  <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
+  array([[0, 2, 3, 1],
+         [0, 3, 2, 0]], dtype=int32)>
   """
 
-  def __init__(self, bins, **kwargs):
-    super(Discretization, self).__init__(**kwargs)
-    base_preprocessing_layer._kpl_gauge.get_cell("V2").set("Discretization")
-    # The bucketization op requires a final rightmost boundary in order to
-    # correctly assign values higher than the largest left boundary.
-    # This should not impact intended buckets even if a max value is provided.
-    self.bins = np.append(bins, [np.Inf])
+  def __init__(self,
+               bins,
+               epsilon=0.01,
+               **kwargs):
+    super(Discretization, self).__init__(
+        combiner=Discretization.DiscretizingCombiner(
+            epsilon, bins if isinstance(bins, int) else 1),
+        **kwargs)
+    if bins is not None and not isinstance(bins, int):
+      self.bins = np.append(bins, [np.Inf])
+    else:
+      self.bins = np.zeros(bins)
+    # Need this to return correct config
+    self.input_bins = bins
+    self.epsilon = epsilon
+
+  def build(self, input_shape):
+    self.bins = self._add_state_variable(
+        name=_BINS_NAME,
+        shape=(self.bins.size,),
+        dtype=dtypes.float32,
+        initializer=init_ops.constant_initializer(self.bins))
+    super(Discretization, self).build(input_shape)
 
   def get_config(self):
     config = {
-        "bins": self.bins,
+        "bins": None if self.input_bins is None else (
+            K.get_value(self.input_bins)),
+        "epsilon": self.epsilon,
     }
     base_config = super(Discretization, self).get_config()
     return dict(list(base_config.items()) + list(config.items()))
@@ -129,3 +264,82 @@
           control_flow_ops.vectorized_map(
               _bucketize_op(array_ops.squeeze(self.bins)), reshaped),
           array_ops.constant([-1] + input_shape.as_list()[1:]))
+
+  class DiscretizingCombiner(Combiner):
+    """Combiner for the Discretization preprocessing layer.
+
+    This class encapsulates the computations for finding the quantile boundaries
+    of a set of data in a stable and numerically correct way. Its associated
+    accumulator is a namedtuple('summaries'), representing summarizations of
+    the data used to generate boundaries.
+
+    Attributes:
+      epsilon: Error tolerance.
+      num_bins: The desired number of buckets.
+    """
+
+    def __init__(self, epsilon, num_bins,):
+      self.epsilon = epsilon
+      self.num_bins = num_bins
+
+      # TODO(mwunder): Implement elementwise per-column discretization.
+
+    def compute(self, values, accumulator=None):
+      """Compute a step in this computation, returning a new accumulator."""
+
+      if isinstance(values, sparse_tensor.SparseTensor):
+        values = values.values
+      if tf_utils.is_ragged(values):
+        values = values.flat_values
+      flattened_input = np.reshape(values, newshape=(-1, 1))
+
+      summaries = [summarize(v, self.epsilon) for v in flattened_input.T]
+
+      if accumulator is None:
+        return self._create_accumulator(summaries)
+      else:
+        return self._create_accumulator(
+            [merge_summaries(prev_summ, summ, self.epsilon)
+             for prev_summ, summ in zip(accumulator.summaries, summaries)])
+
+    def merge(self, accumulators):
+      """Merge several accumulators to a single accumulator."""
+      # Combine accumulators and return the result.
+
+      merged = accumulators[0].summaries
+      for accumulator in accumulators[1:]:
+        merged = [merge_summaries(prev, summary, self.epsilon)
+                  for prev, summary in zip(merged, accumulator.summaries)]
+
+      return self._create_accumulator(merged)
+
+    def extract(self, accumulator):
+      """Convert an accumulator into a dict of output values."""
+
+      boundaries = [np.append(get_bucket_boundaries(summary, self.num_bins),
+                              [np.Inf])
+                    for summary in accumulator.summaries]
+      return {
+          _BINS_NAME: np.squeeze(np.vstack(boundaries))
+      }
+
+    def restore(self, output):
+      """Create an accumulator based on 'output'."""
+      raise NotImplementedError(
+          "Discretization does not restore or support streaming updates.")
+
+    def serialize(self, accumulator):
+      """Serialize an accumulator for a remote call."""
+      output_dict = {
+          _BINS_NAME: [summary.tolist() for summary in accumulator.summaries]
+      }
+      return compat.as_bytes(json.dumps(output_dict))
+
+    def deserialize(self, encoded_accumulator):
+      """Deserialize an accumulator received from 'serialize()'."""
+      value_dict = json.loads(compat.as_text(encoded_accumulator))
+      return self._create_accumulator(np.array(value_dict[_BINS_NAME]))
+
+    def _create_accumulator(self, summaries):
+      """Represent the accumulator as one or more summaries of the dataset."""
+      return collections.namedtuple("Accumulator", ["summaries"])(summaries)
diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_test.py b/tensorflow/python/keras/layers/preprocessing/discretization_test.py
index 9d04ccc..0226355 100644
--- a/tensorflow/python/keras/layers/preprocessing/discretization_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/discretization_test.py
@@ -18,19 +18,32 @@
 from __future__ import division
 from __future__ import print_function
 
+from absl.testing import parameterized
+
 import numpy as np
 
 from tensorflow.python import keras
 
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.layers.preprocessing import discretization
+from tensorflow.python.keras.layers.preprocessing import discretization_v1
 from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
 from tensorflow.python.ops.ragged import ragged_factory_ops
 from tensorflow.python.platform import test
 
 
+def get_layer_class():
+  if context.executing_eagerly():
+    return discretization.Discretization
+  else:
+    return discretization_v1.Discretization
+
+
 @keras_parameterized.run_all_keras_modes
 class DiscretizationTest(keras_parameterized.TestCase,
                          preprocessing_test_utils.PreprocessingLayerTest):
@@ -106,7 +119,6 @@
     layer = discretization.Discretization(bins=[-.5, 0.5, 1.5])
     bucket_data = layer(input_data)
     self.assertAllEqual(expected_output_shape, bucket_data.shape.as_list())
-
     model = keras.Model(inputs=input_data, outputs=bucket_data)
     output_dataset = model.predict(input_array)
     self.assertAllEqual(expected_output, output_dataset)
@@ -125,6 +137,110 @@
     self.assertAllEqual(indices, output_dataset.indices)
     self.assertAllEqual(expected_output, output_dataset.values)
 
+  @parameterized.named_parameters([
+      {
+          "testcase_name": "2d_single_element",
+          "adapt_data": np.array([[1.], [2.], [3.], [4.], [5.]]),
+          "test_data": np.array([[1.], [2.], [3.]]),
+          "use_dataset": True,
+          "expected": np.array([[0], [1], [2]]),
+          "num_bins": 5,
+          "epsilon": 0.01
+      }, {
+          "testcase_name": "2d_multi_element",
+          "adapt_data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.],
+                                  [5., 10.]]),
+          "test_data": np.array([[1., 10.], [2., 6.], [3., 8.]]),
+          "use_dataset": True,
+          "expected": np.array([[0, 4], [0, 2], [1, 3]]),
+          "num_bins": 5,
+          "epsilon": 0.01
+      }, {
+          "testcase_name": "1d_single_element",
+          "adapt_data": np.array([3., 2., 1., 5., 4.]),
+          "test_data": np.array([1., 2., 3.]),
+          "use_dataset": True,
+          "expected": np.array([0, 1, 2]),
+          "num_bins": 5,
+          "epsilon": 0.01
+      }, {
+          "testcase_name": "300_batch_1d_single_element_1",
+          "adapt_data": np.arange(300),
+          "test_data": np.arange(300),
+          "use_dataset": True,
+          "expected":
+              np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]),
+          "num_bins": 3,
+          "epsilon": 0.01
+      }, {
+          "testcase_name": "300_batch_1d_single_element_2",
+          "adapt_data": np.arange(300) ** 2,
+          "test_data": np.arange(300) ** 2,
+          "use_dataset": True,
+          "expected":
+              np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]),
+          "num_bins": 3,
+          "epsilon": 0.01
+      }, {
+          "testcase_name": "300_batch_1d_single_element_large_epsilon",
+          "adapt_data": np.arange(300),
+          "test_data": np.arange(300),
+          "use_dataset": True,
+          "expected": np.concatenate([np.zeros(137), np.ones(163)]),
+          "num_bins": 2,
+          "epsilon": 0.1
+      }])
+  def test_layer_computation(self, adapt_data, test_data, use_dataset,
+                             expected, num_bins=5, epsilon=0.01):
+
+    input_shape = tuple(list(test_data.shape)[1:])
+    np.random.shuffle(adapt_data)
+    if use_dataset:
+      # Keras APIs expect batched datasets
+      adapt_data = dataset_ops.Dataset.from_tensor_slices(adapt_data).batch(
+          test_data.shape[0] // 2)
+      test_data = dataset_ops.Dataset.from_tensor_slices(test_data).batch(
+          test_data.shape[0] // 2)
+
+    cls = get_layer_class()
+    layer = cls(epsilon=epsilon, bins=num_bins)
+    layer.adapt(adapt_data)
+
+    input_data = keras.Input(shape=input_shape)
+    output = layer(input_data)
+    model = keras.Model(input_data, output)
+    model._run_eagerly = testing_utils.should_run_eagerly()
+    output_data = model.predict(test_data)
+    self.assertAllClose(expected, output_data)
+
+  @parameterized.named_parameters(
+      {
+          "num_bins": 5,
+          "data": np.array([[1.], [2.], [3.], [4.], [5.]]),
+          "expected": {
+              "bins": np.array([1., 2., 3., 4., np.Inf])
+          },
+          "testcase_name": "2d_single_element_all_bins"
+      }, {
+          "num_bins": 5,
+          "data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.], [5., 10.]]),
+          "expected": {
+              "bins": np.array([2., 4., 6., 8., np.Inf])
+          },
+          "testcase_name": "2d_multi_element_all_bins",
+      }, {
+          "num_bins": 3,
+          "data": np.array([[0.], [1.], [2.], [3.], [4.], [5.]]),
+          "expected": {
+              "bins": np.array([1., 3., np.Inf])
+          },
+          "testcase_name": "2d_single_element_3_bins"
+      })
+  def test_combiner_computation(self, num_bins, data, expected):
+    epsilon = 0.01
+    combiner = discretization.Discretization.DiscretizingCombiner(epsilon,
+                                                                  num_bins)
+    self.validate_accumulator_extract(combiner, data, expected)
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_v1.py b/tensorflow/python/keras/layers/preprocessing/discretization_v1.py
new file mode 100644
index 0000000..6daea9b
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/discretization_v1.py
@@ -0,0 +1,28 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tensorflow V1 version of the Discretization preprocessing layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.engine.base_preprocessing_layer_v1 import CombinerPreprocessingLayer
+from tensorflow.python.keras.layers.preprocessing import discretization
+from tensorflow.python.util.tf_export import keras_export
+
+
+@keras_export(v1=['keras.layers.experimental.preprocessing.Discretization'])
+class Discretization(discretization.Discretization, CombinerPreprocessingLayer):
+  pass
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt
new file mode 100644
index 0000000..088d450
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__"
+tf_class {
+  is_instance: "<type \'type\'>"
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "mro"
+  }
+  member_method {
+    name: "register"
+    argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt
new file mode 100644
index 0000000..2f75c15
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization.DiscretizingCombiner\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.Combiner\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute"
+    argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "deserialize"
+    argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "extract"
+    argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "merge"
+    argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "restore"
+    argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "serialize"
+    argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
index 628f76c..87c0e79 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
@@ -1,6 +1,7 @@
 path: "tensorflow.keras.layers.experimental.preprocessing.Discretization"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
   is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
   is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
   is_instance: "<class \'tensorflow.python.module.module.Module\'>"
@@ -9,6 +10,10 @@
   is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
   is_instance: "<type \'object\'>"
   member {
+    name: "DiscretizingCombiner"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "activity_regularizer"
     mtype: "<type \'property\'>"
   }
@@ -130,7 +135,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None"
+    argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], "
   }
   member_method {
     name: "adapt"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt
new file mode 100644
index 0000000..088d450
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__"
+tf_class {
+  is_instance: "<type \'type\'>"
+  member_method {
+    name: "__init__"
+  }
+  member_method {
+    name: "mro"
+  }
+  member_method {
+    name: "register"
+    argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt
new file mode 100644
index 0000000..2f75c15
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization.DiscretizingCombiner\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.Combiner\'>"
+  is_instance: "<type \'object\'>"
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "compute"
+    argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "deserialize"
+    argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "extract"
+    argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "merge"
+    argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "restore"
+    argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "serialize"
+    argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
index 628f76c..87c0e79 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
@@ -1,6 +1,7 @@
 path: "tensorflow.keras.layers.experimental.preprocessing.Discretization"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization\'>"
+  is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
   is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
   is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
   is_instance: "<class \'tensorflow.python.module.module.Module\'>"
@@ -9,6 +10,10 @@
   is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
   is_instance: "<type \'object\'>"
   member {
+    name: "DiscretizingCombiner"
+    mtype: "<type \'type\'>"
+  }
+  member {
     name: "activity_regularizer"
     mtype: "<type \'property\'>"
   }
@@ -130,7 +135,7 @@
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None"
+    argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], "
   }
   member_method {
     name: "adapt"