blob: db148ff2343adda68d5da590f2c04095e3bdc5a4 [file] [log] [blame]
import torch
from torch.legacy import nn
# tensorCache maintains a list of all tensors and storages that have been
# converted (recursively) by calls to recursiveType() and type().
# It caches conversions in order to preserve sharing semantics
# i.e. if two tensors share a common storage, then type conversion
# should preserve that.
#
# You can preserve sharing semantics across multiple networks by
# passing tensorCache between the calls to type, e.g.
#
# > tensorCache = {}
# > net1:type('torch.cuda.FloatTensor', tensorCache)
# > net2:type('torch.cuda.FloatTensor', tensorCache)
# > nn.utils.recursiveType(anotherTensor, 'torch.cuda.FloatTensor', tensorCache)
def recursiveType(param, type, tensorCache={}):
if isinstance(param, list):
for i, p in enumerate(param):
param[i] = recursiveType(p, type, tensorCache)
elif isinstance(param, nn.Module) or isinstance(param, nn.Criterion):
param.type(type, tensorCache)
elif torch.isTensor(param):
if torch.typename(param) != type:
key = param._cdata
if key in tensorCache:
newparam = tensorCache[key]
else:
newparam = torch.Tensor().type(type)
storageType = type.replace('Tensor','Storage')
param_storage = param.storage()
if param_storage:
storage_key = param_storage._cdata
if storage_key not in tensorCache:
tensorCache[storage_key] = torch._import_dotted_name(storageType)(param_storage.size()).copy(param_storage)
newparam.set(
tensorCache[storage_key],
param.storageOffset(),
param.size(),
param.stride()
)
tensorCache[key] = newparam
param = newparam
return param
def recursiveResizeAs(t1, t2):
if isinstance(t2, list):
t1 = t1 if isinstance(t1, list) else [t1]
if len(t1) < len(t2):
t1 += [None] * (len(t2) - len(t1))
for i, _ in enumerate(t2):
t1[i], t2[i] = recursiveResizeAs(t1[i], t2[i])
t1 = t1[:len(t2)]
elif torch.isTensor(t2):
t1 = t1 if torch.isTensor(t1) else t2.new()
t1.resizeAs(t2)
else:
raise RuntimeError("Expecting nested tensors or tables. Got " + \
type(t1).__name__ + " and " + type(t2).__name__ + "instead")
return t1, t2
def recursiveFill(t2, val):
if isinstance(t2, list):
t2 = [recursiveFill(x, val) for x in t2]
elif torch.isTensor(t2):
t2.fill(val)
else:
raise RuntimeError("expecting tensor or table thereof. Got " + \
type(t2).__name__ + " instead")
return t2
def recursiveAdd(t1, val=1, t2=None):
if t2 is None:
t2 = val
val = 1
if isinstance(t2, list):
t1 = t1 if isinstance(t1, list) else [t1]
for i, _ in enumerate(t2):
t1[i], t2[i] = recursiveAdd(t1[i], val, t2[i])
elif torch.isTensor(t1) and torch.isTensor(t2):
t1.add(val, t2)
else:
raise RuntimeError("expecting nested tensors or tables. Got " + \
type(t1).__name__ + " and " + type(t2).__name__ + " instead")
return t1, t2
def recursiveCopy(t1, t2):
if isinstance(t2, list):
t1 = t1 if isinstance(t1, list) else [t1]
for i, _ in enumerate(t2):
t1[i], t2[i] = recursiveCopy(t1[i], t2[i])
elif torch.isTensor(t2):
t1 = t1 if torch.isTensor(t1) else t2.new()
t1.resizeAs(t2).copy(t2)
else:
raise RuntimeError("expecting nested tensors or tables. Got " + \
type(t1).__name__ + " and " + type(t2).__name__ + " instead")
return t1, t2
def addSingletonDimension(*args):
view = None
if len(args) < 3:
t, dim = args
else:
view, t, dim = args
assert torch.isTensor(view)
assert torch.isTensor(t)
if view is None:
view = t.new()
size = torch.LongStorage(t.dim() + 1)
stride = torch.LongStorage(t.dim() + 1)
for d in range(dim):
size[d] = t.size(d)
stride[d] = t.stride(d)
size[dim] = 1
stride[dim] = 1
for d in range(dim+1, t.dim()+1):
size[d] = t.size(d - 1)
stride[d] = t.stride(d - 1)
view.set(t.storage(), t.storageOffset(), size, stride)
return view
def contiguousView(output, input, *args):
output = output or input.new()
if input.isContiguous():
output.view(input, *args)
else:
output.resizeAs(input)
output.copy(input)
output.view(output, *args)
return output
# go over specified fields and clear them. accepts
# nn.utils.clearState(self, ['_buffer', '_buffer2']) and
# nn.utils.clearState(self, '_buffer', '_buffer2')
def clear(self, *args):
if len(args) == 1 and isinstance(args[0], list):
args = args[1]
def _clear(f):
if not hasattr(self, f):
return
attr = getattr(self, f)
if torch.isTensor(attr):
attr.set()
elif isinstance(attr, list):
del attr[:]
else:
delattr(self, f)
for key in arg:
_clear(key)
return self