Add SaveOptions/CheckpointOptions to keras.Models.save_weights and keras_call_backs.ModelCheckpoint.

PiperOrigin-RevId: 316973333
Change-Id: I43f5b59ece4b862db41ab0e99f3c8df0a0d3b901
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 1bca541..1fae5ab 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -54,7 +54,9 @@
 from tensorflow.python.ops import summary_ops_v2
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.profiler import profiler_v2 as profiler
+from tensorflow.python.saved_model import save_options as save_options_lib
 from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib
 from tensorflow.python.util import nest
 from tensorflow.python.util.compat import collections_abc
 from tensorflow.python.util.tf_export import keras_export
@@ -1115,6 +1117,9 @@
         epochs, the monitored metric may potentially be less reliable (it
         could reflect as little as 1 batch, since the metrics get reset every
         epoch). Defaults to `'epoch'`.
+      options: Optional `tf.train.CheckpointOptions` object if
+        `save_weights_only` is true or optional `tf.saved_model.SavedOptions`
+        object if `save_weights_only` is false.
       **kwargs: Additional arguments for backwards compatibility. Possible key
         is `period`.
   """
@@ -1127,6 +1132,7 @@
                save_weights_only=False,
                mode='auto',
                save_freq='epoch',
+               options=None,
                **kwargs):
     super(ModelCheckpoint, self).__init__()
     self._supports_tf_logs = True
@@ -1140,6 +1146,20 @@
     self._batches_seen_since_last_saving = 0
     self._last_batch_seen = 0
 
+    if save_weights_only:
+      if options is None or isinstance(
+          options, checkpoint_options_lib.CheckpointOptions):
+        self._options = options or checkpoint_options_lib.CheckpointOptions()
+      else:
+        raise TypeError('If save_weights_only is True, then `options` must be'
+                        'either None or a tf.train.CheckpointOptions')
+    else:
+      if options is None or isinstance(options, save_options_lib.SaveOptions):
+        self._options = options or save_options_lib.SaveOptions()
+      else:
+        raise TypeError('If save_weights_only is False, then `options` must be'
+                        'either None or a tf.saved_model.SaveOptions')
+
     # Deprecated field `load_weights_on_restart` is for loading the checkpoint
     # file from `filepath` at the start of `model.fit()`
     # TODO(rchao): Remove the arg during next breaking release.
@@ -1269,9 +1289,10 @@
                                                self.best, current, filepath))
               self.best = current
               if self.save_weights_only:
-                self.model.save_weights(filepath, overwrite=True)
+                self.model.save_weights(
+                    filepath, overwrite=True, options=self._options)
               else:
-                self.model.save(filepath, overwrite=True)
+                self.model.save(filepath, overwrite=True, options=self._options)
             else:
               if self.verbose > 0:
                 print('\nEpoch %05d: %s did not improve from %0.5f' %
@@ -1280,9 +1301,10 @@
           if self.verbose > 0:
             print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
           if self.save_weights_only:
-            self.model.save_weights(filepath, overwrite=True)
+            self.model.save_weights(
+                filepath, overwrite=True, options=self._options)
           else:
-            self.model.save(filepath, overwrite=True)
+            self.model.save(filepath, overwrite=True, options=self._options)
 
         self._maybe_remove_file()
       except IOError as e:
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 28f8530..d180e85 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -49,9 +49,11 @@
 from tensorflow.python.ops import summary_ops_v2
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import save_options as save_options_lib
 from tensorflow.python.summary import summary_iterator
 from tensorflow.python.training import adam
 from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib
 
 try:
   import h5py  # pylint:disable=g-import-not-at-top
@@ -666,6 +668,38 @@
         mode=mode,
         save_freq=3)
 
+    # Case 9: `ModelCheckpoint` with valid and invalid `options` argument.
+    with self.assertRaisesRegexp(TypeError, 'tf.train.CheckpointOptions'):
+      keras.callbacks.ModelCheckpoint(
+          filepath,
+          monitor=monitor,
+          save_best_only=save_best_only,
+          save_weights_only=True,
+          mode=mode,
+          options=save_options_lib.SaveOptions())
+    with self.assertRaisesRegexp(TypeError, 'tf.saved_model.SaveOptions'):
+      keras.callbacks.ModelCheckpoint(
+          filepath,
+          monitor=monitor,
+          save_best_only=save_best_only,
+          save_weights_only=False,
+          mode=mode,
+          options=checkpoint_options_lib.CheckpointOptions())
+    keras.callbacks.ModelCheckpoint(
+        filepath,
+        monitor=monitor,
+        save_best_only=save_best_only,
+        save_weights_only=True,
+        mode=mode,
+        options=checkpoint_options_lib.CheckpointOptions())
+    keras.callbacks.ModelCheckpoint(
+        filepath,
+        monitor=monitor,
+        save_best_only=save_best_only,
+        save_weights_only=False,
+        mode=mode,
+        options=save_options_lib.SaveOptions())
+
   def _get_dummy_resource_for_model_checkpoint_testing(self):
 
     def get_input_datasets():
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 5567e17..ccd184a 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -1979,7 +1979,11 @@
     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
                     signatures, options)
 
-  def save_weights(self, filepath, overwrite=True, save_format=None):
+  def save_weights(self,
+                   filepath,
+                   overwrite=True,
+                   save_format=None,
+                   options=None):
     """Saves all layer weights.
 
     Either saves in HDF5 or in TensorFlow format based on the `save_format`
@@ -2032,6 +2036,8 @@
         save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
             '.keras' will default to HDF5 if `save_format` is `None`. Otherwise
             `None` defaults to 'tf'.
+        options: Optional `tf.train.CheckpointOptions` object that specifies
+            options for saving weights.
 
     Raises:
         ImportError: If h5py is not available when attempting to save in HDF5
@@ -2093,7 +2099,7 @@
              'the TensorFlow format the optimizer\'s state will not be '
              'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
             % (optimizer,))
-      self._trackable_saver.save(filepath, session=session)
+      self._trackable_saver.save(filepath, session=session, options=options)
       # Record this checkpoint so it's visible from tf.train.latest_checkpoint.
       checkpoint_management.update_checkpoint_state_internal(
           save_dir=os.path.dirname(filepath),
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
index b62814e..6318e57 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
@@ -302,7 +302,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
index 7485a0b..9b7b773 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
@@ -320,7 +320,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
index 5fb646e..e6cc7ae 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
@@ -5,7 +5,7 @@
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
+    argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'options\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'None\'], "
   }
   member_method {
     name: "on_batch_begin"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
index bf980e5..976eb49 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -303,7 +303,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
index c214a5c..500aa28 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
@@ -303,7 +303,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
index 86868c9..ad0edc6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
@@ -302,7 +302,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
index 05aa19a..b38c669 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
@@ -320,7 +320,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index b62814e..6318e57 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -302,7 +302,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index 7485a0b..9b7b773 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -320,7 +320,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
index 5fb646e..e6cc7ae 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
@@ -5,7 +5,7 @@
   is_instance: "<type \'object\'>"
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
+    argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'options\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'None\'], "
   }
   member_method {
     name: "on_batch_begin"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
index bf980e5..976eb49 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -303,7 +303,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
index c214a5c..500aa28 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
@@ -303,7 +303,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index 86868c9..ad0edc6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -302,7 +302,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index 05aa19a..b38c669 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -320,7 +320,7 @@
   }
   member_method {
     name: "save_weights"
-    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+    argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
   }
   member_method {
     name: "set_weights"