blob: 73056c89666c8ad963a8c47db81cd40b8f42e700 [file] [log] [blame]
import torch
from .SpatialConvolution import SpatialConvolution
class SpatialDilatedConvolution(SpatialConvolution):
def __init__(self, nInputPlane, nOutputPlane, kW, kH, dW=1, dH=1, padW=0, padH=None, dilationH=1, dilationW=None):
super(SpatialDilatedConvolution, self).__init__(nInputPlane, nOutputPlane, kW, kH, dW, dH, padW, padH)
self.dilationH = dilationH
self.dilationW = dilationW if dilationW is not None else dilationH
def updateOutput(self, input):
if self.finput is None:
self.finput = self.weight.new()
if self.fgradInput is None:
self.fgradInput = self.weight.new()
input = self._makeContiguous(input)
self._backend.SpatialDilatedConvolution_updateOutput(
self._backend.library_state,
input,
self.output,
self.weight,
self.bias,
self.finput,
self.fgradInput,
self.kW, self.kH,
self.dW, self.dH,
self.padW, self.padH,
self.dilationH, self.dilationW
)
return self.output
def updateGradInput(self, input, gradOutput):
if self.gradInput is None:
return
input, gradOutput = self._makeContiguous(input, gradOutput)
if self.fgradInput is None:
self.fgradInput = self.weight.new()
self._backend.SpatialDilatedConvolution_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.weight,
self.finput,
self.kW, self.kH,
self.dW, self.dH,
self.padW, self.padH,
self.dilationH, self.dilationW
)
return self.gradInput
def accGradParameters(self, input, gradOutput, scale=1):
input, gradOutput = self._makeContiguous(input, gradOutput)
if self.fgradInput is None:
self.fgradInput = self.weight.new()
self._backend.SpatialDilatedConvolution_accGradParameters(
self._backend.library_state,
input,
gradOutput,
self.gradWeight,
self.gradBias,
self.finput,
self.fgradInput,
self.kW, self.kH,
self.dW, self.dH,
self.padW, self.padH,
self.dilationH, self.dilationW,
scale
)
def __repr__(self):
s = super(SpatialConvolution, self).__repr__()
s += '({} -> {}, {}x{}'.format(self.nInputPlane, self.nOutputPlane, self.kW, self.kH)
if self.dW != 1 or self.dH != 1 or self.padW != 0 or self.padH != 0:
s += ', {}, {}'.format(self.dW, self.dH)
if self.padW != 0 or self.padH != 0:
s += ', {}, {}'.format(self.padW, self.padH)
s += ', {}, {}'.format(self.dilationW, self.dilationH)
s += ')'
if self.bias is None:
s += ' without bias'
return s