Chain exceptions when attempting to restore tensors to variables with incompatible shapes.

PiperOrigin-RevId: 435137964
diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py
index 36e3cdb..3db0eae 100644
--- a/tensorflow/python/training/saving/saveable_object_util.py
+++ b/tensorflow/python/training/saving/saveable_object_util.py
@@ -118,14 +118,22 @@
     super(ResourceVariableSaveable, self).__init__(var, [spec], name)
 
   def restore(self, restored_tensors, restored_shapes):
+    """Restores tensors. Raises ValueError if incompatible shape found."""
     restored_tensor = restored_tensors[0]
     if restored_shapes is not None:
       restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
     # Copy the restored tensor to the variable's device.
     with ops.device(self._var_device):
       restored_tensor = array_ops.identity(restored_tensor)
-      return resource_variable_ops.shape_safe_assign_variable_handle(
-          self.handle_op, self._var_shape, restored_tensor)
+      try:
+        assigned_variable = resource_variable_ops.shape_safe_assign_variable_handle(
+            self.handle_op, self._var_shape, restored_tensor)
+      except ValueError as e:
+        raise ValueError(
+            f"Received incompatible tensor with shape {restored_tensor.shape} "
+            f"when attempting to restore variable with shape {self._var_shape} "
+            f"and name {self.name}.") from e
+      return assigned_variable
 
 
 def _tensor_comes_from_variable(v):
diff --git a/tensorflow/python/training/tracking/util_test.py b/tensorflow/python/training/tracking/util_test.py
index 7fbafc8..587ec18 100644
--- a/tensorflow/python/training/tracking/util_test.py
+++ b/tensorflow/python/training/tracking/util_test.py
@@ -1093,6 +1093,15 @@
     # https://docs.python.org/3/library/sys.html#sys.getrefcount
     self.assertEqual(sys.getrefcount(ref.deref()), 2)
 
+  def test_restore_incompatible_shape(self):
+    v = variables_lib.Variable([1.0, 1.0])
+    w = variables_lib.Variable([1.0])
+    ckpt = trackable_utils.Checkpoint(v=v)
+    save_path = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt"))
+
+    with self.assertRaisesRegex(ValueError, "incompatible tensor with shape"):
+      trackable_utils.Checkpoint(v=w).restore(save_path)
+
 
 class TemplateTests(parameterized.TestCase, test.TestCase):