Chain exceptions when attempting to restore tensors to variables with incompatible shapes.
PiperOrigin-RevId: 435137964
diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py
index 36e3cdb..3db0eae 100644
--- a/tensorflow/python/training/saving/saveable_object_util.py
+++ b/tensorflow/python/training/saving/saveable_object_util.py
@@ -118,14 +118,22 @@
super(ResourceVariableSaveable, self).__init__(var, [spec], name)
def restore(self, restored_tensors, restored_shapes):
+ """Restores tensors. Raises ValueError if incompatible shape found."""
restored_tensor = restored_tensors[0]
if restored_shapes is not None:
restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
# Copy the restored tensor to the variable's device.
with ops.device(self._var_device):
restored_tensor = array_ops.identity(restored_tensor)
- return resource_variable_ops.shape_safe_assign_variable_handle(
- self.handle_op, self._var_shape, restored_tensor)
+ try:
+ assigned_variable = resource_variable_ops.shape_safe_assign_variable_handle(
+ self.handle_op, self._var_shape, restored_tensor)
+ except ValueError as e:
+ raise ValueError(
+ f"Received incompatible tensor with shape {restored_tensor.shape} "
+ f"when attempting to restore variable with shape {self._var_shape} "
+ f"and name {self.name}.") from e
+ return assigned_variable
def _tensor_comes_from_variable(v):
diff --git a/tensorflow/python/training/tracking/util_test.py b/tensorflow/python/training/tracking/util_test.py
index 7fbafc8..587ec18 100644
--- a/tensorflow/python/training/tracking/util_test.py
+++ b/tensorflow/python/training/tracking/util_test.py
@@ -1093,6 +1093,15 @@
# https://docs.python.org/3/library/sys.html#sys.getrefcount
self.assertEqual(sys.getrefcount(ref.deref()), 2)
+ def test_restore_incompatible_shape(self):
+ v = variables_lib.Variable([1.0, 1.0])
+ w = variables_lib.Variable([1.0])
+ ckpt = trackable_utils.Checkpoint(v=v)
+ save_path = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+ with self.assertRaisesRegex(ValueError, "incompatible tensor with shape"):
+ trackable_utils.Checkpoint(v=w).restore(save_path)
+
class TemplateTests(parameterized.TestCase, test.TestCase):