blob: dfa5050cd31148d2327fdada7d1f8b24702f0444 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trackable class registration tests.
For integrated tests, see registration_saving_test.py.
"""
from absl.testing import parameterized
from tensorflow.python.eager import test
from tensorflow.python.saved_model import registration
from tensorflow.python.training.tracking import base
@registration.register_serializable()
class RegisteredClass(base.Trackable):
pass
@registration.register_serializable(name="Subclass")
class RegisteredSubclass(RegisteredClass):
pass
@registration.register_serializable(package="testing")
class CustomPackage(base.Trackable):
pass
@registration.register_serializable(package="testing", name="name")
class CustomPackageAndName(base.Trackable):
pass
class SerializableRegistrationTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters([
(RegisteredClass, "Custom.RegisteredClass"),
(RegisteredSubclass, "Custom.Subclass"),
(CustomPackage, "testing.CustomPackage"),
(CustomPackageAndName, "testing.name"),
])
def test_registration(self, expected_cls, expected_name):
obj = expected_cls()
self.assertEqual(registration.get_registered_class_name(obj), expected_name)
self.assertIs(
registration.get_registered_class(expected_name), expected_cls)
def test_get_invalid_name(self):
self.assertIsNone(registration.get_registered_class("invalid name"))
def test_get_unregistered_class(self):
class NotRegistered(base.Trackable):
pass
no_register = NotRegistered
self.assertIsNone(registration.get_registered_class_name(no_register))
def test_duplicate_registration(self):
@registration.register_serializable()
class Duplicate(base.Trackable):
pass
dup = Duplicate()
self.assertEqual(
registration.get_registered_class_name(dup), "Custom.Duplicate")
# Registrations with different names are ok.
registration.register_serializable(package="duplicate")(Duplicate)
# Registrations are checked in reverse order.
self.assertEqual(
registration.get_registered_class_name(dup), "duplicate.Duplicate")
# Both names should resolve to the same class.
self.assertIs(
registration.get_registered_class("Custom.Duplicate"), Duplicate)
self.assertIs(
registration.get_registered_class("duplicate.Duplicate"), Duplicate)
# Registrations of the same name fails
with self.assertRaisesRegex(ValueError, "already been registered"):
registration.register_serializable(
package="testing", name="CustomPackage")(
Duplicate)
def test_register_non_class_fails(self):
obj = RegisteredClass()
with self.assertRaisesRegex(TypeError, "must be a class"):
registration.register_serializable()(obj)
def test_register_bad_predicate_fails(self):
with self.assertRaisesRegex(TypeError, "must be callable"):
registration.register_serializable(predicate=0)(RegisteredClass)
def test_predicate(self):
class Predicate(base.Trackable):
def __init__(self, register_this):
self.register_this = register_this
registration.register_serializable(
name="RegisterThisOnlyTrue",
predicate=lambda x: isinstance(x, Predicate) and x.register_this)(
Predicate)
a = Predicate(True)
b = Predicate(False)
self.assertEqual(
registration.get_registered_class_name(a),
"Custom.RegisterThisOnlyTrue")
self.assertIsNone(registration.get_registered_class_name(b))
registration.register_serializable(
name="RegisterAllPredicate",
predicate=lambda x: isinstance(x, Predicate))(
Predicate)
self.assertEqual(
registration.get_registered_class_name(a),
"Custom.RegisterAllPredicate")
self.assertEqual(
registration.get_registered_class_name(b),
"Custom.RegisterAllPredicate")
class CheckpointSaverRegistrationTest(test.TestCase):
def test_invalid_registration(self):
with self.assertRaisesRegex(TypeError, "must be string"):
registration.register_checkpoint_saver(
package=None,
name="test",
predicate=lambda: None,
save_fn=lambda: None,
restore_fn=lambda: None)
with self.assertRaisesRegex(TypeError, "must be string"):
registration.register_checkpoint_saver(
name=None,
predicate=lambda: None,
save_fn=lambda: None,
restore_fn=lambda: None)
with self.assertRaisesRegex(ValueError,
"Invalid registered checkpoint saver."):
registration.register_checkpoint_saver(
package="package",
name="t/est",
predicate=lambda: None,
save_fn=lambda: None,
restore_fn=lambda: None)
with self.assertRaisesRegex(ValueError,
"Invalid registered checkpoint saver."):
registration.register_checkpoint_saver(
package="package",
name="t/est",
predicate=lambda: None,
save_fn=lambda: None,
restore_fn=lambda: None)
with self.assertRaisesRegex(
TypeError,
"The predicate registered to a checkpoint saver must be callable"
):
registration.register_checkpoint_saver(
name="test",
predicate=None,
save_fn=lambda: None,
restore_fn=lambda: None)
with self.assertRaisesRegex(TypeError, "The save_fn must be callable"):
registration.register_checkpoint_saver(
name="test",
predicate=lambda: None,
save_fn=None,
restore_fn=lambda: None)
with self.assertRaisesRegex(TypeError, "The restore_fn must be callable"):
registration.register_checkpoint_saver(
name="test",
predicate=lambda: None,
save_fn=lambda: None,
restore_fn=None)
def test_registration(self):
registration.register_checkpoint_saver(
package="Testing",
name="test_predicate",
predicate=lambda x: hasattr(x, "check_attr"),
save_fn=lambda: "save",
restore_fn=lambda: "restore")
x = base.Trackable()
self.assertIsNone(registration.get_registered_saver_name(x))
x.check_attr = 1
saver_name = registration.get_registered_saver_name(x)
self.assertEqual(saver_name, "Testing.test_predicate")
self.assertEqual(registration.get_save_function(saver_name)(), "save")
self.assertEqual(registration.get_restore_function(saver_name)(), "restore")
registration.validate_restore_function(x, "Testing.test_predicate")
with self.assertRaisesRegex(ValueError, "saver cannot be found"):
registration.validate_restore_function(x, "Invalid.name")
x2 = base.Trackable()
with self.assertRaisesRegex(ValueError, "saver cannot be used"):
registration.validate_restore_function(x2, "Testing.test_predicate")
if __name__ == "__main__":
test.main()