blob: acf4c640830ed5a145344f083c8e6c91c9522d38 [file] [log] [blame]
import torch
from .Module import Module
class SpatialAveragePooling(Module):
def __init__(self, kW, kH, dW=1, dH=1, padW=0, padH=0):
super(SpatialAveragePooling, self).__init__()
self.kW = kW
self.kH = kH
self.dW = dW
self.dH = dH
self.padW = padW
self.padH = padH
self.ceil_mode = False
self.count_include_pad = True
self.divide = True
def ceil(self):
self.ceil_mode = True
return self
def floor(self):
self.ceil_mode = False
return self
def setCountIncludePad(self):
self.count_include_pad = True
return self
def setCountExcludePad(self):
self.count_include_pad = False
return self
def updateOutput(self, input):
self._backend.SpatialAveragePooling_updateOutput(
self._backend.library_state,
input,
self.output,
self.kW, self.kH,
self.dW, self.dH,
self.padW, self.padH,
self.ceil_mode,
self.count_include_pad
)
# for backward compatibility with saved models
# which are not supposed to have "divide" field
if not self.divide:
self.output.mul_(self.kW * self.kH)
return self.output
def updateGradInput(self, input, gradOutput):
if self.gradInput is not None:
self._backend.SpatialAveragePooling_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.kW, self.kH,
self.dW, self.dH,
self.padW, self.padH,
self.ceil_mode,
self.count_include_pad
)
# for backward compatibility
if not self.divide:
self.gradInput.mul_(self.kW * self.kH)
return self.gradInput
def __repr__(self):
s = super(SpatialAveragePooling, self).__repr__()
s += '({}x{}, {}, {}'.format(self.kW, self.kH, self.dW, self.dH)
if (self.padW or self.padH) and (self.padW != 0 or self.padH != 0):
s += ', {}, {}'.format(self.padW, self.padH)
s += ')'
return s