blob: 823ab058466414f1c12f69a4e3fbedb36f8313ab [file] [log] [blame]
import torch
from .Module import Module
from .utils import clear
class VolumetricMaxPooling(Module):
def __init__(self, kT, kW, kH, dT=None, dW=None, dH=None, padT=0, padW=0, padH=0):
super(VolumetricMaxPooling, self).__init__()
self.kT = kT
self.kH = kH
self.kW = kW
self.dT = dT or kT
self.dW = dW or kW
self.dH = dH or kH
self.padT = padT
self.padW = padW
self.padH = padH
self.ceil_mode = False
self.indices = torch.LongTensor()
def ceil(self):
self.ceil_mode = True
return self
def floor(self):
self.ceil_mode = False
return self
def updateOutput(self, input):
dims = input.dim()
self.itime = input.size(dims - 3)
self.iheight = input.size(dims - 2)
self.iwidth = input.size(dims - 1)
if self.indices is None:
self.indices = input.new()
self.indices = self.indices.long()
self._backend.VolumetricMaxPooling_updateOutput(
self._backend.library_state,
input,
self.output,
self.indices,
self.kT, self.kW, self.kH,
self.dT, self.dW, self.dH,
self.padT, self.padW, self.padH,
self.ceil_mode
)
return self.output
def updateGradInput(self, input, gradOutput):
self._backend.VolumetricMaxPooling_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.indices,
self.kT, self.kW, self.kH,
self.dT, self.dW, self.dH,
self.padT, self.padW, self.padH,
self.ceil_mode
)
return self.gradInput
def clearState(self):
clear(self, 'indices')
return super(VolumetricMaxPooling, self).clearState()
def __repr__(self):
s = super(VolumetricMaxPooling, self).__repr__()
s += '({}x{}x{}, {}, {}, {}'.format(self.kT, self.kW, self.kH, self.dT, self.dW, self.dH)
if self.padT != 0 or self.padW != 0 or self.padH != 0:
s += ', {}, {}, {}'.format(self.padT, self.padW, self.padH)
s += ')'
return s