Extracting out tf.bin_values_fixed_width from tf.histogram_fixed_width. (#13330)
* Extracting out tf.bin_values_fixed_width from tf.histogram_fixed_width.
* Fixing final reshape.
* Replacing .get_shape() with tf.shape() call.
* Renaming bin_values_fixed_width to histogram_fixed_width_bins.
* Undoing inadvertent merge.
* Removing added newline.
* Fixing newline.
* updating goldens manually.
* Updating goldens.
diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py
index 51e4be9..4313b79 100644
--- a/tensorflow/python/ops/histogram_ops.py
+++ b/tensorflow/python/ops/histogram_ops.py
@@ -17,6 +17,7 @@
Please see @{$python/histogram_ops} guide.
+@@histogram_fixed_width_bins
@@histogram_fixed_width
"""
@@ -32,6 +33,70 @@
from tensorflow.python.ops import math_ops
+def histogram_fixed_width_bins(values,
+ value_range,
+ nbins=100,
+ dtype=dtypes.int32,
+ name=None):
+ """Bins the given values for use in a histogram.
+
+ Given the tensor `values`, this operation returns a rank 1 `Tensor`
+ representing the indices of a histogram into which each element
+ of `values` would be binned. The bins are equal width and
+ determined by the arguments `value_range` and `nbins`.
+
+ Args:
+ values: Numeric `Tensor`.
+ value_range: Shape [2] `Tensor` of same `dtype` as `values`.
+ values <= value_range[0] will be mapped to hist[0],
+ values >= value_range[1] will be mapped to hist[-1].
+ nbins: Scalar `int32 Tensor`. Number of histogram bins.
+ dtype: dtype for returned histogram.
+ name: A name for this operation (defaults to 'histogram_fixed_width').
+
+ Returns:
+ A `Tensor` holding the indices of the binned values whose shape matches
+ `values`.
+
+ Examples:
+
+ ```python
+ # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ nbins = 5
+ value_range = [0.0, 5.0]
+ new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
+
+ with tf.get_default_session() as sess:
+ indices = tf.histogram_fixed_width_bins(new_values, value_range, nbins=5)
+ variables.global_variables_initializer().run()
+ sess.run(indices) => [0, 0, 1, 2, 4]
+ ```
+ """
+ with ops.name_scope(name, 'histogram_fixed_width_bins',
+ [values, value_range, nbins]) as scope:
+ values = ops.convert_to_tensor(values, name='values')
+ shape = array_ops.shape(values)
+
+ values = array_ops.reshape(values, [-1])
+ value_range = ops.convert_to_tensor(value_range, name='value_range')
+ nbins = ops.convert_to_tensor(nbins, dtype=dtypes.int32, name='nbins')
+ nbins_float = math_ops.cast(nbins, values.dtype)
+
+ # Map tensor values that fall within value_range to [0, 1].
+ scaled_values = math_ops.truediv(values - value_range[0],
+ value_range[1] - value_range[0],
+ name='scaled_values')
+
+ # map tensor values within the open interval value_range to {0,.., nbins-1},
+ # values outside the open interval will be zero or less, or nbins or more.
+ indices = math_ops.floor(nbins_float * scaled_values, name='indices')
+
+ # Clip edge cases (e.g. value = value_range[1]) or "outliers."
+ indices = math_ops.cast(
+ clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32)
+ return array_ops.reshape(indices, shape)
+
+
def histogram_fixed_width(values,
value_range,
nbins=100,
diff --git a/tensorflow/python/ops/histogram_ops_test.py b/tensorflow/python/ops/histogram_ops_test.py
index 19ad6cd..80ee090 100644
--- a/tensorflow/python/ops/histogram_ops_test.py
+++ b/tensorflow/python/ops/histogram_ops_test.py
@@ -21,11 +21,64 @@
import numpy as np
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import histogram_ops
from tensorflow.python.platform import test
+class BinValuesFixedWidth(test.TestCase):
+
+ def test_empty_input_gives_all_zero_counts(self):
+ # Bins will be:
+ # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ value_range = [0.0, 5.0]
+ values = []
+ expected_bins = []
+ with self.test_session():
+ bins = histogram_ops.histogram_fixed_width_bins(values, value_range, nbins=5)
+ self.assertEqual(dtypes.int32, bins.dtype)
+ self.assertAllClose(expected_bins, bins.eval())
+
+ def test_1d_values_int32_output(self):
+ # Bins will be:
+ # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ value_range = [0.0, 5.0]
+ values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
+ expected_bins = [0, 0, 1, 2, 4, 4]
+ with self.test_session():
+ bins = histogram_ops.histogram_fixed_width_bins(
+ values, value_range, nbins=5, dtype=dtypes.int64)
+ self.assertEqual(dtypes.int32, bins.dtype)
+ self.assertAllClose(expected_bins, bins.eval())
+
+ def test_1d_float64_values_int32_output(self):
+ # Bins will be:
+ # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ value_range = np.float64([0.0, 5.0])
+ values = np.float64([-1.0, 0.0, 1.5, 2.0, 5.0, 15])
+ expected_bins = [0, 0, 1, 2, 4, 4]
+ with self.test_session():
+ bins = histogram_ops.histogram_fixed_width_bins(
+ values, value_range, nbins=5)
+ self.assertEqual(dtypes.int32, bins.dtype)
+ self.assertAllClose(expected_bins, bins.eval())
+
+ def test_2d_values(self):
+ # Bins will be:
+ # (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+ value_range = [0.0, 5.0]
+ values = constant_op.constant(
+ [[-1.0, 0.0, 1.5], [2.0, 5.0, 15]],
+ shape=(2, 3))
+ expected_bins = [[0, 0, 1], [2, 4, 4]]
+ with self.test_session():
+ bins = histogram_ops.histogram_fixed_width_bins(
+ values, value_range, nbins=5)
+ self.assertEqual(dtypes.int32, bins.dtype)
+ self.assertAllClose(expected_bins, bins.eval())
+
+
class HistogramFixedWidthTest(test.TestCase):
def setUp(self):
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 35917e9..db1ed42 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1161,6 +1161,10 @@
argspec: "args=[\'values\', \'value_range\', \'nbins\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'100\', \"<dtype: \'int32\'>\", \'None\'], "
}
member_method {
+ name: "histogram_fixed_width_bins"
+ argspec: "args=[\'values\', \'value_range\', \'nbins\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'100\', \"<dtype: \'int32\'>\", \'None\'], "
+ }
+ member_method {
name: "identity"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}