test: add correct test for stateful flag in numpy_function
diff --git a/RELEASE.md b/RELEASE.md
index 79ebe00..76fd24a 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -42,16 +42,16 @@
for the migration.
* TF Core:
- * `tf.Graph.get_name_scope()` now always returns a string, as documented.
+ * `tf.Graph.get_name_scope()` now always returns a string, as documented.
Previously, when called within `name_scope("")` or `name_scope(None)`
contexts, it returned None; now it returns the empty string.
- * `tensorflow/core/ir/` contains a new MLIR-based Graph dialect that is
+ * `tensorflow/core/ir/` contains a new MLIR-based Graph dialect that is
isomorphic to GraphDef and will be used to replace GraphDef-based (e.g.,
Grappler) optimizations.
- * Deprecated and removed attrs() function in shape inference. All
+ * Deprecated and removed attrs() function in shape inference. All
attributes should be queried by name now (rather than range returned)
to enable changing the underlying storage there.
- * The following Python symbols were accidentally added in earlier versions
+ * The following Python symbols were accidentally added in earlier versions
of TensorFlow and now are removed. Each symbol has a replacement that
should be used instead, but note the replacement's argument names are
different.
@@ -63,10 +63,13 @@
2.6): Use `tf.raw_ops.SparseSegmentSumGrad` instead. Directly calling
this op is typically not necessary, as it is automatically used when
computing the gradient of `tf.sparse.segment_sum`.
- * Renaming of tensorflow::int64 to int_64_t in numerous places (the former
+ * Renaming of tensorflow::int64 to int_64_t in numerous places (the former
is an alias for the latter) which could result in needing to regenerate
selective op registration headers else execution would fail with
unregistered kernels error.
+ * Adding a flag `stateful` to `numpy_function`, allowing to give the
+ guarantee to the runtime that the function call is stateless,
+ which allows for more optimizations in the graph.
## Known Caveats
diff --git a/tensorflow/python/ops/script_ops_test.py b/tensorflow/python/ops/script_ops_test.py
index 7d25016..f38eb58 100644
--- a/tensorflow/python/ops/script_ops_test.py
+++ b/tensorflow/python/ops/script_ops_test.py
@@ -48,29 +48,38 @@
return a + b
@def_function.function
- def tensor_double_plus_stateless(a, b):
- sum1 = numpy_function(plus, [a, b], dtypes.int32, stateful=False)
- sum2 = numpy_function(plus, [a, b], dtypes.int32, stateful=False)
+ def numpy_func_stateless(a, b):
+ return numpy_function(plus, [a, b], dtypes.int32, stateful=False)
+
+ @def_function.function
+ def func_stateless(a, b):
+ sum1 = numpy_func_stateless(a, b)
+ sum2 = numpy_func_stateless(a, b)
return sum1 + sum2
- # different argument
- _ = tensor_double_plus_stateless( # executing empty
+ _ = func_stateless(
constant_op.constant(1),
constant_op.constant(2),
)
- self.assertEqual(call_count, 1) # +1 as only the first encounter was executed
+
+ self.assertEqual(call_count, 1) # the second call should be eliminated
+ call_count = 0 # reset
@def_function.function
- def tensor_double_plus_stateful(a, b):
- sum1 = numpy_function(plus, [a, b], dtypes.int32, stateful=True)
- sum2 = numpy_function(plus, [a, b], dtypes.int32, stateful=True)
+ def numpy_func_stateful(a, b):
+ return numpy_function(plus, [a, b], dtypes.int32, stateful=True)
+
+ @def_function.function
+ def func_stateful(a, b):
+ sum1 = numpy_func_stateful(a, b)
+ sum2 = numpy_func_stateful(a, b)
return sum1 + sum2
- _ = tensor_double_plus_stateful( # executing empty
- constant_op.constant(3),
- constant_op.constant(4),
- )
- self.assertEqual(call_count, 3) # +2 as it is stateful, func was both times executed
+ _ = func_stateful(
+ constant_op.constant(1),
+ constant_op.constant(2),
+ )
+ self.assertEqual(call_count, 2) # 2 as it is stateful, func was both times executed
if __name__ == "__main__":