blob: 88635e23959f814f3c5f3075525cb692da06dce7 [file] [log] [blame]
from torch.autograd import Function
class ConstantPad2d(Function):
def __init__(self, pad, value=0):
super(ConstantPad2d, self).__init__()
self.pad = pad
self.value = value
def forward(self, input):
assert input.dim() == 4, 'only 4D supported for padding'
pad_l, pad_r, pad_t, pad_b = self.pad
h = input.size(2) + pad_t + pad_b
w = input.size(3) + pad_l + pad_r
assert w > 0 and h > 0, 'input is too small'
self.input_size = input.size()
# crop input if necessary
output = input.new(input.size(0), input.size(1), h, w).fill_(self.value)
c_input = input
if pad_t < 0:
c_input = c_input.narrow(2, -pad_t, c_input.size(2) + pad_t)
if pad_b < 0:
c_input = c_input.narrow(2, 0, c_input.size(2) + pad_b)
if pad_l < 0:
c_input = c_input.narrow(3, -pad_l, c_input.size(3) + pad_l)
if pad_r < 0:
c_input = c_input.narrow(3, 0, c_input.size(3) + pad_r)
# crop output if necessary
c_output = output
if pad_t > 0:
c_output = c_output.narrow(2, pad_t, c_output.size(2) - pad_t)
if pad_b > 0:
c_output = c_output.narrow(2, 0, c_output.size(2) - pad_b)
if pad_l > 0:
c_output = c_output.narrow(3, pad_l, c_output.size(3) - pad_l)
if pad_r > 0:
c_output = c_output.narrow(3, 0, c_output.size(3) - pad_r)
c_output.copy_(c_input)
return output
def backward(self, grad_output):
pad_l, pad_r, pad_t, pad_b = self.pad
grad_input = grad_output.new(self.input_size).zero_()
# crop grad_input if necessary
cg_input = grad_input
if pad_t < 0:
cg_input = cg_input.narrow(2, -pad_t, cg_input.size(2) + pad_t)
if pad_b < 0:
cg_input = cg_input.narrow(2, 0, cg_input.size(2) + pad_b)
if pad_l < 0:
cg_input = cg_input.narrow(3, -pad_l, cg_input.size(3) + pad_l)
if pad_r < 0:
cg_input = cg_input.narrow(3, 0, cg_input.size(3) + pad_r)
# crop grad_output if necessary
cg_output = grad_output
if pad_t > 0:
cg_output = cg_output.narrow(2, pad_t, cg_output.size(2) - pad_t)
if pad_b > 0:
cg_output = cg_output.narrow(2, 0, cg_output.size(2) - pad_b)
if pad_l > 0:
cg_output = cg_output.narrow(3, pad_l, cg_output.size(3) - pad_l)
if pad_r > 0:
cg_output = cg_output.narrow(3, 0, cg_output.size(3) - pad_r)
cg_input.copy_(cg_output)
return grad_input