test: add correct test for stateful flag in numpy_function
diff --git a/RELEASE.md b/RELEASE.md
index 79ebe00..76fd24a 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -42,16 +42,16 @@
     for the migration.
 
 * TF Core:
-    *   `tf.Graph.get_name_scope()` now always returns a string, as documented.
+    * `tf.Graph.get_name_scope()` now always returns a string, as documented.
         Previously, when called within `name_scope("")` or `name_scope(None)`
         contexts, it returned None; now it returns the empty string.
-    *   `tensorflow/core/ir/` contains a new MLIR-based Graph dialect that is
+    * `tensorflow/core/ir/` contains a new MLIR-based Graph dialect that is
         isomorphic to GraphDef and will be used to replace GraphDef-based (e.g.,
         Grappler) optimizations.
-    *   Deprecated and removed attrs() function in shape inference. All
+    * Deprecated and removed attrs() function in shape inference. All
         attributes should be queried by name now (rather than range returned)
         to enable changing the underlying storage there.
-    *   The following Python symbols were accidentally added in earlier versions
+    * The following Python symbols were accidentally added in earlier versions
         of TensorFlow and now are removed. Each symbol has a replacement that
         should be used instead, but note the replacement's argument names are
         different.
@@ -63,10 +63,13 @@
         2.6): Use `tf.raw_ops.SparseSegmentSumGrad` instead. Directly calling
         this op is typically not necessary, as it is automatically used when
         computing the gradient of `tf.sparse.segment_sum`.
-    *   Renaming of tensorflow::int64 to int_64_t in numerous places (the former
+    * Renaming of tensorflow::int64 to int_64_t in numerous places (the former
         is an alias for the latter) which could result in needing to regenerate
         selective op registration headers else execution would fail with
         unregistered kernels error.
+    * Adding a flag `stateful` to `numpy_function`, allowing to give the
+        guarantee to the runtime that the function call is stateless,
+        which allows for more optimizations in the graph.
 
 ## Known Caveats
 
diff --git a/tensorflow/python/ops/script_ops_test.py b/tensorflow/python/ops/script_ops_test.py
index 7d25016..f38eb58 100644
--- a/tensorflow/python/ops/script_ops_test.py
+++ b/tensorflow/python/ops/script_ops_test.py
@@ -48,29 +48,38 @@
       return a + b
 
     @def_function.function
-    def tensor_double_plus_stateless(a, b):
-      sum1 = numpy_function(plus, [a, b], dtypes.int32, stateful=False)
-      sum2 = numpy_function(plus, [a, b], dtypes.int32, stateful=False)
+    def numpy_func_stateless(a, b):
+      return numpy_function(plus, [a, b], dtypes.int32, stateful=False)
+
+    @def_function.function
+    def func_stateless(a, b):
+      sum1 = numpy_func_stateless(a, b)
+      sum2 = numpy_func_stateless(a, b)
       return sum1 + sum2
 
-    # different argument
-    _ = tensor_double_plus_stateless(  # executing empty
+    _ = func_stateless(
       constant_op.constant(1),
       constant_op.constant(2),
     )
-    self.assertEqual(call_count, 1)  # +1 as only the first encounter was executed
+
+    self.assertEqual(call_count, 1)  # the second call should be eliminated
+    call_count = 0  # reset
 
     @def_function.function
-    def tensor_double_plus_stateful(a, b):
-      sum1 = numpy_function(plus, [a, b], dtypes.int32, stateful=True)
-      sum2 = numpy_function(plus, [a, b], dtypes.int32, stateful=True)
+    def numpy_func_stateful(a, b):
+      return numpy_function(plus, [a, b], dtypes.int32, stateful=True)
+
+    @def_function.function
+    def func_stateful(a, b):
+      sum1 = numpy_func_stateful(a, b)
+      sum2 = numpy_func_stateful(a, b)
       return sum1 + sum2
 
-    _ = tensor_double_plus_stateful( # executing empty
-      constant_op.constant(3),
-      constant_op.constant(4),
-                          )
-    self.assertEqual(call_count, 3)  # +2 as it is stateful, func was both times executed
+    _ = func_stateful(
+      constant_op.constant(1),
+      constant_op.constant(2),
+    )
+    self.assertEqual(call_count, 2)  # 2 as it is stateful, func was both times executed
 
 
 if __name__ == "__main__":