blob: 1de585147347e10cf94b8dadb26b34706bf1b9bc [file] [log] [blame]
import math
import torch
from torch.nn.functional import _Reduction
from .MSECriterion import MSECriterion
"""
This file implements a criterion for multi-class classification.
It learns an embedding per class, where each class' embedding
is a point on an (N-1)-dimensional simplex, where N is
the number of classes.
For example usage of this class, look at.c/criterion.md
Reference: http.//arxiv.org/abs/1506.08230
"""
class ClassSimplexCriterion(MSECriterion):
def __init__(self, nClasses):
super(ClassSimplexCriterion, self).__init__()
self.nClasses = nClasses
# embedding the simplex in a space of dimension strictly greater than
# the minimum possible (nClasses-1) is critical for effective training.
simp = self._regsplex(nClasses - 1)
self.simplex = torch.cat((simp, torch.zeros(simp.size(0), nClasses - simp.size(1))), 1)
self._target = torch.Tensor(nClasses)
self.output_tensor = None
def _regsplex(self, n):
"""
regsplex returns the coordinates of the vertices of a
regular simplex centered at the origin.
The Euclidean norms of the vectors specifying the vertices are
all equal to 1. The input n is the dimension of the vectors;
the simplex has n+1 vertices.
input:
n # dimension of the vectors specifying the vertices of the simplex
output:
a # tensor dimensioned (n+1, n) whose rows are
vectors specifying the vertices
reference:
http.//en.wikipedia.org/wiki/Simplex#Cartesian_coordinates_for_regular_n-dimensional_simplex_in_Rn
"""
a = torch.zeros(n + 1, n)
for k in range(n):
# determine the last nonzero entry in the vector for the k-th vertex
if k == 0:
a[k][k] = 1
else:
a[k][k] = math.sqrt(1 - a[k:k + 1, 0:k + 1].norm() ** 2)
# fill_ the k-th coordinates for the vectors of the remaining vertices
c = (a[k][k] ** 2 - 1 - 1 / n) / a[k][k]
a[k + 1:n + 2, k:k + 1].fill_(c)
return a
# handle target being both 1D tensor, and
# target being 2D tensor (2D tensor means.nt: anything)
def _transformTarget(self, target):
assert target.dim() == 1
nSamples = target.size(0)
self._target.resize_(nSamples, self.nClasses)
for i in range(nSamples):
self._target[i].copy_(self.simplex[int(target[i])])
def updateOutput(self, input, target):
self._transformTarget(target)
assert input.nelement() == self._target.nelement()
if self.output_tensor is None:
self.output_tensor = input.new(1)
self._backend.MSECriterion_updateOutput(
self._backend.library_state,
input,
self._target,
self.output_tensor,
_Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
self.output = self.output_tensor[0].item()
return self.output
def updateGradInput(self, input, target):
assert input.nelement() == self._target.nelement()
implicit_gradOutput = torch.Tensor([1]).type(input.type())
self._backend.MSECriterion_updateGradInput(
self._backend.library_state,
input,
self._target,
implicit_gradOutput,
self.gradInput,
_Reduction.legacy_get_enum(self.sizeAverage, True, emit_warning=False),
)
return self.gradInput
def getPredictions(self, input):
return torch.mm(input, self.simplex.t())
def getTopPrediction(self, input):
prod = self.getPredictions(input)
_, maxs = prod.max(prod.ndimension() - 1)
return maxs.view(-1)