blob: a3f44eca9415816ae555a84a407ab9e1d115b577 [file] [log] [blame]
import torch
from torch.autograd import Function
from torch._thnn import type2backend
from .thnn.auto import function_by_name
import torch.backends.cudnn as cudnn
_thnn_convs = {}
class ConvNd(Function):
def __init__(self, stride, padding, dilation, transposed, output_padding,
groups):
super(ConvNd, self).__init__()
if len(stride) == 1:
# view 1d convolutions as 2d
stride = (1,) + stride
padding = (0,) + padding
dilation = (1,) + dilation
output_padding = (0,) + output_padding
self.stride = stride
self.padding = padding
self.dilation = dilation
self.transposed = transposed
self.output_padding = output_padding
self.groups = groups
def forward(self, input, weight, bias=None):
k = input.dim()
self.save_for_backward(input, weight, bias)
input = input.contiguous()
if k == 3:
input, weight = _view4d(input, weight)
output = self._update_output(input, weight, bias)
if k == 3:
output, = _view3d(output)
return output
def backward(self, grad_output):
k = grad_output.dim()
grad_output = grad_output.contiguous()
input, weight, bias = self.saved_tensors
input = input.contiguous()
if k == 3:
grad_output, input, weight = _view4d(grad_output, input, weight)
grad_input = (self._grad_input(input, weight, grad_output)
if self.needs_input_grad[0] else None)
grad_weight, grad_bias = (
self._grad_params(input, weight, bias, grad_output)
if any(self.needs_input_grad[1:]) else (None, None))
if k == 3:
grad_input, grad_weight, = _view3d(grad_input, grad_weight)
return grad_input, grad_weight, grad_bias
def is_dilated(self):
return self.dilation != (1,) * len(self.dilation)
def _output_size(self, input, weight):
channels = (weight.size(1) * self.groups if self.transposed
else weight.size(0))
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = self.padding[d]
kernel = self.dilation[d] * (weight.size(d + 2) - 1) + 1
stride = self.stride[d]
if self.transposed:
out_pad = self.output_padding[d]
output_size += (
(in_size - 1) * stride - (2 * pad) + kernel + out_pad,)
else:
output_size += ((in_size + (2 * pad) - kernel) // stride + 1,)
if not all(map(lambda s: s > 0, output_size)):
raise ValueError("convolution input is too small (output would be {})".format(
'x'.join(map(str, output_size))))
return output_size
def _update_output(self, input, weight, bias):
self.use_cudnn = cudnn.is_acceptable(input)
if self.use_cudnn and cudnn.version() < 6000:
self.use_cudnn = not self.is_dilated()
if self.use_cudnn:
output = input.new(*self._output_size(input, weight))
if self.transposed:
self._cudnn_info = (
torch._C._cudnn_convolution_transpose_full_forward(
input, weight, bias, output, self.padding, self.stride, self.dilation,
self.groups, cudnn.benchmark))
else:
self._cudnn_info = torch._C._cudnn_convolution_full_forward(
input, weight, bias, output, self.padding, self.stride, self.dilation,
self.groups, cudnn.benchmark)
if not self.requires_grad:
del self._cudnn_info
return output
self._bufs = [[] for g in range(self.groups)]
output = self._thnn('update_output', input, weight, bias)
if not self.requires_grad:
del self._bufs
return output
def _grad_input(self, input, weight, grad_output):
if self.use_cudnn:
grad_input = input.new().resize_as_(input)
if self.transposed:
# ConvTranspose uses the same kernels as regular convolution
# but swaps forward and backward calls
torch._C._cudnn_convolution_forward(
grad_output, weight, grad_input, self._cudnn_info,
cudnn.benchmark)
else:
torch._C._cudnn_convolution_backward_data(
grad_output, grad_input, weight, self._cudnn_info,
cudnn.benchmark)
return grad_input
return self._thnn('grad_input', input, weight, grad_output)
def _grad_params(self, input, weight, bias, grad_output):
if self.use_cudnn:
grad_weight = grad_bias = None
if self.needs_input_grad[1]:
grad_weight = weight.new().resize_as_(weight)
torch._C._cudnn_convolution_backward_filter(
grad_output, input, grad_weight, self._cudnn_info,
cudnn.benchmark)
if bias is not None and self.needs_input_grad[2]:
grad_bias = bias.new().resize_as_(bias)
torch._C._cudnn_convolution_backward_bias(
grad_output, grad_bias, self._cudnn_info)
return grad_weight, grad_bias
return self._thnn('grad_params', input, weight, bias, grad_output)
def thnn_class_name(self, input):
assert input.dim() == 4 or input.dim() == 5
if self.transposed:
if input.dim() == 4:
return 'SpatialFullConvolution'
else:
return 'VolumetricFullConvolution'
elif self.is_dilated():
if input.dim() == 4:
return 'SpatialDilatedConvolution'
else:
return 'VolumetricDilatedConvolution'
elif input.dim() == 4:
return 'SpatialConvolutionMM'
elif input.dim() == 5 and input.is_cuda:
return 'VolumetricConvolution'
else:
return 'VolumetricConvolutionMM'
def _thnn(self, fn_name, input, weight, *args):
impl = _thnn_convs[self.thnn_class_name(input)]
if self.groups == 1:
return impl[fn_name](self, self._bufs[0], input, weight, *args)
else:
res = []
for g in range(self.groups):
def group(tensor, dim=None):
if tensor is None:
return None
if dim is None:
dim = 0 if tensor.dim() == 1 else 1
n = tensor.size(dim) // self.groups
return tensor.narrow(dim, n * g, n).contiguous()
grouped_args = [group(input, 1), group(weight, 0)]
grouped_args += [group(t) for t in args]
res.append(impl[fn_name](self, self._bufs[g], *grouped_args))
if fn_name == 'grad_params':
return [torch.cat(t, 0) if t[0] is not None else None
for t in zip(*res)]
else:
return torch.cat(res, 1)
def _view4d(*tensors):
# view 3d tensor as 4d (conv1d as conv2d)
output = []
for t in tensors:
assert t.dim() == 3
size = list(t.size())
size.insert(2, 1)
output += [t.view(*size)]
return output
def _view3d(*tensors):
# view 4d tensor as 3d
output = []
for t in tensors:
if t is None:
output += [None]
else:
assert t.dim() == 4 and t.size(2) == 1
output += [t.squeeze(2)]
return output
def parse_arguments(self, arguments, buffers, kernel_size):
idx = {'T': -3, 'H': -2, 'W': -1}
buf_idx = 0
params = []
for arg in arguments:
if arg.type == 'THTensor*':
params.append(buffers[buf_idx])
buf_idx += 1
elif arg.name == 'scale':
params.append(1.0)
elif arg.name.startswith('dil'):
params.append(self.dilation[idx[arg.name[-1]]])
elif arg.name[0] == 'k':
params.append(kernel_size[idx[arg.name[-1]]])
elif arg.name[0] == 'd':
params.append(self.stride[idx[arg.name[-1]]])
elif arg.name[0] == 'p':
params.append(self.padding[idx[arg.name[-1]]])
elif arg.name[0] == 'a':
params.append(self.output_padding[idx[arg.name[-1]]])
else:
raise RuntimeError('unexpected argument in THNN header: ' + arg)
return params
def make_update_output(fn):
def call_update_output(self, bufs, input, weight, bias):
backend = type2backend[type(input)]
bufs.extend([input.new(), input.new()])
output = input.new(*self._output_size(input, weight))
kernel_size = weight.size()[2:]
args = parse_arguments(self, fn.arguments[5:], bufs, kernel_size)
getattr(backend, fn.name)(backend.library_state, input, output, weight,
bias, *args)
return output
return call_update_output
def make_grad_input(fn):
def call_grad_input(self, bufs, input, weight, grad_output):
backend = type2backend[type(input)]
grad_input = input.new().resize_as_(input)
kernel_size = weight.size()[2:]
args = parse_arguments(self, fn.arguments[5:], bufs, kernel_size)
getattr(backend, fn.name)(backend.library_state, input, grad_output,
grad_input, weight, *args)
return grad_input
return call_grad_input
def make_grad_params(fn):
def call_grad_params(self, bufs, input, weight, bias, grad_output):
backend = type2backend[type(input)]
grad_weight = weight.new().resize_as_(weight).zero_()
grad_bias = None
if bias is not None and self.needs_input_grad[2]:
grad_bias = bias.new().resize_as_(bias).zero_()
kernel_size = weight.size()[2:]
args = parse_arguments(self, fn.arguments[5:], bufs, kernel_size)
getattr(backend, fn.name)(backend.library_state, input, grad_output,
grad_weight, grad_bias, *args)
return grad_weight, grad_bias
return call_grad_params
def _bind_functions():
classes = [
'SpatialConvolutionMM',
'VolumetricConvolution',
'VolumetricConvolutionMM',
'SpatialDilatedConvolution',
'VolumetricDilatedConvolution',
'SpatialFullConvolution',
'VolumetricFullConvolution',
]
fns = function_by_name
for name in classes:
_thnn_convs[name] = {
'update_output': make_update_output(fns[name + '_updateOutput']),
'grad_input': make_grad_input(fns[name + '_updateGradInput']),
'grad_params': make_grad_params(fns[name + '_accGradParameters']),
}
_bind_functions()