blob: bf6d268d5c5a87da10800d3686f9de1e8c2b847f [file] [log] [blame]
import torch
from torch.legacy import nn
class Module(object):
def __init__(self):
self.gradInput = torch.Tensor()
self.output = torch.Tensor()
self._type = self.output.type()
self._backend = nn._backends.THNNDoubleBackend
def parameters(self):
if self.weight and self.bias:
return [self.weight, self.bias], [self.gradWeight, self.gradBias]
elif self.weight:
return [self.weight], [self.gradWeight]
elif self.bias:
return [self.bias], [self.gradBias]
else:
return
def updateOutput(self, input):
return self.output
def forward(self, input):
return self.updateOutput(input)
def backward(self, input, gradOutput, scale=1):
self.updateGradInput(input, gradOutput)
self.accGradParameters(input, gradOutput, scale)
return self.gradInput
def backwardUpdate(self, input, gradOutput, lr):
self.updateGradInput(input, gradOutput)
self.accUpdateGradParameters(input, gradOutput, lr)
return self.gradInput
def updateGradInput(self, input, gradOutput):
return self.gradInput
def accGradParameters(self, input, gradOutput, scale=1):
pass
def accUpdateGradParameters(self, input, gradOutput, lr):
gradWeight = self.gradWeight
gradBias = self.gradBias
self.gradWeight = self.weight
self.gradBias = self.bias
self.accGradParameters(input, gradOutput, -lr)
self.gradWeight = gradWeight
self.gradBias = gradBias
def sharedAccUpdateGradParameters(self, input, gradOutput, lr):
if self.parameters():
self.zeroGradParameters()
self.accGradParameters(input, gradOutput, 1)
self.updateParameters(lr)
def zeroGradParameters(self):
_, gradParams = self.parameters()
if gradParams:
for grad in gradParams:
grad.zero()
def updateParameters(self, learningRate):
params, gradParams = self.parameters()
if params:
for p, gp in zip(params, gradParams):
p.add(-learningRate, gp)
def training(self):
self.train = True
def evaluate(self):
self.train = False
# TODO
def share(self, mlp, *arg):
for i, v in ipairs(arg):
if self[v] != nil:
self[v].set(mlp[v])
self.accUpdateGradParameters = self.sharedAccUpdateGradParameters
mlp.accUpdateGradParameters = mlp.sharedAccUpdateGradParameters
return self
def clone(self, *arg):
f = torch.MemoryFile("rw").binary()
f.writeObject(self)
f.seek(1)
clone = f.readObject()
f.close()
if len(arg) > 0:
clone.share(self, *arg)
return clone
def type(self, type, tensorCache):
if not type:
return self._type
tensorCache = tensorCache or {}
# find all tensors and convert them
for key, param in pairs(self):
self[key] = nn.utils.recursiveType(param, type, tensorCache)
self._type = type
return self
def float(self, *args):
return self.type('torch.FloatTensor', *args)
def double(self, *args):
return self.type('torch.DoubleTensor', *args)
def cuda(self, *args):
return self.type('torch.CudaTensor', *args)
def reset(self):
pass
def write(self, f):
raise NotImplementedError
def read(self, f):
raise NotImplementedError
# This function is not easy to understand. It works as follows:
#
# - gather all parameter tensors for this module (and children);
# count all parameter values (floats)
# - create one ginormous memory area (Storage object) with room for all
# parameters
# - remap each parameter tensor to point to an area within the ginormous
# Storage, and copy it there
#
# It has the effect of making all parameters point to the same memory area,
# which is: returned.
#
# The purpose is to allow operations over all parameters (such as momentum
# updates and serialization), but it assumes that all parameters are of
# the same type (and, in the case of CUDA, on the same device), which
# is not always True. Use for_each() to iterate over this module and
# children instead.
#
# Module._flattenTensorBuffer can be used by other packages (e.g. cunn)
# to specify the type of temporary buffers. For example, the temporary
# buffers for CudaTensor could be FloatTensor, to avoid GPU memory usage.
#
# TODO: This logically belongs to torch.Tensor, not nn.
_flattenTensorBuffer = {}
def _flatten(self, parameters=[]):
# returns True if tensor occupies a contiguous region of memory (no holes)
def isCompact(tensor):
# TODO: wut, does it really need to create this tensor?
# isn't it enough to check if strides == size.cumprod(0)?
sortedStride, perm = torch.sort(torch.LongTensor(tensor.nDimension()).set(tensor.stride()), 0, True)
sortedSize = torch.LongTensor(tensor.nDimension()).set(tensor.size()).index(1, perm)
nRealDim = torch.clamp(sortedStride, 0, 1).sum()
sortedStride = sortedStride.narrow(1, 1, nRealDim).clone()
sortedSize = sortedSize.narrow(1, 1, nRealDim).clone()
t = tensor.new().set(tensor.storage(), 1,
sortedSize.storage(),
sortedStride.storage())
return t.isContiguous()
if not parameters:
return torch.Tensor()
Tensor = parameters[0].new
BufferTensor = Module._flattenTensorBuffer[torch.type(parameters[1])] or Tensor
# 1. construct the set of all unique storages referenced by parameter tensors
storages = {}
num_parameters = 0
parameterMeta = []
for i, param in enumerate(parameters):
storage = param.storage()
key = storage._cdata
if not storages[key]:
storages[key] = (storage, num_parameters)
num_parameters = num_parameters + storage.size()
parameterMeta[i] = {
'storageOffset': param.storageOffset() + storages[key][1],
'size' : param.size(),
'stride' : param.stride()
}
# 2. construct a single tensor that will hold all the parameters
flatParameters = BufferTensor(num_parameters).zero()
# 3. determine if there are elements in the storage that none of the
# parameter tensors reference ('holes')
tensorsCompact = True
for meta in parameterMeta:
# TODO: reuse one Tensor
tmp = BufferTensor().set(flatParameters.storage(), meta.storageOffset, meta.size, meta.stride)
tmp.fill(1)
tensorsCompact = tensorsCompact and isCompact(tmp)
maskParameters = flatParameters.byte().clone()
compactOffsets = flatParameters.long().cumsum(1)
used_parameters = compactOffsets[-1]
# 4. copy storages into the flattened parameter tensor
for storageAndOffset in storages.values():
storage, offset = storageAndOffset
# TODO: reuse Tensor
flatParameters[slice(offset, offset+storage.size())].copy(Tensor().set(storage))
# 5. allow garbage collection
storages = None
for param in parameters:
param.set()
# 6. compact the flattened parameters if there were holes
if used_parameters != num_parameters:
assert tensorsCompact
flatParameters = BufferTensor(used_parameters).copy(
flatParameters.maskedSelect(maskParameters))
for meta in parameterMeta:
meta['storageOffset'] = compactOffsets[meta['storageOffset']]
if BufferTensor != Tensor:
flatParameters = Tensor(flatParameters.nElement()).copy(flatParameters)
# 7. fix up the parameter tensors to point at the flattened parameters
for param, meta in zip(parameters, parameterMeta):
param.set(flatParameters.storage(),
meta['storageOffset'],
meta['size'],
meta['stride'])
return flatParameters
def flattenParameters(self):
parameters, gradParameters = self.parameters()
p, g = self._flatten(parameters), self._flatten(gradParameters)
assert p.nElement() == g.nElement()
if parameters:
for param, grad in zip(parameters, gradParameters):
assert param.storageOffset() == grad.storageOffset()
return p, g
def apply(self, callback):
callback(self)
for _, module in self.modules:
module.apply(callback)
def findModules(self, typename, container=None):
nodes = []
containers = []
mod_type = str(type(self))
if mod_type == typename:
nodes.append(self)
containers.append(container)
# Recurse on nodes with 'modules'
if self.modules:
for child in self.modules:
child_nodes, child_containers = child.findModules(typename, self)
assert len(child_nodes) == len(child_containers)
# add the list items from our child to our list (i.e. return a
# flattened table of the return nodes).
nodes.extend(child_nodes)
containers.extend(child_containers)
return nodes, containers
def listModules(self):
# include self first
modules = [self]
if self.modules:
for child in self.modules:
modules.extend(child.listModules())
return modules
def clearState(self):
return nn.utils.clear(self, 'output', 'gradInput')
def replace(self, callback):
out = callback(self)
# TODO: not out.modules?
if self.modules:
for i, module in self.modules:
self.modules[i] = module.replace(callback)
return out