blob: b4981f3fd577f4c699f6c16680ac974560a9f743 [file] [log] [blame]
import random
import math
import torch
from .Module import Module
class SpatialFullConvolutionMap(Module):
def __init__(self, conMatrix, kW, kH, dW=1, dH=1):
super(SpatialFullConvolutionMap, self).__init__()
self.kW = kW
self.kH = kH
self.dW = dW
self.dH = dH
self.connTable = conMatrix
self.nInputPlane = int(self.connTable.select(1, 0).max()) + 1
self.nOutputPlane = int(self.connTable.select(1, 1).max()) + 1
self.weight = torch.Tensor(self.connTable.size(0), kH, kW)
self.gradWeight = torch.Tensor(self.connTable.size(0), kH, kW)
self.bias = torch.Tensor(self.nOutputPlane)
self.gradBias = torch.Tensor(self.nOutputPlane)
self.reset()
def reset(self, stdv=None):
if stdv is not None:
stdv = stdv * math.sqrt(3)
self.weight.uniform_(-stdv, stdv)
self.bias.uniform_(-stdv, stdv)
else:
ninp = torch.Tensor(self.nOutputPlane).zero_()
for i in range(self.connTable.size(0)):
idx = int(self.connTable[i][1])
ninp[idx] += 1
for k in range(self.connTable.size(0)):
idx = int(self.connTable[k][1])
stdv = 1. / math.sqrt(self.kW * self.kH * ninp[idx])
self.weight[k].uniform_(-stdv, stdv)
for k in range(self.bias.size(0)):
stdv = 1. / math.sqrt(self.kW * self.kH * ninp[k])
# TODO: torch.uniform
self.bias[k] = random.uniform(-stdv, stdv)
def updateOutput(self, input):
self._backend.SpatialFullConvolutionMap_updateOutput(
self._backend.library_state,
input,
self.output,
self.weight,
self.bias,
self.connTable,
self.nInputPlane,
self.nOutputPlane,
self.dW, self.dH
)
return self.output
def updateGradInput(self, input, gradOutput):
self._backend.SpatialFullConvolutionMap_updateGradInput(
self._backend.library_state,
input,
gradOutput,
self.gradInput,
self.weight,
self.bias,
self.connTable,
self.nInputPlane,
self.nOutputPlane,
self.dW, self.dH
)
return self.gradInput
def accGradParameters(self, input, gradOutput, scale=1):
self._backend.SpatialFullConvolutionMap_accGradParameters(
self._backend.library_state,
input,
gradOutput,
self.gradWeight,
self.gradBias,
self.connTable,
self.nInputPlane,
self.nOutputPlane,
self.dW, self.dH,
scale
)