blob: 814d70c811467cd59a850f1f02be8958c813c94c [file] [log] [blame]
import os.path
from glob import glob
from typing import Any, List
import torch
_storages: List[Any] = [
torch.DoubleStorage,
torch.FloatStorage,
torch.LongStorage,
torch.IntStorage,
torch.ShortStorage,
torch.CharStorage,
torch.ByteStorage,
torch.BoolStorage,
]
_dtype_to_storage = {data_type(0).dtype: data_type for data_type in _storages}
# because get_storage_from_record returns a tensor!?
class _HasStorage(object):
def __init__(self, storage):
self._storage = storage
def storage(self):
return self._storage
class MockZipReader(object):
def __init__(self, directory):
self.directory = directory
def get_record(self, name):
filename = f"{self.directory}/{name}"
with open(filename, "rb") as f:
return f.read()
def get_storage_from_record(self, name, numel, dtype):
storage = _dtype_to_storage[dtype]
filename = f"{self.directory}/{name}"
return _HasStorage(storage.from_file(filename=filename, size=numel))
def get_all_records(
self,
):
files = []
for filename in glob(f"{self.directory}/**", recursive=True):
if not os.path.isdir(filename):
files.append(filename[len(self.directory) + 1 :])
return files