Add DiscretizingCombiner to KPL.
PiperOrigin-RevId: 339468010
Change-Id: Ic2a5033eae9ff305ce311e4022c1bc8ab64ad230
diff --git a/RELEASE.md b/RELEASE.md
index ed1c789..1011610 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -28,6 +28,9 @@
* <SIMILAR TO ABOVE SECTION, BUT FOR OTHER IMPORTANT CHANGES / BUG FIXES>
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
+* `tf.keras`:
+ * Improvements to Keras preprocessing layers:
+ * Discretization combiner implemented, with additional arg `epsilon`.
* `tf.data`:
* Exposing `tf.data.experimental.ExternalStatePolicy`, which can be used
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index 3d41c60..d131001 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -47,6 +47,7 @@
name = "discretization",
srcs = [
"discretization.py",
+ "discretization_v1.py",
],
srcs_version = "PY2AND3",
deps = [
@@ -54,6 +55,7 @@
"//tensorflow/python:boosted_trees_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:resources",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tf_export",
@@ -461,6 +463,7 @@
size = "small",
srcs = ["discretization_test.py"],
python_version = "PY3",
+ shard_count = 4,
tags = ["no_rocm"],
deps = [
":discretization",
diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD
index 0935d86..7a965bf 100644
--- a/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/BUILD
@@ -103,6 +103,16 @@
],
)
+tf_py_test(
+ name = "discretization_adapt_benchmark",
+ srcs = ["discretization_adapt_benchmark.py"],
+ python_version = "PY3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python/keras/layers/preprocessing:discretization",
+ ],
+)
+
cuda_py_test(
name = "image_preproc_benchmark",
srcs = ["image_preproc_benchmark.py"],
diff --git a/tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py b/tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py
new file mode 100644
index 0000000..d0fb194
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/benchmarks/discretization_adapt_benchmark.py
@@ -0,0 +1,120 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for Keras discretization preprocessing layer's adapt method."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl import flags
+import numpy as np
+
+from tensorflow.python import keras
+from tensorflow.python.compat import v2_compat
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras.layers.preprocessing import discretization
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import benchmark
+from tensorflow.python.platform import test
+
+FLAGS = flags.FLAGS
+EPSILON = 0.1
+
+v2_compat.enable_v2_behavior()
+
+
+def reduce_fn(state, values, epsilon=EPSILON):
+ """tf.data.Dataset-friendly implementation of mean and variance."""
+
+ state_, = state
+ summary = discretization.summarize(values, epsilon)
+ if np.sum(state_[:, 0]) == 0:
+ return (summary,)
+ return (discretization.merge_summaries(state_, summary, epsilon),)
+
+
+class BenchmarkAdapt(benchmark.Benchmark):
+ """Benchmark adapt."""
+
+ def run_dataset_implementation(self, num_elements, batch_size):
+ input_t = keras.Input(shape=(1,))
+ layer = discretization.Discretization()
+ _ = layer(input_t)
+
+ num_repeats = 5
+ starts = []
+ ends = []
+ for _ in range(num_repeats):
+ ds = dataset_ops.Dataset.range(num_elements)
+ ds = ds.map(
+ lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
+ ds = ds.batch(batch_size)
+
+ starts.append(time.time())
+ # Benchmarked code begins here.
+ state = ds.reduce((np.zeros((1, 2)),), reduce_fn)
+
+ bins = discretization.get_bucket_boundaries(state, 100)
+ layer.set_weights([bins])
+ # Benchmarked code ends here.
+ ends.append(time.time())
+
+ avg_time = np.mean(np.array(ends) - np.array(starts))
+ return avg_time
+
+ def bm_adapt_implementation(self, num_elements, batch_size):
+ """Test the KPL adapt implementation."""
+ input_t = keras.Input(shape=(1,), dtype=dtypes.float32)
+ layer = discretization.Discretization()
+ _ = layer(input_t)
+
+ num_repeats = 5
+ starts = []
+ ends = []
+ for _ in range(num_repeats):
+ ds = dataset_ops.Dataset.range(num_elements)
+ ds = ds.map(
+ lambda x: array_ops.expand_dims(math_ops.cast(x, dtypes.float32), -1))
+ ds = ds.batch(batch_size)
+
+ starts.append(time.time())
+ # Benchmarked code begins here.
+ layer.adapt(ds)
+ # Benchmarked code ends here.
+ ends.append(time.time())
+
+ avg_time = np.mean(np.array(ends) - np.array(starts))
+ name = "discretization_adapt|%s_elements|batch_%s" % (num_elements,
+ batch_size)
+ baseline = self.run_dataset_implementation(num_elements, batch_size)
+ extras = {
+ "tf.data implementation baseline": baseline,
+ "delta seconds": (baseline - avg_time),
+ "delta percent": ((baseline - avg_time) / baseline) * 100
+ }
+ self.report_benchmark(
+ iters=num_repeats, wall_time=avg_time, extras=extras, name=name)
+
+ def benchmark_vocab_size_by_batch(self):
+ for vocab_size in [100, 1000, 10000, 100000, 1000000]:
+ for batch in [64 * 2048]:
+ self.bm_adapt_implementation(vocab_size, batch)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/layers/preprocessing/discretization.py b/tensorflow/python/keras/layers/preprocessing/discretization.py
index cdbc8f9..4ffe35c 100644
--- a/tensorflow/python/keras/layers/preprocessing/discretization.py
+++ b/tensorflow/python/keras/layers/preprocessing/discretization.py
@@ -17,24 +17,125 @@
from __future__ import division
from __future__ import print_function
+import collections
+import json
+
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_preprocessing_layer
+from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_boosted_trees_ops
from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.parallel_for import control_flow_ops
from tensorflow.python.ops.ragged import ragged_functional_ops
+from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import keras_export
+_BINS_NAME = "bins"
+
+
+def summarize(values, epsilon):
+ """Reduce a 1D sequence of values to a summary.
+
+ This algorithm is based on numpy.quantiles but modified to allow for
+ intermediate steps between multiple data sets. It first finds the target
+ number of bins as the reciprocal of epsilon and then takes the individual
+ values spaced at appropriate intervals to arrive at that target.
+ The final step is to return the corresponding counts between those values
+ If the target num_bins is larger than the size of values, the whole array is
+ returned (with weights of 1).
+
+ Arguments:
+ values: 1-D `np.ndarray` to be summarized.
+ epsilon: A `'float32'` that determines the approxmiate desired precision.
+
+ Returns:
+ A 2-D `np.ndarray` that is a summary of the inputs. First column is the
+ interpolated partition values, the second is the weights (counts).
+ """
+
+ num_bins = 1.0 / epsilon
+ value_shape = values.shape
+ n = np.prod([[(1 if dim is None else dim) for dim in value_shape]])
+ if num_bins >= n:
+ return np.hstack((np.expand_dims(np.sort(values), 1), np.ones((n, 1))))
+ step_size = int(n / num_bins)
+ partition_indices = np.arange(step_size, n, step_size, np.int64)
+
+ part = np.partition(values, partition_indices)[partition_indices]
+
+ return np.hstack((np.expand_dims(part, 1),
+ step_size * np.ones((np.prod(part.shape), 1))))
+
+
+def compress(summary, epsilon):
+ """Compress a summary to within `epsilon` accuracy.
+
+ The compression step is needed to keep the summary sizes small after merging,
+ and also used to return the final target boundaries. It finds the new bins
+ based on interpolating cumulative weight percentages from the large summary.
+ Taking the difference of the cumulative weights from the previous bin's
+ cumulative weight will give the new weight for that bin.
+
+ Arguments:
+ summary: 2-D `np.ndarray` summary to be compressed.
+ epsilon: A `'float32'` that determines the approxmiate desired precision.
+
+ Returns:
+ A 2-D `np.ndarray` that is a compressed summary. First column is the
+ interpolated partition values, the second is the weights (counts).
+ """
+ if np.prod(summary[:, 0].shape) * epsilon < 1:
+ return summary
+
+ percents = epsilon + np.arange(0.0, 1.0, epsilon)
+ cum_weights = summary[:, 1].cumsum()
+ cum_weight_percents = cum_weights / cum_weights[-1]
+ new_bins = np.interp(percents, cum_weight_percents, summary[:, 0])
+ cum_weights = np.interp(percents, cum_weight_percents, cum_weights)
+ new_weights = cum_weights - np.concatenate((np.array([0]), cum_weights[:-1]))
+
+ return np.hstack((np.expand_dims(new_bins, 1),
+ np.expand_dims(new_weights, 1)))
+
+
+def merge_summaries(prev_summary, next_summary, epsilon):
+ """Weighted merge sort of summaries.
+
+ Given two summaries of distinct data, this function merges (and compresses)
+ them to stay within `epsilon` error tolerance.
+
+ Arguments:
+ prev_summary: 2-D `np.ndarray` summary to be merged with `next_summary`.
+ next_summary: 2-D `np.ndarray` summary to be merged with `prev_summary`.
+ epsilon: A `'float32'` that determines the approxmiate desired precision.
+
+ Returns:
+ A 2-D `np.ndarray` that is a merged summary. First column is the
+ interpolated partition values, the second is the weights (counts).
+ """
+ merged = np.concatenate((prev_summary, next_summary))
+ merged = merged[merged[:, 0].argsort()]
+ if np.prod(merged.shape) * epsilon < 1:
+ return merged
+ return compress(merged, epsilon)
+
+
+def get_bucket_boundaries(summary, num_bins):
+ return compress(summary, 1.0 / num_bins)[:-1, 0]
+
+
@keras_export("keras.layers.experimental.preprocessing.Discretization")
-class Discretization(base_preprocessing_layer.PreprocessingLayer):
+class Discretization(base_preprocessing_layer.CombinerPreprocessingLayer):
"""Buckets data into discrete ranges.
This layer will place each element of its input data into one of several
@@ -48,9 +149,15 @@
Same as input shape.
Attributes:
- bins: Optional boundary specification. Bins exclude the left boundary and
- include the right boundary, so `bins=[0., 1., 2.]` generates bins
+ bins: Optional boundary specification or number of bins to compute if `int`.
+ Bins exclude the left boundary and include the right boundary,
+ so `bins=[0., 1., 2.]` generates bins
`(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`, and `[2., +inf)`.
+ This would correspond to bins = 4.
+ epsilon: Error tolerance, typically a small fraction close to zero (e.g.
+ 0.01). Higher values of epsilon increase the quantile approximation, and
+ hence result in more unequal buckets, but could improve performance
+ and resource consumption.
Examples:
@@ -62,19 +169,47 @@
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[0, 1, 3, 1],
[0, 3, 2, 0]], dtype=int32)>
+
+ Bucketize float values based on a number of buckets to compute.
+ >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
+ >>> layer = tf.keras.layers.experimental.preprocessing.Discretization(
+ ... bins=4, epsilon=0.01)
+ >>> layer.adapt(input)
+ >>> layer(input)
+ <tf.Tensor: shape=(2, 4), dtype=int32, numpy=
+ array([[0, 2, 3, 1],
+ [0, 3, 2, 0]], dtype=int32)>
"""
- def __init__(self, bins, **kwargs):
- super(Discretization, self).__init__(**kwargs)
- base_preprocessing_layer._kpl_gauge.get_cell("V2").set("Discretization")
- # The bucketization op requires a final rightmost boundary in order to
- # correctly assign values higher than the largest left boundary.
- # This should not impact intended buckets even if a max value is provided.
- self.bins = np.append(bins, [np.Inf])
+ def __init__(self,
+ bins,
+ epsilon=0.01,
+ **kwargs):
+ super(Discretization, self).__init__(
+ combiner=Discretization.DiscretizingCombiner(
+ epsilon, bins if isinstance(bins, int) else 1),
+ **kwargs)
+ if bins is not None and not isinstance(bins, int):
+ self.bins = np.append(bins, [np.Inf])
+ else:
+ self.bins = np.zeros(bins)
+ # Need this to return correct config
+ self.input_bins = bins
+ self.epsilon = epsilon
+
+ def build(self, input_shape):
+ self.bins = self._add_state_variable(
+ name=_BINS_NAME,
+ shape=(self.bins.size,),
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(self.bins))
+ super(Discretization, self).build(input_shape)
def get_config(self):
config = {
- "bins": self.bins,
+ "bins": None if self.input_bins is None else (
+ K.get_value(self.input_bins)),
+ "epsilon": self.epsilon,
}
base_config = super(Discretization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -129,3 +264,82 @@
control_flow_ops.vectorized_map(
_bucketize_op(array_ops.squeeze(self.bins)), reshaped),
array_ops.constant([-1] + input_shape.as_list()[1:]))
+
+ class DiscretizingCombiner(Combiner):
+ """Combiner for the Discretization preprocessing layer.
+
+ This class encapsulates the computations for finding the quantile boundaries
+ of a set of data in a stable and numerically correct way. Its associated
+ accumulator is a namedtuple('summaries'), representing summarizations of
+ the data used to generate boundaries.
+
+ Attributes:
+ epsilon: Error tolerance.
+ num_bins: The desired number of buckets.
+ """
+
+ def __init__(self, epsilon, num_bins,):
+ self.epsilon = epsilon
+ self.num_bins = num_bins
+
+ # TODO(mwunder): Implement elementwise per-column discretization.
+
+ def compute(self, values, accumulator=None):
+ """Compute a step in this computation, returning a new accumulator."""
+
+ if isinstance(values, sparse_tensor.SparseTensor):
+ values = values.values
+ if tf_utils.is_ragged(values):
+ values = values.flat_values
+ flattened_input = np.reshape(values, newshape=(-1, 1))
+
+ summaries = [summarize(v, self.epsilon) for v in flattened_input.T]
+
+ if accumulator is None:
+ return self._create_accumulator(summaries)
+ else:
+ return self._create_accumulator(
+ [merge_summaries(prev_summ, summ, self.epsilon)
+ for prev_summ, summ in zip(accumulator.summaries, summaries)])
+
+ def merge(self, accumulators):
+ """Merge several accumulators to a single accumulator."""
+ # Combine accumulators and return the result.
+
+ merged = accumulators[0].summaries
+ for accumulator in accumulators[1:]:
+ merged = [merge_summaries(prev, summary, self.epsilon)
+ for prev, summary in zip(merged, accumulator.summaries)]
+
+ return self._create_accumulator(merged)
+
+ def extract(self, accumulator):
+ """Convert an accumulator into a dict of output values."""
+
+ boundaries = [np.append(get_bucket_boundaries(summary, self.num_bins),
+ [np.Inf])
+ for summary in accumulator.summaries]
+ return {
+ _BINS_NAME: np.squeeze(np.vstack(boundaries))
+ }
+
+ def restore(self, output):
+ """Create an accumulator based on 'output'."""
+ raise NotImplementedError(
+ "Discretization does not restore or support streaming updates.")
+
+ def serialize(self, accumulator):
+ """Serialize an accumulator for a remote call."""
+ output_dict = {
+ _BINS_NAME: [summary.tolist() for summary in accumulator.summaries]
+ }
+ return compat.as_bytes(json.dumps(output_dict))
+
+ def deserialize(self, encoded_accumulator):
+ """Deserialize an accumulator received from 'serialize()'."""
+ value_dict = json.loads(compat.as_text(encoded_accumulator))
+ return self._create_accumulator(np.array(value_dict[_BINS_NAME]))
+
+ def _create_accumulator(self, summaries):
+ """Represent the accumulator as one or more summaries of the dataset."""
+ return collections.namedtuple("Accumulator", ["summaries"])(summaries)
diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_test.py b/tensorflow/python/keras/layers/preprocessing/discretization_test.py
index 9d04ccc..0226355 100644
--- a/tensorflow/python/keras/layers/preprocessing/discretization_test.py
+++ b/tensorflow/python/keras/layers/preprocessing/discretization_test.py
@@ -18,19 +18,32 @@
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers.preprocessing import discretization
+from tensorflow.python.keras.layers.preprocessing import discretization_v1
from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.platform import test
+def get_layer_class():
+ if context.executing_eagerly():
+ return discretization.Discretization
+ else:
+ return discretization_v1.Discretization
+
+
@keras_parameterized.run_all_keras_modes
class DiscretizationTest(keras_parameterized.TestCase,
preprocessing_test_utils.PreprocessingLayerTest):
@@ -106,7 +119,6 @@
layer = discretization.Discretization(bins=[-.5, 0.5, 1.5])
bucket_data = layer(input_data)
self.assertAllEqual(expected_output_shape, bucket_data.shape.as_list())
-
model = keras.Model(inputs=input_data, outputs=bucket_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)
@@ -125,6 +137,110 @@
self.assertAllEqual(indices, output_dataset.indices)
self.assertAllEqual(expected_output, output_dataset.values)
+ @parameterized.named_parameters([
+ {
+ "testcase_name": "2d_single_element",
+ "adapt_data": np.array([[1.], [2.], [3.], [4.], [5.]]),
+ "test_data": np.array([[1.], [2.], [3.]]),
+ "use_dataset": True,
+ "expected": np.array([[0], [1], [2]]),
+ "num_bins": 5,
+ "epsilon": 0.01
+ }, {
+ "testcase_name": "2d_multi_element",
+ "adapt_data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.],
+ [5., 10.]]),
+ "test_data": np.array([[1., 10.], [2., 6.], [3., 8.]]),
+ "use_dataset": True,
+ "expected": np.array([[0, 4], [0, 2], [1, 3]]),
+ "num_bins": 5,
+ "epsilon": 0.01
+ }, {
+ "testcase_name": "1d_single_element",
+ "adapt_data": np.array([3., 2., 1., 5., 4.]),
+ "test_data": np.array([1., 2., 3.]),
+ "use_dataset": True,
+ "expected": np.array([0, 1, 2]),
+ "num_bins": 5,
+ "epsilon": 0.01
+ }, {
+ "testcase_name": "300_batch_1d_single_element_1",
+ "adapt_data": np.arange(300),
+ "test_data": np.arange(300),
+ "use_dataset": True,
+ "expected":
+ np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]),
+ "num_bins": 3,
+ "epsilon": 0.01
+ }, {
+ "testcase_name": "300_batch_1d_single_element_2",
+ "adapt_data": np.arange(300) ** 2,
+ "test_data": np.arange(300) ** 2,
+ "use_dataset": True,
+ "expected":
+ np.concatenate([np.zeros(101), np.ones(99), 2 * np.ones(100)]),
+ "num_bins": 3,
+ "epsilon": 0.01
+ }, {
+ "testcase_name": "300_batch_1d_single_element_large_epsilon",
+ "adapt_data": np.arange(300),
+ "test_data": np.arange(300),
+ "use_dataset": True,
+ "expected": np.concatenate([np.zeros(137), np.ones(163)]),
+ "num_bins": 2,
+ "epsilon": 0.1
+ }])
+ def test_layer_computation(self, adapt_data, test_data, use_dataset,
+ expected, num_bins=5, epsilon=0.01):
+
+ input_shape = tuple(list(test_data.shape)[1:])
+ np.random.shuffle(adapt_data)
+ if use_dataset:
+ # Keras APIs expect batched datasets
+ adapt_data = dataset_ops.Dataset.from_tensor_slices(adapt_data).batch(
+ test_data.shape[0] // 2)
+ test_data = dataset_ops.Dataset.from_tensor_slices(test_data).batch(
+ test_data.shape[0] // 2)
+
+ cls = get_layer_class()
+ layer = cls(epsilon=epsilon, bins=num_bins)
+ layer.adapt(adapt_data)
+
+ input_data = keras.Input(shape=input_shape)
+ output = layer(input_data)
+ model = keras.Model(input_data, output)
+ model._run_eagerly = testing_utils.should_run_eagerly()
+ output_data = model.predict(test_data)
+ self.assertAllClose(expected, output_data)
+
+ @parameterized.named_parameters(
+ {
+ "num_bins": 5,
+ "data": np.array([[1.], [2.], [3.], [4.], [5.]]),
+ "expected": {
+ "bins": np.array([1., 2., 3., 4., np.Inf])
+ },
+ "testcase_name": "2d_single_element_all_bins"
+ }, {
+ "num_bins": 5,
+ "data": np.array([[1., 6.], [2., 7.], [3., 8.], [4., 9.], [5., 10.]]),
+ "expected": {
+ "bins": np.array([2., 4., 6., 8., np.Inf])
+ },
+ "testcase_name": "2d_multi_element_all_bins",
+ }, {
+ "num_bins": 3,
+ "data": np.array([[0.], [1.], [2.], [3.], [4.], [5.]]),
+ "expected": {
+ "bins": np.array([1., 3., np.Inf])
+ },
+ "testcase_name": "2d_single_element_3_bins"
+ })
+ def test_combiner_computation(self, num_bins, data, expected):
+ epsilon = 0.01
+ combiner = discretization.Discretization.DiscretizingCombiner(epsilon,
+ num_bins)
+ self.validate_accumulator_extract(combiner, data, expected)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/keras/layers/preprocessing/discretization_v1.py b/tensorflow/python/keras/layers/preprocessing/discretization_v1.py
new file mode 100644
index 0000000..6daea9b
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/discretization_v1.py
@@ -0,0 +1,28 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tensorflow V1 version of the Discretization preprocessing layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.engine.base_preprocessing_layer_v1 import CombinerPreprocessingLayer
+from tensorflow.python.keras.layers.preprocessing import discretization
+from tensorflow.python.util.tf_export import keras_export
+
+
+@keras_export(v1=['keras.layers.experimental.preprocessing.Discretization'])
+class Discretization(discretization.Discretization, CombinerPreprocessingLayer):
+ pass
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt
new file mode 100644
index 0000000..088d450
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__"
+tf_class {
+ is_instance: "<type \'type\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt
new file mode 100644
index 0000000..2f75c15
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization.DiscretizingCombiner\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.Combiner\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute"
+ argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "deserialize"
+ argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "extract"
+ argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "merge"
+ argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "restore"
+ argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "serialize"
+ argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
index 628f76c..87c0e79 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
@@ -9,6 +10,10 @@
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
+ name: "DiscretizingCombiner"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
@@ -130,7 +135,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], "
}
member_method {
name: "adapt"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt
new file mode 100644
index 0000000..088d450
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner.__metaclass__"
+tf_class {
+ is_instance: "<type \'type\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt
new file mode 100644
index 0000000..2f75c15
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.-discretizing-combiner.pbtxt
@@ -0,0 +1,34 @@
+path: "tensorflow.keras.layers.experimental.preprocessing.Discretization.DiscretizingCombiner"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization.DiscretizingCombiner\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.Combiner\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'epsilon\', \'num_bins\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "compute"
+ argspec: "args=[\'self\', \'values\', \'accumulator\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "deserialize"
+ argspec: "args=[\'self\', \'encoded_accumulator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "extract"
+ argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "merge"
+ argspec: "args=[\'self\', \'accumulators\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "restore"
+ argspec: "args=[\'self\', \'output\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "serialize"
+ argspec: "args=[\'self\', \'accumulator\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
index 628f76c..87c0e79 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.preprocessing.-discretization.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.keras.layers.experimental.preprocessing.Discretization"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.preprocessing.discretization.Discretization\'>"
+ is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.CombinerPreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_preprocessing_layer.PreprocessingLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
@@ -9,6 +10,10 @@
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
+ name: "DiscretizingCombiner"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
@@ -130,7 +135,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'bins\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'bins\', \'epsilon\'], varargs=None, keywords=kwargs, defaults=[\'0.01\'], "
}
member_method {
name: "adapt"