blob: a3ccaefaf86c36f115cee537d49bd152e8120d99 [file] [log] [blame]
from itertools import repeat
from collections import defaultdict
import torch
from torch._thnn.utils import parse_header, THNN_H_PATH
from torch.autograd.function import Function, InplaceFunction
from torch._thnn import type2backend
from . import _all_functions
def _make_function_class_criterion(class_name, update_output, update_grad_input, acc_grad_parameters):
weight_arg_idx = -1
for i, arg in enumerate(update_output.arguments):
if arg.name.startswith('weight'):
weight_arg_idx = i
break
buffers_idx = []
additional_arg_idx = 0
for arg in update_output.arguments[4:]:
if not arg.name.startswith('weight') and arg.type == 'THTensor*':
buffers_idx.append(additional_arg_idx)
additional_arg_idx += 1
def __init__(self, *args, **kwargs):
Function.__init__(self)
self.weight = kwargs.get('weight')
self.additional_args = list(args)
def forward(self, input, target):
self._backend = type2backend[type(input)]
self.save_for_backward(input, target)
if weight_arg_idx >= 0:
insert_idx = weight_arg_idx - 4 # state, input, target, output
self.additional_args.insert(insert_idx, self.weight)
for idx in buffers_idx:
self.additional_args.insert(idx, input.new(1))
output = input.new(1)
getattr(self._backend, update_output.name)(self._backend.library_state, input, target,
output, *self.additional_args)
return output
def backward(self, grad_output):
input, target = self.saved_tensors
grad_input = grad_output.new().resize_as_(input).zero_()
getattr(self._backend, update_grad_input.name)(self._backend.library_state, input, target,
grad_input, *self.additional_args)
grad_output_expanded = grad_output.view(*repeat(1, grad_input.dim()))
grad_input.mul_(grad_output_expanded.expand_as(grad_input))
return grad_input, None
return type(class_name, (Function,), dict(__init__=__init__, forward=forward, backward=backward))
def _find_buffers(args, ignored_args):
additional_arg_idx = 0
buffers = []
for arg in args:
if arg.name in ignored_args:
continue
if arg.type == 'THTensor*':
buffers.append((additional_arg_idx, arg.name))
additional_arg_idx += 1
return buffers
def _make_function_class(class_name, update_output, update_grad_input, acc_grad_parameters):
def has_argument(fn, name):
for arg in fn.arguments:
if arg.name == name:
return True
return False
save_output = has_argument(update_grad_input, 'output')
param_args = {'weight', 'bias'}
ignored_args = {'weight', 'bias', 'gradWeight', 'gradBias', 'output'}
expected_params = [arg for arg in update_output.arguments[3:]
if arg.name in param_args]
buffers = {}
buffers['update_output'] = _find_buffers(update_output.arguments[3:],
ignored_args)
buffers['update_grad_input'] = _find_buffers(
update_grad_input.arguments[4:], ignored_args)
if acc_grad_parameters is not None:
buffers['acc_grad_parameters'] = _find_buffers(
acc_grad_parameters.arguments[3:], ignored_args)
# This and __init__ assume that only the last argument can be
# an inplace flag
is_inplace = update_output.arguments[-1].name == 'inplace'
def __init__(self, *args):
if is_inplace:
InplaceFunction.__init__(self, args[-1])
else:
Function.__init__(self)
self.additional_args = list(args)
def _initialize_buffers(self, fn_name):
additional_args = self.additional_args
for idx, name in buffers[fn_name]:
# TODO: some buffers are necessary only for update output and can be
# freed right afterwards
buffer = self.buffers[name]
additional_args = additional_args[:idx] + [buffer] + additional_args[idx:]
return tuple(additional_args)
def forward(self, input, *params):
self._backend = type2backend[type(input)]
for param in params:
if type(param) != type(input):
raise RuntimeError("input type ({}) doesn't match the type of "
"a parameter tensor ({})".format(torch.typename(input),
torch.typename(param)))
# Allocate temporary buffers and insert them into additional_args
self.buffers = defaultdict(type(input))
additional_args = self._initialize_buffers('update_output')
# Fill in optional params with None
args = params
for i in range(len(params), len(expected_params)):
param = expected_params[i]
if param.is_optional:
args += (None,)
else:
raise ValueError("missing required argument '%s'" % param.name)
args += tuple(additional_args)
# If the module is working in-place it's output will be set to the
# same storage as input, but it's variable won't be dirty.
if is_inplace and self.inplace:
self.mark_dirty(input)
output = input
self.save_for_backward(input, *params)
else:
output = input.new()
if save_output:
self.save_for_backward(input, output, *params)
else:
self.save_for_backward(input, *params)
getattr(self._backend, update_output.name)(self._backend.library_state, input, output, *args)
return output
def backward(self, grad_output):
t = self.saved_tensors
if save_output:
input, output, params = t[0], t[1], t[2:]
else:
input, params = t[0], t[1:]
grad_params = tuple(None for p in params)
grad_input_tuple = (None,)
if self.needs_input_grad[0]:
additional_args = self._initialize_buffers('update_grad_input')
if save_output:
additional_args = (output,) + additional_args
grad_input = input.new().resize_as_(input).zero_()
params_without_bias = params if len(params) < 2 else params[:1]
update_grad_input_fn = getattr(self._backend, update_grad_input.name)
gi_args = params_without_bias + additional_args
update_grad_input_fn(self._backend.library_state, input, grad_output, grad_input, *gi_args)
grad_input_tuple = (grad_input,)
if acc_grad_parameters and any(self.needs_input_grad[1:]):
additional_args = self._initialize_buffers('acc_grad_parameters')
grad_params = tuple(p.new().resize_as_(p).zero_() for p in params)
appended_grads = len(expected_params) - len(grad_params)
grad_params += (None,) * appended_grads
acc_grad_parameters_fn = getattr(self._backend, acc_grad_parameters.name)
param_args = grad_params + additional_args + (1,)
acc_grad_parameters_fn(self._backend.library_state, input, grad_output, *param_args)
if appended_grads:
grad_params = grad_params[:-appended_grads]
return grad_input_tuple + grad_params
base_class = Function if not is_inplace else InplaceFunction
return type(class_name, (base_class,), dict(__init__=__init__, forward=forward, backward=backward, _initialize_buffers=_initialize_buffers))
def _generate_function_classes(scope_dict):
function_list = parse_header(THNN_H_PATH)
function_by_name = {fn.name: fn for fn in function_list}
classes_to_generate = {fn.name.partition('_')[0] for fn in function_list}
exceptions = {
'Linear',
'SpatialConvolutionMM',
'SparseLinear',
'TemporalConvolution',
'SpatialMaxPooling',
'SpatialDilatedMaxPooling',
'SpatialMaxUnpooling',
'VolumetricMaxPooling',
'VolumetricMaxUnpooling',
'VolumetricConvolution',
'VolumetricConvolutionMM',
'TemporalMaxPooling',
'BatchNormalization',
'LookupTable',
'PReLU',
'RReLU',
'unfolded',
}
name_remap = {
'TemporalConvolution': 'Conv1d',
'SpatialDilatedConvolution': 'DilatedConv2d',
'SpatialAveragePooling': 'AvgPool2d',
'SpatialMaxUnpooling': 'MaxUnpool2d',
'SpatialReflectionPadding': 'ReflectionPad2d',
'SpatialReplicationPadding': 'ReplicationPad2d',
'SpatialFullConvolution': 'ConvTranspose2d',
'VolumetricFullConvolution': 'ConvTranspose3d',
'VolumetricReplicationPadding': 'ReplicationPad3d',
'VolumetricMaxUnpooling': 'MaxUnpool3d',
'VolumetricAveragePooling': 'AvgPool3d',
'SoftMax': 'Softmax',
'LogSoftMax': 'LogSoftmax',
'HardTanh': 'Hardtanh',
'HardShrink': 'Hardshrink',
'SoftPlus': 'Softplus',
'SoftShrink': 'Softshrink',
'MSECriterion': 'MSELoss',
'AbsCriterion': 'L1Loss',
'BCECriterion': '_BCELoss', # TODO: move the glue code into THNN
'ClassNLLCriterion': 'NLLLoss',
'DistKLDivCriterion': 'KLDivLoss',
'SpatialClassNLLCriterion': 'NLLLoss2d',
'MultiLabelMarginCriterion': 'MultiLabelMarginLoss',
'MultiMarginCriterion': 'MultiMarginLoss',
'SmoothL1Criterion': 'SmoothL1Loss',
'SoftMarginCriterion': 'SoftMarginLoss',
}
classes_to_generate -= exceptions
for fn in classes_to_generate:
update_output = function_by_name[fn + '_updateOutput']
update_grad_input = function_by_name[fn + '_updateGradInput']
acc_grad_parameters = function_by_name.get(fn + '_accGradParameters')
class_name = name_remap.get(fn, fn)
# This has to call a function to retain correct references to functions
if 'Criterion' in fn:
cls = _make_function_class_criterion(class_name, update_output,
update_grad_input, acc_grad_parameters)
else:
cls = _make_function_class(class_name, update_output,
update_grad_input, acc_grad_parameters)
scope_dict[class_name] = cls
if not class_name.startswith('_'):
_all_functions.append(cls)
_generate_function_classes(locals())