Fix v2 compatibility with moving average
PiperOrigin-RevId: 264688574
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 5d5bcff..46025ed 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -28,6 +28,7 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import slot_creator
+from tensorflow.python.util import object_identity
from tensorflow.python.util.tf_export import tf_export
@@ -368,7 +369,7 @@
self._num_updates = num_updates
self._zero_debias = zero_debias
self._name = name
- self._averages = {}
+ self._averages = object_identity.ObjectIdentityDictionary()
@property
def name(self):
@@ -456,7 +457,7 @@
(1.0 + num_updates) / (10.0 + num_updates))
updates = []
for var in var_list:
- zero_debias = self._averages[var] in zero_debias_true
+ zero_debias = any(self._averages[var] is v for v in zero_debias_true)
updates.append(
assign_moving_average(
self._averages[var], var, decay, zero_debias=zero_debias))
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 889d111..3a52d76 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -59,26 +59,25 @@
@test_util.run_in_graph_and_eager_modes
def testAssignMovingAverage(self):
- with self.cached_session():
- var = variables.Variable([0.0, 0.0])
- val = constant_op.constant([1.0, 2.0], dtypes.float32)
- decay = 0.25
- if context.executing_eagerly():
- self.assertAllClose([0.0, 0.0], self.evaluate(var))
- assign = moving_averages.assign_moving_average(var, val, decay)
- self.assertAllClose(
- [1.0 * (1.0 - 0.25) / (1 - 0.25), 2.0 * (1.0 - 0.25) / (1 - 0.25)],
- self.evaluate(var))
- else:
- assign = moving_averages.assign_moving_average(var, val, decay)
- self.evaluate(variables.global_variables_initializer())
- self.assertAllClose([0.0, 0.0], self.evaluate(var))
- assign.op.run()
- self.assertAllClose(
- [1.0 * (1.0 - 0.25) / (1 - 0.25), 2.0 * (1.0 - 0.25) / (1 - 0.25)],
- self.evaluate(var))
+ var = variables.Variable([0.0, 0.0])
+ val = constant_op.constant([1.0, 2.0], dtypes.float32)
+ decay = 0.25
+ if context.executing_eagerly():
+ self.assertAllClose([0.0, 0.0], self.evaluate(var))
+ assign = moving_averages.assign_moving_average(var, val, decay)
+ self.assertAllClose(
+ [1.0 * (1.0 - 0.25) / (1 - 0.25), 2.0 * (1.0 - 0.25) / (1 - 0.25)],
+ self.evaluate(var))
+ else:
+ assign = moving_averages.assign_moving_average(var, val, decay)
+ self.evaluate(variables.global_variables_initializer())
+ self.assertAllClose([0.0, 0.0], self.evaluate(var))
+ assign.op.run()
+ self.assertAllClose(
+ [1.0 * (1.0 - 0.25) / (1 - 0.25), 2.0 * (1.0 - 0.25) / (1 - 0.25)],
+ self.evaluate(var))
- @test_util.run_deprecated_v1
+ @test_util.deprecated_graph_mode_only
def testAssignMovingAverageNewNamingMultipleCalls(self):
with variable_scope.variable_scope("scope1") as vs1:
with variable_scope.variable_scope("scope2"):
@@ -93,7 +92,7 @@
actual_names = [v.name for v in vs1.global_variables()]
self.assertSetEqual(set(expected_names), set(actual_names))
- @test_util.run_deprecated_v1
+ @test_util.deprecated_graph_mode_only
def testAssignMovingAverageNewNamingMultipleCallsWithReuse(self):
with variable_scope.variable_scope("scope1") as vs1:
var = variable_scope.get_variable("Var", shape=[])
@@ -104,7 +103,7 @@
moving_averages.assign_moving_average(var, 0.0, 0.99)
moving_averages.assign_moving_average(var, 0.0, 0.99)
- @test_util.run_deprecated_v1
+ @test_util.deprecated_graph_mode_only
def testWeightedMovingAverage(self):
with self.cached_session() as sess:
decay = 0.5
@@ -130,7 +129,7 @@
denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
self.assertAllClose(numerator_2 / denominator_2, wma_array)
- @test_util.run_deprecated_v1
+ @test_util.deprecated_graph_mode_only
def testWeightedMovingAverageBfloat16(self):
bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
with self.cached_session() as sess:
@@ -157,6 +156,7 @@
denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array)
+
def _Repeat(value, dim):
if dim == 1:
return value
@@ -188,9 +188,9 @@
self.assertItemsEqual([var0, var1], variables.moving_average_variables())
- self.assertFalse(avg0 in variables.trainable_variables())
- self.assertFalse(avg1 in variables.trainable_variables())
- self.assertFalse(avg2 in variables.trainable_variables())
+ self.assertNotIn(avg0, variables.trainable_variables())
+ self.assertNotIn(avg1, variables.trainable_variables())
+ self.assertNotIn(avg2, variables.trainable_variables())
self.evaluate(variables.global_variables_initializer())
self.assertEqual("v0/ExponentialMovingAverage:0", avg0.name)
@@ -210,7 +210,7 @@
self.assertAllClose(_Repeat(0.0, dim), self.evaluate(avg2))
# Update the averages and check.
- update.run()
+ self.evaluate(update)
dk = actual_decay
expected = _Repeat(10.0 * dk + 10.0 * (1 - dk), dim)
@@ -221,7 +221,7 @@
self.assertAllClose(expected, self.evaluate(avg2))
# Again, update the averages and check.
- update.run()
+ self.evaluate(update)
expected = _Repeat((10.0 * dk + 10.0 * (1 - dk)) * dk + 10.0 * (1 - dk),
dim)
self.assertAllClose(expected, self.evaluate(avg0))
@@ -232,87 +232,76 @@
(10.0 + 30.0) * (1 - dk)) / _Scale(dk, 2), dim)
self.assertAllClose(expected, self.evaluate(avg2))
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNoNumUpdates_Scalar(self):
- with self.cached_session():
- ema = moving_averages.ExponentialMovingAverage(0.25)
- self._CheckDecay(ema, actual_decay=0.25, dim=1)
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ self._CheckDecay(ema, actual_decay=0.25, dim=1)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNoNumUpdates_Scalar_Debias(self):
- with self.cached_session():
- ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
- self._CheckDecay(ema, actual_decay=0.25, dim=1)
+ ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
+ self._CheckDecay(ema, actual_decay=0.25, dim=1)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNoNumUpdates_Vector(self):
- with self.cached_session():
- ema = moving_averages.ExponentialMovingAverage(0.25)
- self._CheckDecay(ema, actual_decay=0.25, dim=5)
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ self._CheckDecay(ema, actual_decay=0.25, dim=5)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNoNumUpdates_Vector_Debias(self):
- with self.cached_session():
- ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
- self._CheckDecay(ema, actual_decay=0.25, dim=5)
+ ema = moving_averages.ExponentialMovingAverage(0.25, zero_debias=True)
+ self._CheckDecay(ema, actual_decay=0.25, dim=5)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNumUpdates_Scalar(self):
- with self.cached_session():
- # With num_updates 1, the decay applied is 0.1818
- ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
- self._CheckDecay(ema, actual_decay=0.181818, dim=1)
+ # With num_updates 1, the decay applied is 0.1818
+ ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=1)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNumUpdates_Scalar_Debias(self):
- with self.cached_session():
- # With num_updates 1, the decay applied is 0.1818
- ema = moving_averages.ExponentialMovingAverage(
- 0.25, num_updates=1, zero_debias=True)
- self._CheckDecay(ema, actual_decay=0.181818, dim=1)
+ # With num_updates 1, the decay applied is 0.1818
+ ema = moving_averages.ExponentialMovingAverage(
+ 0.25, num_updates=1, zero_debias=True)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=1)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNumUpdates_Vector(self):
- with self.cached_session():
- # With num_updates 1, the decay applied is 0.1818
- ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
- self._CheckDecay(ema, actual_decay=0.181818, dim=5)
+ # With num_updates 1, the decay applied is 0.1818
+ ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=5)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNumUpdates_Vector_Debias(self):
- with self.cached_session():
- # With num_updates 1, the decay applied is 0.1818
- ema = moving_averages.ExponentialMovingAverage(
- 0.25, num_updates=1, zero_debias=True)
- self._CheckDecay(ema, actual_decay=0.181818, dim=5)
+ # With num_updates 1, the decay applied is 0.1818
+ ema = moving_averages.ExponentialMovingAverage(
+ 0.25, num_updates=1, zero_debias=True)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=5)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesWithControlDeps(self):
- with self.cached_session() as sess:
- v0 = variables.Variable(0, name="v0")
- add_to_v0 = v0.assign_add(1)
- v1 = variables.Variable([10.0], name="v1")
- assign_to_v1 = v1.assign([20.0])
- ema = moving_averages.ExponentialMovingAverage(0.25)
- with ops.control_dependencies([add_to_v0]):
- ema_op = ema.apply([v1])
- # the moving average of v1 should not have any control inputs
- v1_avg = ema.average(v1)
- self.assertEqual([], v1_avg.initializer.control_inputs)
- self.assertEqual([], v1_avg.value().op.control_inputs)
- self.assertEqual([], v1_avg.value().op.control_inputs)
- # We should be able to initialize v1_avg before v0.
- self.evaluate(v1_avg.initializer)
- self.evaluate(v0.initializer)
- self.assertEqual([10.0], self.evaluate(v1_avg))
- # running ema_op should add to v0 (in addition to updating v1_avg)
- self.evaluate(assign_to_v1)
- self.evaluate(ema_op)
- self.assertEqual(1, self.evaluate(v0))
- self.assertEqual([17.5], self.evaluate(v1_avg))
+ v0 = variables.Variable(0, name="v0")
+ add_to_v0 = v0.assign_add(1)
+ v1 = variables.Variable([10.0], name="v1")
+ assign_to_v1 = v1.assign([20.0])
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ with ops.control_dependencies([add_to_v0]):
+ ema_op = ema.apply([v1])
+ # the moving average of v1 should not have any control inputs
+ v1_avg = ema.average(v1)
+ self.assertEqual([], v1_avg.initializer.control_inputs)
+ self.assertEqual([], v1_avg.value().op.control_inputs)
+ self.assertEqual([], v1_avg.value().op.control_inputs)
+ # We should be able to initialize v1_avg before v0.
+ self.evaluate(v1_avg.initializer)
+ self.evaluate(v0.initializer)
+ self.assertEqual([10.0], self.evaluate(v1_avg))
+ # running ema_op should add to v0 (in addition to updating v1_avg)
+ self.evaluate(assign_to_v1)
+ self.evaluate(ema_op)
+ self.assertEqual(1, self.evaluate(v0))
+ self.assertEqual([17.5], self.evaluate(v1_avg))
- @test_util.run_in_graph_and_eager_modes
- @test_util.run_v1_only("b/120545219")
def testBasicEager(self):
v0 = variables.Variable(1.0)
v1 = variables.Variable(2.0)
@@ -332,130 +321,129 @@
self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5)
def averageVariablesNamesHelper(self, zero_debias):
- with self.cached_session():
+ v0 = variables.Variable(10.0, name="v0")
+ v1 = variables.Variable(30.0, name="v1")
+ # Add a non-trainable variable.
+ v2 = variables.Variable(20.0, name="v2", trainable=False)
+ tensor2 = v0 + v1
+ ema = moving_averages.ExponentialMovingAverage(
+ 0.25, zero_debias=zero_debias, name="foo")
+ self.assertEqual("foo", ema.name)
+ self.assertEqual("v0/foo", ema.average_name(v0))
+ self.assertEqual("v1/foo", ema.average_name(v1))
+ self.assertEqual("add/foo", ema.average_name(tensor2))
+ ema.apply([v0, v1, tensor2])
+ vars_to_restore = ema.variables_to_restore()
+ # vars_to_restore should contain the following:
+ # {v0/foo : v0,
+ # v1/foo : v1,
+ # add/foo : add/foo,
+ # v2 : v2}
+ expected_names = [
+ ema.average_name(v0),
+ ema.average_name(v1),
+ ema.average_name(tensor2), v2.op.name
+ ]
+ if zero_debias:
+ # vars_to_restore should also contain the following:
+ # {add/foo/biased: add/foo/biased,
+ # add/foo/local_step: add/foo/local_step}
+ expected_names += [
+ ema.average_name(tensor2) + "/biased",
+ ema.average_name(tensor2) + "/local_step"
+ ]
+ self.assertEqual(sorted(expected_names), sorted(vars_to_restore.keys()))
+ self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
+ self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
+ self.assertEqual(ema.average(tensor2).op.name, ema.average_name(tensor2))
+
+ @test_util.deprecated_graph_mode_only
+ def testAverageVariablesNames(self):
+ self.averageVariablesNamesHelper(zero_debias=True)
+
+ @test_util.deprecated_graph_mode_only
+ def testAverageVariablesNamesNoDebias(self):
+ self.averageVariablesNamesHelper(zero_debias=False)
+
+ @test_util.deprecated_graph_mode_only
+ def averageVariablesNamesRespectScopeHelper(self, zero_debias):
+ # See discussion on #2740.
+ with variable_scope.variable_scope("scope1"):
v0 = variables.Variable(10.0, name="v0")
v1 = variables.Variable(30.0, name="v1")
# Add a non-trainable variable.
v2 = variables.Variable(20.0, name="v2", trainable=False)
tensor2 = v0 + v1
+ with variable_scope.variable_scope("scope2"):
ema = moving_averages.ExponentialMovingAverage(
0.25, zero_debias=zero_debias, name="foo")
- self.assertEqual("foo", ema.name)
- self.assertEqual("v0/foo", ema.average_name(v0))
- self.assertEqual("v1/foo", ema.average_name(v1))
- self.assertEqual("add/foo", ema.average_name(tensor2))
+ self.assertEqual("scope2/scope1/v0/foo", ema.average_name(v0))
+ self.assertEqual("scope2/scope1/v1/foo", ema.average_name(v1))
+ self.assertEqual("scope2/scope1/add/foo", ema.average_name(tensor2))
ema.apply([v0, v1, tensor2])
vars_to_restore = ema.variables_to_restore()
- # vars_to_restore should contain the following:
- # {v0/foo : v0,
- # v1/foo : v1,
- # add/foo : add/foo,
- # v2 : v2}
+ # `vars_to_restore` should contain the following:
+ # {scope2/scope1/v0/foo : v0,
+ # scope2/scope1/v1/foo : v1,
+ # scope2/scope1/add/foo : add/foo,
+ # scope1/v2 : v2}
expected_names = [
- ema.average_name(v0), ema.average_name(v1), ema.average_name(tensor2),
- v2.op.name
+ ema.average_name(v0),
+ ema.average_name(v1),
+ ema.average_name(tensor2), v2.op.name
]
if zero_debias:
- # vars_to_restore should also contain the following:
- # {add/foo/biased: add/foo/biased,
- # add/foo/local_step: add/foo/local_step}
+ # `vars_to_restore` should also contain the following:
+ # {scope2/scope2/scope1/add/foo/biased: add/foo/biased,
+ # scope2/scope2/scope1/add/foo/local_step: add/foo/local_step}
+ sc = "scope2/"
expected_names += [
- ema.average_name(tensor2) + "/biased",
- ema.average_name(tensor2) + "/local_step"
+ sc + ema.average_name(tensor2) + "/biased",
+ sc + ema.average_name(tensor2) + "/local_step"
]
+
self.assertEqual(sorted(expected_names), sorted(vars_to_restore.keys()))
self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
self.assertEqual(ema.average(tensor2).op.name, ema.average_name(tensor2))
- @test_util.run_v1_only("b/120545219")
- def testAverageVariablesNames(self):
- self.averageVariablesNamesHelper(zero_debias=True)
-
- @test_util.run_v1_only("b/120545219")
- def testAverageVariablesNamesNoDebias(self):
- self.averageVariablesNamesHelper(zero_debias=False)
-
- def averageVariablesNamesRespectScopeHelper(self, zero_debias):
- # See discussion on #2740.
- with self.cached_session():
- with variable_scope.variable_scope("scope1"):
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(30.0, name="v1")
- # Add a non-trainable variable.
- v2 = variables.Variable(20.0, name="v2", trainable=False)
- tensor2 = v0 + v1
- with variable_scope.variable_scope("scope2"):
- ema = moving_averages.ExponentialMovingAverage(
- 0.25, zero_debias=zero_debias, name="foo")
- self.assertEqual("scope2/scope1/v0/foo", ema.average_name(v0))
- self.assertEqual("scope2/scope1/v1/foo", ema.average_name(v1))
- self.assertEqual("scope2/scope1/add/foo", ema.average_name(tensor2))
- ema.apply([v0, v1, tensor2])
- vars_to_restore = ema.variables_to_restore()
- # `vars_to_restore` should contain the following:
- # {scope2/scope1/v0/foo : v0,
- # scope2/scope1/v1/foo : v1,
- # scope2/scope1/add/foo : add/foo,
- # scope1/v2 : v2}
- expected_names = [
- ema.average_name(v0), ema.average_name(v1),
- ema.average_name(tensor2), v2.op.name
- ]
- if zero_debias:
- # `vars_to_restore` should also contain the following:
- # {scope2/scope2/scope1/add/foo/biased: add/foo/biased,
- # scope2/scope2/scope1/add/foo/local_step: add/foo/local_step}
- sc = "scope2/"
- expected_names += [
- sc + ema.average_name(tensor2) + "/biased",
- sc + ema.average_name(tensor2) + "/local_step"
- ]
-
- self.assertEqual(sorted(expected_names), sorted(vars_to_restore.keys()))
- self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
- self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
- self.assertEqual(
- ema.average(tensor2).op.name, ema.average_name(tensor2))
-
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNamesRespectScope(self):
self.averageVariablesNamesRespectScopeHelper(zero_debias=True)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesNamesRespectScopeNoDebias(self):
self.averageVariablesNamesRespectScopeHelper(zero_debias=False)
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testSubsetAverageVariablesNames(self):
- with self.cached_session():
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(30.0, name="v1")
- # Add a non-trainable variable.
- v2 = variables.Variable(20.0, name="v2", trainable=False)
- tensor2 = v0 + v1
- ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
- self.assertEqual("v0/foo_avg", ema.average_name(v0))
- self.assertEqual("v1/foo_avg", ema.average_name(v1))
- self.assertEqual("add/foo_avg", ema.average_name(tensor2))
- vars_to_restore = ema.variables_to_restore([v0, tensor2])
- # vars_to_restore should contain the following:
- # {v0/foo_avg : v0,
- # add/foo_avg : add
- # v1 : v1,
- # v2 : v2}
- self.assertEqual(
- sorted(vars_to_restore.keys()),
- sorted([
- ema.average_name(v0), ema.average_name(tensor2), v1.op.name,
- v2.op.name
- ]))
- ema.apply([v0, v1, tensor2])
- self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
- self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
- self.assertEqual(ema.average(tensor2).op.name, ema.average_name(tensor2))
+ v0 = variables.Variable(10.0, name="v0")
+ v1 = variables.Variable(30.0, name="v1")
+ # Add a non-trainable variable.
+ v2 = variables.Variable(20.0, name="v2", trainable=False)
+ tensor2 = v0 + v1
+ ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
+ self.assertEqual("v0/foo_avg", ema.average_name(v0))
+ self.assertEqual("v1/foo_avg", ema.average_name(v1))
+ self.assertEqual("add/foo_avg", ema.average_name(tensor2))
+ vars_to_restore = ema.variables_to_restore([v0, tensor2])
+ # vars_to_restore should contain the following:
+ # {v0/foo_avg : v0,
+ # add/foo_avg : add
+ # v1 : v1,
+ # v2 : v2}
+ self.assertEqual(
+ sorted(vars_to_restore.keys()),
+ sorted([
+ ema.average_name(v0),
+ ema.average_name(tensor2), v1.op.name, v2.op.name
+ ]))
+ ema.apply([v0, v1, tensor2])
+ self.assertEqual(ema.average(v0).op.name, ema.average_name(v0))
+ self.assertEqual(ema.average(v1).op.name, ema.average_name(v1))
+ self.assertEqual(ema.average(tensor2).op.name, ema.average_name(tensor2))
- @test_util.run_v1_only("b/120545219")
+ @test_util.deprecated_graph_mode_only
def testAverageVariablesDeviceAssignment(self):
with ops.device("/job:dev_v0"):
v0 = variables.Variable(10.0, name="v0")
@@ -486,7 +474,7 @@
_ = saver_lib.import_meta_graph(meta_graph)
return graph_copy
- @test_util.run_deprecated_v1
+ @test_util.deprecated_graph_mode_only
def testImportedGraphVariablesToRestore(self):
g = ops.Graph()
with g.as_default():
@@ -502,7 +490,7 @@
# need to be sure that two variables referring to the same variable don't
# both get added to vars_to_restore.
self.assertEqual(len(vars_to_restore), 1)
- self.assertTrue("v/foo_avg" in vars_to_restore)
+ self.assertIn("v/foo_avg", vars_to_restore)
if __name__ == "__main__":