| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the 'License'); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an 'AS IS' BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ====================================== |
| """Hook for asynchronous checkpointing. |
| |
| This hook dispatches checkpoint writing operations in a separate thread to |
| allow execution to continue on the main thread. |
| """ |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import os |
| import threading |
| import time |
| |
| from tensorflow.core.util import event_pb2 |
| from tensorflow.python.framework import meta_graph |
| from tensorflow.python.framework import ops |
| from tensorflow.python.platform import tf_logging as logging |
| from tensorflow.python.training import basic_session_run_hooks |
| from tensorflow.python.training import training_util |
| from tensorflow.python.training.session_run_hook import SessionRunArgs |
| from tensorflow.python.training.summary_io import SummaryWriterCache |
| |
| |
| class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): |
| """Saves checkpoints every N steps or seconds.""" |
| |
| def __init__(self, |
| checkpoint_dir, |
| save_secs=None, |
| save_steps=None, |
| saver=None, |
| checkpoint_basename="model.ckpt", |
| scaffold=None, |
| listeners=None): |
| """Initializes a `CheckpointSaverHook`. |
| |
| Args: |
| checkpoint_dir: `str`, base directory for the checkpoint files. |
| save_secs: `int`, save every N secs. |
| save_steps: `int`, save every N steps. |
| saver: `Saver` object, used for saving. |
| checkpoint_basename: `str`, base name for the checkpoint files. |
| scaffold: `Scaffold`, use to get saver object. |
| listeners: List of `CheckpointSaverListener` subclass instances. Used for |
| callbacks that run immediately before or after this hook saves the |
| checkpoint. |
| |
| Raises: |
| ValueError: One of `save_steps` or `save_secs` should be set. |
| ValueError: At most one of `saver` or `scaffold` should be set. |
| """ |
| save_path = os.path.join(checkpoint_dir, checkpoint_basename) |
| logging.info("Create AsyncCheckpointSaverHook saving to path\n%s\n" |
| "with %d listener(s).", save_path, len(listeners)) |
| if saver is not None and scaffold is not None: |
| raise ValueError("You cannot provide both saver and scaffold.") |
| self._saver = saver |
| self._save_thread = None |
| self._write_graph_thread = None |
| self._checkpoint_dir = checkpoint_dir |
| self._save_path = save_path |
| self._scaffold = scaffold |
| self._timer = basic_session_run_hooks.SecondOrStepTimer( |
| every_secs=save_secs, every_steps=save_steps) |
| self._listeners = listeners or [] |
| self._steps_per_run = 1 |
| self._summary_writer = None |
| self._global_step_tensor = None |
| |
| self._last_checkpoint_step = None |
| |
| def _set_steps_per_run(self, steps_per_run): |
| self._steps_per_run = steps_per_run |
| |
| def begin(self): |
| self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) |
| self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access |
| if self._global_step_tensor is None: |
| raise RuntimeError( |
| "Global step should be created to use CheckpointSaverHook.") |
| for l in self._listeners: |
| l.begin() |
| |
| def after_create_session(self, session, coord): |
| global_step = session.run(self._global_step_tensor) |
| |
| # We do write graph and saver_def at the first call of before_run. |
| # We cannot do this in begin, since we let other hooks to change graph and |
| # add variables in begin. Graph is finalized after all begin calls. |
| def _write_graph_fn(self): |
| training_util.write_graph( |
| ops.get_default_graph().as_graph_def(add_shapes=True), |
| self._checkpoint_dir, "graph.pbtxt") |
| self._write_graph_thread = threading.Thread(target=_write_graph_fn, |
| args=[self]) |
| self._write_graph_thread.start() |
| |
| saver_def = self._get_saver().saver_def if self._get_saver() else None |
| graph = ops.get_default_graph() |
| meta_graph_def = meta_graph.create_meta_graph_def( |
| graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) |
| self._summary_writer.add_graph(graph) |
| self._summary_writer.add_meta_graph(meta_graph_def) |
| # The checkpoint saved here is the state at step "global_step". |
| self._save(session, global_step) |
| self._timer.update_last_triggered_step(global_step) |
| |
| def before_run(self, run_context): # pylint: disable=unused-argument |
| return SessionRunArgs(self._global_step_tensor) |
| |
| def after_run(self, run_context, run_values): |
| global_step = run_context.session.run(self._global_step_tensor) |
| if self._timer.should_trigger_for_step(global_step): |
| self._timer.update_last_triggered_step(global_step) |
| logging.info("Triggering checkpoint. %s", global_step) |
| if self._save(run_context.session, global_step): |
| run_context.request_stop() |
| |
| def end(self, session): |
| if self._save_thread: |
| logging.info("Waiting for any pending checkpoints to finish.") |
| self._save_thread.join() |
| if self._write_graph_thread: |
| logging.info("Waiting for any pending write_graph to finish.") |
| self._write_graph_thread.join() |
| |
| last_step = session.run(self._global_step_tensor) |
| |
| if self._last_checkpoint_step != last_step: |
| self._save(session, last_step, asynchronous=False) |
| |
| for l in self._listeners: |
| l.end(session, last_step) |
| |
| def _save(self, session, step, asynchronous=True): |
| """Saves the latest checkpoint, returns should_stop.""" |
| |
| def _save_fn(): |
| """Run the saver process.""" |
| logging.info("Saving checkpoints for %d into %s.", step, self._save_path) |
| |
| start_time = time.time() |
| for l in self._listeners: |
| l.before_save(session, step) |
| |
| self._get_saver().save(session, self._save_path, global_step=step) |
| self._summary_writer.add_session_log( |
| event_pb2.SessionLog( |
| status=event_pb2.SessionLog.CHECKPOINT, |
| checkpoint_path=self._save_path), step) |
| |
| for l in self._listeners: |
| l.after_save(session, step) |
| |
| end_time = time.time() |
| logging.info("Checkpoint actual writing time: (%.3f sec)", |
| end_time - start_time) |
| logging.info("Checkpoint finished for %d into %s.", step, self._save_path) |
| |
| if not asynchronous: |
| self._last_checkpoint_step = step |
| _save_fn() |
| return |
| |
| if self._save_thread is not None: |
| self._save_thread.join(timeout=0.1) |
| if self._save_thread.is_alive(): |
| logging.info("Saver thread still in progress, skipping checkpoint.") |
| return |
| |
| self._last_checkpoint_step = step |
| self._save_thread = threading.Thread(target=_save_fn) |
| self._save_thread.start() |
| |
| def _get_saver(self): |
| if self._saver is not None: |
| return self._saver |
| elif self._scaffold is not None: |
| return self._scaffold.saver |
| |
| # Get saver from the SAVERS collection if present. |
| collection_key = ops.GraphKeys.SAVERS |
| savers = ops.get_collection(collection_key) |
| if not savers: |
| raise RuntimeError( |
| "No items in collection {}. Please add a saver to the collection " |
| "or provide a saver or scaffold.".format(collection_key)) |
| elif len(savers) > 1: |
| raise RuntimeError( |
| "More than one item in collection {}. " |
| "Please indicate which one to use by passing it to the constructor." |
| .format(collection_key)) |
| |
| self._saver = savers[0] |
| return savers[0] |