blob: 0c5cdf0b0f674a52b35e6b0fd1ffd03ed8659f24 [file] [log] [blame]
import torch
from .Module import Module
class Padding(Module):
# pad puts in [pad] amount of [value] over dimension [dim], starting at
# index [index] in that dimension. If pad<0, index counts from the left.
# If pad>0 index counts from the right index = 1 pads before index 1.
# index = 2 pads starting before index 2 and after index 1 in dimension [dim]
# When nInputDim is provided, inputs larger than that value will be considered batches
# where the actual dim to be padded will be dimension dim + 1.
def __init__(self, dim, pad, value=0, index=0, nInputDim=0):
self.value = value
self.index = index
self.dim = dim
self.pad = pad
self.nInputDim = nInputDim
self.outputSize = torch.Size()
super(Padding, self).__init__()
def updateOutput(self, input):
dim = self.dim
if hasattr(self, "nInputDim") and self.nInputDim > 0 and input.dim() != self.nInputDim:
dim = dim + 1
outputSize = list(input.size())
outputSize[dim] += abs(self.pad)
self.outputSize = torch.Size(outputSize)
self.output.resize_(self.outputSize)
self.output.fill_(self.value)
index = self.index
pad = self.pad
if pad > 0:
index = input.size(dim) - index
else:
pad = -pad
if index == 0:
self.output.narrow(dim, pad, input.size(dim)).copy_(input)
elif index == input.size(dim):
self.output.narrow(dim, 0, input.size(dim)).copy_(input)
else:
self.output.narrow(dim, 0, index).copy_(input.narrow(dim, 0, index))
self.output.narrow(dim, index + pad, input.size(dim) -
index).copy_(input.narrow(dim, index, input.size(dim) - index))
return self.output
def updateGradInput(self, input, gradOutput):
self.gradInput.resize_as_(input)
dim = self.dim
if hasattr(self, "nInputDim") and self.nInputDim > 0 and input.dim() != self.nInputDim:
dim = dim + 1
index = self.index
pad = self.pad
if pad > 0:
index = input.size(dim) - index
else:
pad = -pad
if index == 0:
self.gradInput.copy_(gradOutput.narrow(dim, pad, input.size(dim)))
elif index == input.size(dim):
self.gradInput.copy_(gradOutput.narrow(dim, 0, input.size(dim)))
else:
self.gradInput.narrow(dim, 0, index).copy_(gradOutput.narrow(dim, 0, index))
self.gradInput.narrow(dim, index, input.size(
dim) - index).copy_(gradOutput.narrow(dim, index + pad, input.size(dim) - index))
return self.gradInput