| import torch |
| from glob import glob |
| import os.path |
| from typing import List, Any |
| |
| _storages : List[Any] = [ |
| torch.DoubleStorage, |
| torch.FloatStorage, |
| torch.LongStorage, |
| torch.IntStorage, |
| torch.ShortStorage, |
| torch.CharStorage, |
| torch.ByteStorage, |
| torch.BoolStorage, |
| ] |
| _dtype_to_storage = { |
| data_type(0).dtype: data_type for data_type in _storages |
| } |
| |
| # because get_storage_from_record returns a tensor!? |
| class _HasStorage(object): |
| def __init__(self, storage): |
| self._storage = storage |
| |
| def storage(self): |
| return self._storage |
| |
| |
| class MockZipReader(object): |
| def __init__(self, directory): |
| self.directory = directory |
| |
| def get_record(self, name): |
| filename = f'{self.directory}/{name}' |
| with open(filename, 'rb') as f: |
| return f.read() |
| |
| def get_storage_from_record(self, name, numel, dtype): |
| storage = _dtype_to_storage[dtype] |
| filename = f'{self.directory}/{name}' |
| return _HasStorage(storage.from_file(filename=filename, size=numel)) |
| |
| def get_all_records(self, ): |
| files = [] |
| for filename in glob(f'{self.directory}/**', recursive=True): |
| if not os.path.isdir(filename): |
| files.append(filename[len(self.directory) + 1:]) |
| return files |