blob: b8ed87492cf3ff430a98330659c632668a322a32 [file] [log] [blame]
import torch
from .Module import Module
from .utils import clear
class SpatialAdaptiveMaxPooling(Module):
def __init__(self, w, h):
super(SpatialAdaptiveMaxPooling, self).__init__()
self.w = w
self.h = h
self.indices = None
def updateOutput(self, input):
if self.indices is None:
self.indices = input.new()
self.indices = self.indices.long()
self._backend.SpatialAdaptiveMaxPooling_updateOutput(
self._backend.library_state,
input,
self.output,
self.indices,
self.w,
self.h
)
return self.output
def updateGradInput(self, input, gradOutput):
self._backend.SpatialAdaptiveMaxPooling_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.indices
)
return self.gradInput
def clearState(self):
clear(self, 'indices')
return super(SpatialAdaptiveMaxPooling, self).clearState()