Add some tests
Refactor to limit kwargs and scalar tensor check
diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
index d8fa489..ceac40ac 100644
--- a/tensorflow/python/autograph/operators/py_builtins.py
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -249,6 +249,17 @@
def _tf_tensor_list_len(s):
return list_ops.tensor_list_length(s)
+def _tf_is_scalar(s):
+ shape = array_ops.shape(s)
+
+ assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
+
+ if shape.shape[0] == 0:
+ return True
+ else:
+ raise ValueError(
+ 'len requires a non-scalar tensor, got one of shape {}'.format(shape))
+
def _tf_tensor_len(s):
"""Overload of len_ for Tensor arguments."""
@@ -322,23 +333,28 @@
else:
_py_print(*objects, **kwargs)
-def max_(s1, s2=UNSPECIFIED):
- if any(tensor_util.is_tf_type(s) for s in (s1,s2)):
- return _tf_max(s1, s2)
- return _py_max(s1, s2)
+def _py_print(*objects, **kwargs):
+ print(*objects, **kwargs)
-def _tf_max(s1, s2):
- if s2 is UNSPECIFIED:
- return math_ops.reduce_max(s1)
- else:
- # TODO (bhack) How much things we need to handle here?
- return constant_op.constant(True)
+def max_(*args, **kwargs):
+ if any(tensor_util.is_tf_type(s) for s in args):
+ return _tf_max(*args, **kwargs)
+ return _py_max(*args, **kwargs)
-def _py_max(s1, s2):
- if s2 is UNSPECIFIED:
- return max(s1)
+def _tf_max(*args, **kwargs):
+ if len(kwargs):
+ kwargs_tuple = tuple(set(kwargs.keys()))
+ raise ValueError('These keyword arguments are '
+ 'currently not supported: {}'.format(kwargs_tuple))
+ elif len(args) == 1:
+ return math_ops.reduce_max(*args, axis=0)
else:
- return max(s1, s2)
+ if all(_tf_is_scalar(arg) for arg in args):
+ s= array_ops.concat([args], axis=0)
+ return math_ops.reduce_max(s, axis=0)
+
+def _py_max(*args, **kwargs):
+ return max(*args, **kwargs)
def _tf_py_func_print(objects, kwargs):
"""Overload of print_ as a py_func implementation."""
diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
index 9c5e27e..d8fb974 100644
--- a/tensorflow/python/autograph/operators/py_builtins_test.py
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -204,10 +204,10 @@
def test_max_tensor(self):
r = py_builtins.max_(constant_op.constant([1, 3, 2]))
- self.assertAllEqual(self.evaluate(r),3)
- r = py_builtins.max_(constant_op.constant([1, 5, 2]),[4])
- # TODO (bhack) this is just a dummy check
- self.assertTrue(self.evaluate(r))
+ self.assertAllEqual(self.evaluate(r), 3)
+ r = py_builtins.max_(constant_op.constant(6),constant_op.constant(4),
+ constant_op.constant(8))
+ self.assertAllEqual(self.evaluate(r), 8)
def test_range(self):
self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2])