Copy the feature_column_v2._StateManagerImplV2 to keras since its only used by keras. Will remove the class in tf in a followup cl.

PiperOrigin-RevId: 360261708
Change-Id: I9471aea1b17349c18a3a488ffc6fb02af2421099
diff --git a/tensorflow/python/keras/feature_column/BUILD b/tensorflow/python/keras/feature_column/BUILD
index a64f88b..dc382e5 100644
--- a/tensorflow/python/keras/feature_column/BUILD
+++ b/tensorflow/python/keras/feature_column/BUILD
@@ -66,6 +66,7 @@
         ":dense_features",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python/feature_column:feature_column_v2",
+        "//tensorflow/python/keras/utils:tf_contextlib",
         "//tensorflow/python/util:tf_export",
     ],
 )
diff --git a/tensorflow/python/keras/feature_column/dense_features_v2.py b/tensorflow/python/keras/feature_column/dense_features_v2.py
index ae1294c..5f54120 100644
--- a/tensorflow/python/keras/feature_column/dense_features_v2.py
+++ b/tensorflow/python/keras/feature_column/dense_features_v2.py
@@ -22,6 +22,8 @@
 from tensorflow.python.framework import ops
 from tensorflow.python.keras.feature_column import base_feature_layer as kfc
 from tensorflow.python.keras.feature_column import dense_features
+from tensorflow.python.keras.utils import tf_contextlib
+from tensorflow.python.training.tracking import base as trackable
 from tensorflow.python.util.tf_export import keras_export
 
 
@@ -85,7 +87,7 @@
         trainable=trainable,
         name=name,
         **kwargs)
-    self._state_manager = fc._StateManagerImplV2(self, self.trainable)  # pylint: disable=protected-access
+    self._state_manager = _StateManagerImplV2(self, self.trainable)
 
   def build(self, _):
     for column in self._feature_columns:
@@ -94,3 +96,65 @@
     # We would like to call Layer.build and not _DenseFeaturesHelper.build.
     # pylint: disable=protected-access
     super(kfc._BaseFeaturesLayer, self).build(None)  # pylint: disable=bad-super-call
+
+
+class _StateManagerImplV2(fc._StateManagerImpl):  # pylint: disable=protected-access
+  """Manages the state of DenseFeatures."""
+
+  def create_variable(self,
+                      feature_column,
+                      name,
+                      shape,
+                      dtype=None,
+                      trainable=True,
+                      use_resource=True,
+                      initializer=None):
+    if name in self._cols_to_vars_map[feature_column]:
+      raise ValueError('Variable already exists.')
+
+    # We explicitly track these variables since `name` is not guaranteed to be
+    # unique and disable manual tracking that the add_weight call does.
+    with no_manual_dependency_tracking_scope(self._layer):
+      var = self._layer.add_weight(
+          name=name,
+          shape=shape,
+          dtype=dtype,
+          initializer=initializer,
+          trainable=self._trainable and trainable,
+          use_resource=use_resource)
+    if isinstance(var, trackable.Trackable):
+      self._layer._track_trackable(var, feature_column.name + '/' + name)  # pylint: disable=protected-access
+    self._cols_to_vars_map[feature_column][name] = var
+    return var
+
+
+@tf_contextlib.contextmanager
+def no_manual_dependency_tracking_scope(obj):
+  """A context that disables manual dependency tracking for the given `obj`.
+
+  Sometimes library methods might track objects on their own and we might want
+  to disable that and do the tracking on our own. One can then use this context
+  manager to disable the tracking the library method does and do your own
+  tracking.
+
+  For example:
+
+  class TestLayer(tf.keras.Layer):
+    def build():
+      with no_manual_dependency_tracking_scope(self):
+        var = self.add_variable("name1")  # Creates a var and doesn't track it
+      self._track_trackable("name2", var)  # We track variable with name `name2`
+
+  Args:
+    obj: A trackable object.
+
+  Yields:
+    a scope in which the object doesn't track dependencies manually.
+  """
+  # pylint: disable=protected-access
+  previous_value = getattr(obj, '_manual_tracking', True)
+  obj._manual_tracking = False
+  try:
+    yield
+  finally:
+    obj._manual_tracking = previous_value