|  | ## @package cached_reader | 
|  | # Module caffe2.python.cached_reader | 
|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  | import os | 
|  |  | 
|  | from caffe2.python import core | 
|  | from caffe2.python.db_file_reader import DBFileReader | 
|  | from caffe2.python.pipeline import pipe | 
|  | from caffe2.python.task import Cluster, TaskGroup | 
|  |  | 
|  |  | 
|  | class CachedReader(DBFileReader): | 
|  |  | 
|  | default_name_suffix = 'cached_reader' | 
|  |  | 
|  | """Reader with persistent in-file cache. | 
|  |  | 
|  | Example usage: | 
|  | cached_reader = CachedReader( | 
|  | reader, | 
|  | db_path='/tmp/cache.db', | 
|  | db_type='LevelDB', | 
|  | ) | 
|  | build_cache_step = cached_reader.build_cache_step() | 
|  | with LocalSession() as session: | 
|  | session.run(build_cache_step) | 
|  |  | 
|  | Every time new CachedReader is created, it's expected that | 
|  | db_path exists before calling .setup_ex(...) and .read(...). | 
|  |  | 
|  | If db_path doesn't exist, it's expected build_cache_step to be called | 
|  | first to build a cache at db_path. | 
|  |  | 
|  | build_cache_step will check existence of provided db_path and in case | 
|  | it's missing will initialize it by reading data from original reader. | 
|  | All consequent attempts to read will ignore original reader | 
|  | (i.e. no additional data will be read from it). | 
|  |  | 
|  | Args: | 
|  | original_reader: Reader. | 
|  | If provided, it's the original reader used to build the cache file. | 
|  | db_path: str. | 
|  |  | 
|  | Optional Args: | 
|  | db_type: str. DB type of file. A db_type is registed by | 
|  | `REGISTER_CAFFE2_DB(<db_type>, <DB Class>)`. | 
|  | Default to 'LevelDB'. | 
|  | name: str or None. Name of CachedReader. | 
|  | Optional name to prepend to blobs that will store the data. | 
|  | Default to '<db_name>_<default_name_suffix>'. | 
|  | batch_size: int. | 
|  | How many examples are read for each time the read_net is run. | 
|  | Defaults to 100. | 
|  | loop_over: bool. | 
|  | If True given, will go through examples in random order endlessly. | 
|  | Defaults to False. | 
|  | """ | 
|  | def __init__( | 
|  | self, | 
|  | original_reader, | 
|  | db_path, | 
|  | db_type='LevelDB', | 
|  | name=None, | 
|  | batch_size=100, | 
|  | loop_over=False, | 
|  | ): | 
|  | assert original_reader is not None, "original_reader can't be None" | 
|  | self.original_reader = original_reader | 
|  |  | 
|  | super(CachedReader, self).__init__( | 
|  | db_path, | 
|  | db_type, | 
|  | name, | 
|  | batch_size, | 
|  | loop_over, | 
|  | ) | 
|  |  | 
|  | def _init_reader_schema(self, *args, **kwargs): | 
|  | """Prepare the reader schema. | 
|  |  | 
|  | Since an original reader is given, | 
|  | use it's schema as ground truth. | 
|  |  | 
|  | Returns: | 
|  | schema: schema.Struct. Used in Reader.__init__(...). | 
|  | """ | 
|  | return self.original_reader._schema | 
|  |  | 
|  | def build_cache_step(self, overwrite=False): | 
|  | """Build a step for generating cache DB file. | 
|  |  | 
|  | If self.db_path exists and not overwritting, build an empty step. | 
|  | Overwise, build a step as follows. | 
|  | Pipe original reader to the _DatasetWriter, | 
|  | so that dataset field blobs are populated. | 
|  | Then save these blobs into a file. | 
|  |  | 
|  | Args: | 
|  | overwrite: bool. If true, ignore the existing file | 
|  | and build a new one overwritting the existing one anyway. | 
|  |  | 
|  | Returns: | 
|  | build_cache_step: ExecutionStep. | 
|  | The step to be run for building a cache DB file. | 
|  | """ | 
|  | if os.path.exists(self.db_path) and not overwrite: | 
|  | # cache already exists, no need to rebuild it | 
|  | return core.execution_step('build_step', []) | 
|  |  | 
|  | init_net = core.Net('init') | 
|  | self._init_field_blobs_as_empty(init_net) | 
|  | with Cluster(), core.NameScope(self.name), TaskGroup() as copy_tg: | 
|  | pipe(self.original_reader, self.ds.writer(), num_threads=16) | 
|  | copy_step = copy_tg.to_task().get_step() | 
|  | save_net = core.Net('save') | 
|  | self._save_field_blobs_to_db_file(save_net) | 
|  |  | 
|  | return core.execution_step('build_cache', [init_net, copy_step, save_net]) | 
|  |  | 
|  | def _save_field_blobs_to_db_file(self, net): | 
|  | """Save dataset field blobs to a DB file at db_path""" | 
|  | net.Save( | 
|  | self.ds.get_blobs(), | 
|  | [], | 
|  | db=self.db_path, | 
|  | db_type=self.db_type, | 
|  | blob_name_overrides=self.ds.field_names(), | 
|  | absolute_path=True, | 
|  | ) |