blob: 67a79a965f072b39b6875023fa85b2f61268b1ff [file] [log] [blame]
import torch
from .Module import Module
class SpatialReplicationPadding(Module):
def __init__(self, pad_l, pad_r=None, pad_t=None, pad_b=None):
super(SpatialReplicationPadding, self).__init__()
self.pad_l = pad_l
self.pad_r = pad_r if pad_r is not None else pad_l
self.pad_t = pad_t if pad_t is not None else pad_l
self.pad_b = pad_b if pad_b is not None else pad_l
def updateOutput(self, input):
assert input.dim() == 4
self._backend.SpatialReplicationPadding_updateOutput(
self._backend.library_state,
input,
self.output,
self.pad_l, self.pad_r, self.pad_t, self.pad_b
)
return self.output
def updateGradInput(self, input, gradOutput):
assert input.dim() == 4 and gradOutput.dim() == 4
assert input.size(0) == gradOutput.size(0) and \
input.size(1) == gradOutput.size(1) and \
input.size(2) + self.pad_t + self.pad_b == gradOutput.size(2) and \
input.size(3) + self.pad_l + self.pad_r == gradOutput.size(3)
self._backend.SpatialReplicationPadding_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.pad_l, self.pad_r, self.pad_t, self.pad_b
)
return self.gradInput
def __repr__(self):
s = super(SpatialReplicationPadding, self).__repr__()
s += '({}, {}, {}, {})'.format(self.pad_l, self.pad_r, self.pad_t, self.pad_b)
return s