Copy the feature_column_v2._StateManagerImplV2 to keras since its only used by keras. Will remove the class in tf in a followup cl.
PiperOrigin-RevId: 360261708
Change-Id: I9471aea1b17349c18a3a488ffc6fb02af2421099
diff --git a/tensorflow/python/keras/feature_column/BUILD b/tensorflow/python/keras/feature_column/BUILD
index a64f88b..dc382e5 100644
--- a/tensorflow/python/keras/feature_column/BUILD
+++ b/tensorflow/python/keras/feature_column/BUILD
@@ -66,6 +66,7 @@
":dense_features",
"//tensorflow/python:framework_ops",
"//tensorflow/python/feature_column:feature_column_v2",
+ "//tensorflow/python/keras/utils:tf_contextlib",
"//tensorflow/python/util:tf_export",
],
)
diff --git a/tensorflow/python/keras/feature_column/dense_features_v2.py b/tensorflow/python/keras/feature_column/dense_features_v2.py
index ae1294c..5f54120 100644
--- a/tensorflow/python/keras/feature_column/dense_features_v2.py
+++ b/tensorflow/python/keras/feature_column/dense_features_v2.py
@@ -22,6 +22,8 @@
from tensorflow.python.framework import ops
from tensorflow.python.keras.feature_column import base_feature_layer as kfc
from tensorflow.python.keras.feature_column import dense_features
+from tensorflow.python.keras.utils import tf_contextlib
+from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util.tf_export import keras_export
@@ -85,7 +87,7 @@
trainable=trainable,
name=name,
**kwargs)
- self._state_manager = fc._StateManagerImplV2(self, self.trainable) # pylint: disable=protected-access
+ self._state_manager = _StateManagerImplV2(self, self.trainable)
def build(self, _):
for column in self._feature_columns:
@@ -94,3 +96,65 @@
# We would like to call Layer.build and not _DenseFeaturesHelper.build.
# pylint: disable=protected-access
super(kfc._BaseFeaturesLayer, self).build(None) # pylint: disable=bad-super-call
+
+
+class _StateManagerImplV2(fc._StateManagerImpl): # pylint: disable=protected-access
+ """Manages the state of DenseFeatures."""
+
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ use_resource=True,
+ initializer=None):
+ if name in self._cols_to_vars_map[feature_column]:
+ raise ValueError('Variable already exists.')
+
+ # We explicitly track these variables since `name` is not guaranteed to be
+ # unique and disable manual tracking that the add_weight call does.
+ with no_manual_dependency_tracking_scope(self._layer):
+ var = self._layer.add_weight(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ trainable=self._trainable and trainable,
+ use_resource=use_resource)
+ if isinstance(var, trackable.Trackable):
+ self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access
+ self._cols_to_vars_map[feature_column][name] = var
+ return var
+
+
+@tf_contextlib.contextmanager
+def no_manual_dependency_tracking_scope(obj):
+ """A context that disables manual dependency tracking for the given `obj`.
+
+ Sometimes library methods might track objects on their own and we might want
+ to disable that and do the tracking on our own. One can then use this context
+ manager to disable the tracking the library method does and do your own
+ tracking.
+
+ For example:
+
+ class TestLayer(tf.keras.Layer):
+ def build():
+ with no_manual_dependency_tracking_scope(self):
+ var = self.add_variable("name1") # Creates a var and doesn't track it
+ self._track_trackable("name2", var) # We track variable with name `name2`
+
+ Args:
+ obj: A trackable object.
+
+ Yields:
+ a scope in which the object doesn't track dependencies manually.
+ """
+ # pylint: disable=protected-access
+ previous_value = getattr(obj, '_manual_tracking', True)
+ obj._manual_tracking = False
+ try:
+ yield
+ finally:
+ obj._manual_tracking = previous_value