blob: de0eeadbc8578bc2c559a48dfa96d95ce1bf0dce [file] [log] [blame]
import math
import torch
from .Module import Module
class SpatialFractionalMaxPooling(Module):
# Usage:
# nn.SpatialFractionalMaxPooling(poolSizeW, poolSizeH, outW, outH)
# the output should be the exact size (outH x outW)
# nn.SpatialFractionalMaxPooling(poolSizeW, poolSizeH, ratioW, ratioH)
# the output should be the size (floor(inH x ratioH) x floor(inW x ratioW))
# ratios are numbers between (0, 1) exclusive
def __init__(self, poolSizeW, poolSizeH, arg1, arg2):
super(SpatialFractionalMaxPooling, self).__init__()
assert poolSizeW >= 2
assert poolSizeH >= 2
# Pool size (how wide the pooling for each output unit is)
self.poolSizeW = poolSizeW
self.poolSizeH = poolSizeH
# Random samples are drawn for all
# batch * plane * (height, width; i.e., 2) points. This determines
# the 2d "pseudorandom" overlapping pooling regions for each
# (batch element x input plane). A new set of random samples is
# drawn every updateOutput call, unless we disable it via
# .fixPoolingRegions().
self.randomSamples = None
# Flag to disable re-generation of random samples for producing
# a new pooling. For testing purposes
self.newRandomPool = False
self.indices = None
if arg1 >= 1 and arg2 >= 1:
# Desired output size: the input tensor will determine the reduction
# ratio
self.outW = arg1
self.outH = arg2
self.ratioW = self.ratioH = None
else:
# Reduction ratio specified per each input
# This is the reduction ratio that we use
self.ratioW = arg1
self.ratioH = arg2
self.outW = self.outH = None
# The reduction ratio must be between 0 and 1
assert self.ratioW > 0 and self.ratioW < 1
assert self.ratioH > 0 and self.ratioH < 1
def _getBufferSize(self, input):
assert input.ndimension() == 4
batchSize = input.size(0)
planeSize = input.size(1)
return torch.Size([batchSize, planeSize, 2])
def _initSampleBuffer(self, input):
sampleBufferSize = self._getBufferSize(input)
if self.randomSamples is None:
self.randomSamples = input.new().resize_(sampleBufferSize).uniform_()
elif self.randomSamples.size(0) != sampleBufferSize[0] or self.randomSamples.size(1) != sampleBufferSize[1]:
self.randomSamples.resize_(sampleBufferSize).uniform_()
elif not self.newRandomPool:
# Create new pooling windows, since this is a subsequent call
self.randomSamples.uniform_()
def _getOutputSizes(self, input):
outW = self.outW
outH = self.outH
if self.ratioW is not None and self.ratioH is not None:
assert input.ndimension() == 4
outW = int(math.floor(input.size(3) * self.ratioW))
outH = int(math.floor(input.size(2) * self.ratioH))
# Neither can be smaller than 1
assert outW > 0
assert outH > 0
else:
assert outW is not None and outH is not None
return outW, outH
# Call this to turn off regeneration of random pooling regions each
# updateOutput call.
def fixPoolingRegions(self, val=True):
self.newRandomPool = val
return self
def updateOutput(self, input):
if self.indices is None:
self.indices = input.new()
self.indices = self.indices.long()
self._initSampleBuffer(input)
outW, outH = self._getOutputSizes(input)
self._backend.SpatialFractionalMaxPooling_updateOutput(
self._backend.library_state,
input,
self.output,
outW, outH, self.poolSizeW, self.poolSizeH,
self.indices, self.randomSamples)
return self.output
def updateGradInput(self, input, gradOutput):
assert self.randomSamples is not None
outW, outH = self._getOutputSizes(input)
self._backend.SpatialFractionalMaxPooling_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
outW, outH, self.poolSizeW, self.poolSizeH,
self.indices)
return self.gradInput
# backward compat
def empty(self):
self.clearState()
def clearState(self):
self.indices = None
self.randomSamples = None
return super(SpatialFractionalMaxPooling, self).clearState()
def __repr__(self):
return super(SpatialFractionalMaxPooling, self).__repr__() + \
'({}x{}, {}, {})'.format(self.outW or self.ratioW,
self.outH or self.ratioH,
self.poolSizeW, self.poolSizeH)