erase and minor
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorMapReplace.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorMapReplace.pbtxt
deleted file mode 100644
index 80c49cb..0000000
--- a/tensorflow/core/api_def/base_api/api_def_TensorMapReplace.pbtxt
+++ /dev/null
@@ -1,10 +0,0 @@
-op {
- graph_op_name: "TensorMapReplace"
- summary: "Returns a map that is the 'input_handle' after replacing the existing key value with the given value."
- description: <<END
-input_handle: the original map
-output_handle: the map with key and value inserted
-key: the key whose value will be replaced
-value: the value to replace the original value
-END
-}
\ No newline at end of file
diff --git a/tensorflow/core/ops/map_ops.cc b/tensorflow/core/ops/map_ops.cc
index 072c116..d6142f3 100644
--- a/tensorflow/core/ops/map_ops.cc
+++ b/tensorflow/core/ops/map_ops.cc
@@ -1,4 +1,4 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/tensorflow/python/kernel_tests/map_ops_test.py b/tensorflow/python/kernel_tests/map_ops_test.py
index 7fda6fb..bdf06fc 100644
--- a/tensorflow/python/kernel_tests/map_ops_test.py
+++ b/tensorflow/python/kernel_tests/map_ops_test.py
@@ -24,13 +24,10 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import map_ops
from tensorflow.python.platform import test
-from tensorflow.python.util.lazy_loader import LazyLoader
-control_flow_ops = LazyLoader("control_flow_ops", globals(),
- "tensorflow.python.ops.control_flow_ops")
-
@test_util.run_all_in_graph_and_eager_modes
class MapOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
@@ -189,14 +186,34 @@
tape.watch(v2)
m = map_ops.tensor_map_insert(m, k, v)
l = map_ops.tensor_map_lookup(m, k, v.dtype)
+ self.assertAllClose(l, v)
g = tape.gradient(l * 5, v)
self.assertAllClose(g, 5)
m = map_ops.tensor_map_insert(m, k, v2)
l2 = map_ops.tensor_map_lookup(m, k, v2.dtype)
+ self.assertAllClose(l2, v2)
g2 = tape.gradient(l2 * 6, v)
g3 = tape.gradient(l2 * 7, v2)
self.assertAllClose(g2, array_ops.zeros_like(v))
self.assertAllClose(g3, 7)
+ def testEraseGrad(self):
+ with backprop.GradientTape(persistent=True) as tape:
+ m = map_ops.empty_tensor_map()
+ k = constant_op.constant(1.0)
+ v = constant_op.constant(2.0)
+ tape.watch(v)
+ k2 = constant_op.constant(12.0)
+ v2 = constant_op.constant(22.0)
+ tape.watch(v2)
+ m = map_ops.tensor_map_insert(m, k, v)
+ m = map_ops.tensor_map_insert(m, k2, v2)
+ m, e = map_ops.tensor_map_erase(m, k2, v2.dtype)
+ l = map_ops.tensor_map_lookup(m, k, v.dtype)
+ self.assertAllClose(l, v)
+ self.assertAllClose(e, v2)
+ g = tape.gradient(l * 5, v)
+ self.assertAllClose(g, 5)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/ops/map_ops.py b/tensorflow/python/ops/map_ops.py
index 443bb0b..c28bc57 100644
--- a/tensorflow/python/ops/map_ops.py
+++ b/tensorflow/python/ops/map_ops.py
@@ -21,14 +21,11 @@
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_map_ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_map_ops
from tensorflow.python.ops.gen_map_ops import *
-from tensorflow.python.util.lazy_loader import LazyLoader
-control_flow_ops = LazyLoader("control_flow_ops", globals(),
- "tensorflow.python.ops.control_flow_ops")
-
ops.NotDifferentiable("EmptyTensorMap")
def empty_tensor_map():
@@ -68,3 +65,10 @@
lambda: tensor_map_erase(dmap, k, v.dtype)[0],
lambda: dmap)
return map_grad, key_grad, value_grad
+
+@ops.RegisterGradient("TensorMapErase")
+def EraseGrad(op, dmap, dval):
+ _, k = op.inputs
+ key_grad = None
+ map_grad = dmap
+ return map_grad, key_grad