Add SaveOptions/CheckpointOptions to keras.Models.save_weights and keras_call_backs.ModelCheckpoint.
PiperOrigin-RevId: 316973333
Change-Id: I43f5b59ece4b862db41ab0e99f3c8df0a0d3b901
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 1bca541..1fae5ab 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -54,7 +54,9 @@
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler import profiler_v2 as profiler
+from tensorflow.python.saved_model import save_options as save_options_lib
from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
from tensorflow.python.util.tf_export import keras_export
@@ -1115,6 +1117,9 @@
epochs, the monitored metric may potentially be less reliable (it
could reflect as little as 1 batch, since the metrics get reset every
epoch). Defaults to `'epoch'`.
+ options: Optional `tf.train.CheckpointOptions` object if
+ `save_weights_only` is true or optional `tf.saved_model.SavedOptions`
+ object if `save_weights_only` is false.
**kwargs: Additional arguments for backwards compatibility. Possible key
is `period`.
"""
@@ -1127,6 +1132,7 @@
save_weights_only=False,
mode='auto',
save_freq='epoch',
+ options=None,
**kwargs):
super(ModelCheckpoint, self).__init__()
self._supports_tf_logs = True
@@ -1140,6 +1146,20 @@
self._batches_seen_since_last_saving = 0
self._last_batch_seen = 0
+ if save_weights_only:
+ if options is None or isinstance(
+ options, checkpoint_options_lib.CheckpointOptions):
+ self._options = options or checkpoint_options_lib.CheckpointOptions()
+ else:
+ raise TypeError('If save_weights_only is True, then `options` must be'
+ 'either None or a tf.train.CheckpointOptions')
+ else:
+ if options is None or isinstance(options, save_options_lib.SaveOptions):
+ self._options = options or save_options_lib.SaveOptions()
+ else:
+ raise TypeError('If save_weights_only is False, then `options` must be'
+ 'either None or a tf.saved_model.SaveOptions')
+
# Deprecated field `load_weights_on_restart` is for loading the checkpoint
# file from `filepath` at the start of `model.fit()`
# TODO(rchao): Remove the arg during next breaking release.
@@ -1269,9 +1289,10 @@
self.best, current, filepath))
self.best = current
if self.save_weights_only:
- self.model.save_weights(filepath, overwrite=True)
+ self.model.save_weights(
+ filepath, overwrite=True, options=self._options)
else:
- self.model.save(filepath, overwrite=True)
+ self.model.save(filepath, overwrite=True, options=self._options)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
@@ -1280,9 +1301,10 @@
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
- self.model.save_weights(filepath, overwrite=True)
+ self.model.save_weights(
+ filepath, overwrite=True, options=self._options)
else:
- self.model.save(filepath, overwrite=True)
+ self.model.save(filepath, overwrite=True, options=self._options)
self._maybe_remove_file()
except IOError as e:
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index 28f8530..d180e85 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -49,9 +49,11 @@
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import save_options as save_options_lib
from tensorflow.python.summary import summary_iterator
from tensorflow.python.training import adam
from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training.saving import checkpoint_options as checkpoint_options_lib
try:
import h5py # pylint:disable=g-import-not-at-top
@@ -666,6 +668,38 @@
mode=mode,
save_freq=3)
+ # Case 9: `ModelCheckpoint` with valid and invalid `options` argument.
+ with self.assertRaisesRegexp(TypeError, 'tf.train.CheckpointOptions'):
+ keras.callbacks.ModelCheckpoint(
+ filepath,
+ monitor=monitor,
+ save_best_only=save_best_only,
+ save_weights_only=True,
+ mode=mode,
+ options=save_options_lib.SaveOptions())
+ with self.assertRaisesRegexp(TypeError, 'tf.saved_model.SaveOptions'):
+ keras.callbacks.ModelCheckpoint(
+ filepath,
+ monitor=monitor,
+ save_best_only=save_best_only,
+ save_weights_only=False,
+ mode=mode,
+ options=checkpoint_options_lib.CheckpointOptions())
+ keras.callbacks.ModelCheckpoint(
+ filepath,
+ monitor=monitor,
+ save_best_only=save_best_only,
+ save_weights_only=True,
+ mode=mode,
+ options=checkpoint_options_lib.CheckpointOptions())
+ keras.callbacks.ModelCheckpoint(
+ filepath,
+ monitor=monitor,
+ save_best_only=save_best_only,
+ save_weights_only=False,
+ mode=mode,
+ options=save_options_lib.SaveOptions())
+
def _get_dummy_resource_for_model_checkpoint_testing(self):
def get_input_datasets():
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 5567e17..ccd184a 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -1979,7 +1979,11 @@
save.save_model(self, filepath, overwrite, include_optimizer, save_format,
signatures, options)
- def save_weights(self, filepath, overwrite=True, save_format=None):
+ def save_weights(self,
+ filepath,
+ overwrite=True,
+ save_format=None,
+ options=None):
"""Saves all layer weights.
Either saves in HDF5 or in TensorFlow format based on the `save_format`
@@ -2032,6 +2036,8 @@
save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
'.keras' will default to HDF5 if `save_format` is `None`. Otherwise
`None` defaults to 'tf'.
+ options: Optional `tf.train.CheckpointOptions` object that specifies
+ options for saving weights.
Raises:
ImportError: If h5py is not available when attempting to save in HDF5
@@ -2093,7 +2099,7 @@
'the TensorFlow format the optimizer\'s state will not be '
'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.')
% (optimizer,))
- self._trackable_saver.save(filepath, session=session)
+ self._trackable_saver.save(filepath, session=session, options=options)
# Record this checkpoint so it's visible from tf.train.latest_checkpoint.
checkpoint_management.update_checkpoint_state_internal(
save_dir=os.path.dirname(filepath),
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
index b62814e..6318e57 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
@@ -302,7 +302,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
index 7485a0b..9b7b773 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
@@ -320,7 +320,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
index 5fb646e..e6cc7ae 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
@@ -5,7 +5,7 @@
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
+ argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'options\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'None\'], "
}
member_method {
name: "on_batch_begin"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
index bf980e5..976eb49 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -303,7 +303,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
index c214a5c..500aa28 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-wide-deep-model.pbtxt
@@ -303,7 +303,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
index 86868c9..ad0edc6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
@@ -302,7 +302,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
index 05aa19a..b38c669 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
@@ -320,7 +320,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index b62814e..6318e57 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -302,7 +302,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index 7485a0b..9b7b773 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -320,7 +320,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
index 5fb646e..e6cc7ae 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.callbacks.-model-checkpoint.pbtxt
@@ -5,7 +5,7 @@
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\'], "
+ argspec: "args=[\'self\', \'filepath\', \'monitor\', \'verbose\', \'save_best_only\', \'save_weights_only\', \'mode\', \'save_freq\', \'options\'], varargs=None, keywords=kwargs, defaults=[\'val_loss\', \'0\', \'False\', \'False\', \'auto\', \'epoch\', \'None\'], "
}
member_method {
name: "on_batch_begin"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
index bf980e5..976eb49 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-linear-model.pbtxt
@@ -303,7 +303,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
index c214a5c..500aa28 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-wide-deep-model.pbtxt
@@ -303,7 +303,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index 86868c9..ad0edc6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -302,7 +302,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index 05aa19a..b38c669 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -320,7 +320,7 @@
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'None\', \'None\'], "
}
member_method {
name: "set_weights"