Adding mean, variance, stddev to distributions (#4923)
diff --git a/test/test_distributions.py b/test/test_distributions.py
index a0efb73..5425630 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -71,6 +71,13 @@
return Dist(*params1), Dist(*params2)
+def is_all_nan(tensor):
+ """
+ Checks if all entries of a tensor is nan.
+ """
+ return (tensor != tensor).all()
+
+
# Register all distributions for generic tests.
Example = namedtuple('Example', ['Dist', 'params'])
EXAMPLES = [
@@ -516,6 +523,8 @@
def test_categorical_1d(self):
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
+ self.assertTrue(is_all_nan(Categorical(p).mean))
+ self.assertTrue(is_all_nan(Categorical(p).variance))
self.assertEqual(Categorical(p).sample().size(), SCALAR_SHAPE)
self.assertTrue(isinstance(Categorical(p).sample().data, torch.LongTensor))
self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2))
@@ -528,6 +537,10 @@
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
p = Variable(torch.Tensor(probabilities), requires_grad=True)
s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
+ self.assertEqual(Categorical(p).mean.size(), (2,))
+ self.assertEqual(Categorical(p).variance.size(), (2,))
+ self.assertTrue(is_all_nan(Categorical(p).mean))
+ self.assertTrue(is_all_nan(Categorical(p).variance))
self.assertEqual(Categorical(p).sample().size(), (2,))
self.assertEqual(Categorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2))
self.assertEqual(Categorical(p).sample_n(6).size(), (6, 2))
@@ -667,6 +680,8 @@
scale = Variable(torch.ones(5, 5), requires_grad=True)
loc_1d = Variable(torch.zeros(1), requires_grad=True)
scale_1d = Variable(torch.ones(1), requires_grad=True)
+ self.assertTrue(is_all_nan(Cauchy(loc_1d, scale_1d).mean))
+ self.assertEqual(Cauchy(loc_1d, scale_1d).variance, float('inf'), allow_inf=True)
self.assertEqual(Cauchy(loc, scale).sample().size(), (5, 5))
self.assertEqual(Cauchy(loc, scale).sample_n(7).size(), (7, 5, 5))
self.assertEqual(Cauchy(loc_1d, scale_1d).sample().size(), (1,))
@@ -906,6 +921,8 @@
alpha = Variable(torch.randn(2, 3).abs(), requires_grad=True)
scale_1d = torch.randn(1).abs()
alpha_1d = torch.randn(1).abs()
+ self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).mean, float('inf'), allow_inf=True)
+ self.assertEqual(Pareto(scale_1d, torch.Tensor([0.5])).variance, float('inf'), allow_inf=True)
self.assertEqual(Pareto(scale, alpha).sample().size(), (2, 3))
self.assertEqual(Pareto(scale, alpha).sample_n(5).size(), (5, 2, 3))
self.assertEqual(Pareto(scale_1d, alpha_1d).sample_n(1).size(), (1, 1))
@@ -964,6 +981,8 @@
df2 = Variable(torch.randn(2, 3).abs(), requires_grad=True)
df1_1d = torch.randn(1).abs()
df2_1d = torch.randn(1).abs()
+ self.assertTrue(is_all_nan(FisherSnedecor(1, 2).mean))
+ self.assertTrue(is_all_nan(FisherSnedecor(1, 4).variance))
self.assertEqual(FisherSnedecor(df1, df2).sample().size(), (2, 3))
self.assertEqual(FisherSnedecor(df1, df2).sample_n(5).size(), (5, 2, 3))
self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample().size(), (1,))
@@ -1015,9 +1034,12 @@
'Chi2(df={})'.format(df))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
- def test_studentT_shape(self):
+ def test_studentT(self):
df = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
df_1d = Variable(torch.exp(torch.randn(1)), requires_grad=True)
+ self.assertTrue(is_all_nan(StudentT(1).mean))
+ self.assertTrue(is_all_nan(StudentT(1).variance))
+ self.assertEqual(StudentT(2).variance, float('inf'), allow_inf=True)
self.assertEqual(StudentT(df).sample().size(), (2, 3))
self.assertEqual(StudentT(df).sample_n(5).size(), (5, 2, 3))
self.assertEqual(StudentT(df_1d).sample_n(1).size(), (1, 1))
@@ -2170,6 +2192,104 @@
self.assertFalse('logits' in vars(dist), msg=message)
+@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
+class TestAgainstScipy(TestCase):
+ def setUp(self):
+ positive_var = Variable(torch.Tensor(20,).normal_()).exp()
+ positive_var2 = Variable(torch.Tensor(20,).normal_()).exp()
+ random_var = Variable(torch.Tensor(20,).normal_())
+ random_tensor = torch.Tensor(20,).normal_()
+ simplex_tensor = random_tensor.exp() / random_tensor.exp().sum()
+ self.distribution_pairs = [
+ (
+ Bernoulli(simplex_tensor),
+ scipy.stats.bernoulli(simplex_tensor)
+ ),
+ (
+ Beta(positive_var, positive_var2),
+ scipy.stats.beta(positive_var, positive_var2)
+ ),
+ (
+ Binomial(10, simplex_tensor),
+ scipy.stats.binom(10 * np.ones(simplex_tensor.shape), simplex_tensor)
+ ),
+ (
+ Dirichlet(positive_var),
+ scipy.stats.dirichlet(positive_var)
+ ),
+ (
+ Exponential(positive_var),
+ scipy.stats.expon(scale=1. / positive_var)
+ ),
+ (
+ FisherSnedecor(positive_var, 4 + positive_var2), # var for df2<=4 is undefined
+ scipy.stats.f(positive_var, 4 + positive_var2)
+ ),
+ (
+ Gamma(positive_var, positive_var2),
+ scipy.stats.gamma(positive_var, scale=1 / positive_var2)
+ ),
+ (
+ Geometric(simplex_tensor),
+ scipy.stats.geom(simplex_tensor, loc=-1)
+ ),
+ (
+ Gumbel(random_var, positive_var2),
+ scipy.stats.gumbel_r(random_var, positive_var2)
+ ),
+ (
+ Laplace(random_var, positive_var2),
+ scipy.stats.laplace(random_var, positive_var2)
+ ),
+ (
+ # Tests fail 1e-5 threshold if scale > 3
+ LogNormal(random_var, positive_var.clamp(max=3)),
+ scipy.stats.lognorm(s=positive_var.clamp(max=3), scale=random_var.exp())
+ ),
+ (
+ Multinomial(10, simplex_tensor),
+ scipy.stats.multinomial(10, simplex_tensor)
+ ),
+ (
+ Normal(random_var, positive_var2),
+ scipy.stats.norm(random_var, positive_var2)
+ ),
+ (
+ OneHotCategorical(simplex_tensor),
+ scipy.stats.multinomial(1, simplex_tensor)
+ ),
+ (
+ Pareto(positive_var, 2 + positive_var2),
+ scipy.stats.pareto(2 + positive_var2, scale=positive_var)
+ ),
+ (
+ Poisson(positive_var),
+ scipy.stats.poisson(positive_var)
+ ),
+ (
+ StudentT(2 + positive_var, random_var, positive_var2),
+ scipy.stats.t(2 + positive_var, random_var, positive_var2)
+ ),
+ (
+ Uniform(random_var, random_var + positive_var),
+ scipy.stats.uniform(random_var, positive_var)
+ )
+ ]
+
+ def test_mean(self):
+ for pytorch_dist, scipy_dist in self.distribution_pairs:
+ self.assertEqual(pytorch_dist.mean, scipy_dist.mean(), allow_inf=True, message=pytorch_dist)
+
+ def test_variance_stddev(self):
+ for pytorch_dist, scipy_dist in self.distribution_pairs:
+ if isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
+ self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist)
+ self.assertEqual(pytorch_dist.stddev, np.diag(scipy_dist.cov()) ** 0.5, message=pytorch_dist)
+ else:
+ self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist)
+ self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist)
+
+
class TestTransforms(TestCase):
def setUp(self):
self.transforms = []
diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py
index 4af83be..c6e0cea 100644
--- a/torch/distributions/bernoulli.py
+++ b/torch/distributions/bernoulli.py
@@ -49,6 +49,14 @@
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)
+ @property
+ def mean(self):
+ return self.probs
+
+ @property
+ def variance(self):
+ return self.probs * (1 - self.probs)
+
@lazy_property
def logits(self):
return probs_to_logits(self.probs, is_binary=True)
diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py
index 3b191b7..c7907e1 100644
--- a/torch/distributions/beta.py
+++ b/torch/distributions/beta.py
@@ -38,6 +38,16 @@
self._dirichlet = Dirichlet(concentration1_concentration0)
super(Beta, self).__init__(self._dirichlet._batch_shape)
+ @property
+ def mean(self):
+ return self.concentration1 / (self.concentration1 + self.concentration0)
+
+ @property
+ def variance(self):
+ total = self.concentration1 + self.concentration0
+ return (self.concentration1 * self.concentration0 /
+ (total.pow(2) * (total + 1)))
+
def rsample(self, sample_shape=()):
value = self._dirichlet.rsample(sample_shape).select(-1, 0)
if isinstance(value, Number):
diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py
index 6b88e34..ea91d77 100644
--- a/torch/distributions/binomial.py
+++ b/torch/distributions/binomial.py
@@ -4,7 +4,6 @@
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all, probs_to_logits, lazy_property, logits_to_probs
-from torch.distributions.utils import clamp_probs
from torch.autograd import Variable
@@ -61,6 +60,14 @@
def support(self):
return constraints.integer_interval(0, self.total_count)
+ @property
+ def mean(self):
+ return self.total_count * self.probs
+
+ @property
+ def variance(self):
+ return self.total_count * self.probs * (1 - self.probs)
+
@lazy_property
def logits(self):
return probs_to_logits(self.probs, is_binary=True)
diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py
index 5799e01..06dfd6a 100644
--- a/torch/distributions/categorical.py
+++ b/torch/distributions/categorical.py
@@ -68,6 +68,14 @@
def param_shape(self):
return self._param.size()
+ @property
+ def mean(self):
+ return self.probs.new([float('nan')]).expand(self._extended_shape())
+
+ @property
+ def variance(self):
+ return self.probs.new([float('nan')]).expand(self._extended_shape())
+
def sample(self, sample_shape=torch.Size()):
sample_shape = self._extended_shape(sample_shape)
param_shape = sample_shape + torch.Size((self._num_events,))
diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py
index 26f9068..6a3600b 100644
--- a/torch/distributions/cauchy.py
+++ b/torch/distributions/cauchy.py
@@ -36,6 +36,14 @@
batch_shape = self.loc.size()
super(Cauchy, self).__init__(batch_shape)
+ @property
+ def mean(self):
+ return self.loc.new([float('nan')]).expand(self._extended_shape())
+
+ @property
+ def variance(self):
+ return self.loc.new([float('inf')]).expand(self._extended_shape())
+
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = self.loc.new(shape).cauchy_()
diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py
index 763b50a..5a7009d 100644
--- a/torch/distributions/dirichlet.py
+++ b/torch/distributions/dirichlet.py
@@ -74,6 +74,15 @@
torch.lgamma(self.concentration.sum(-1)) -
torch.lgamma(self.concentration).sum(-1))
+ @property
+ def mean(self):
+ return self.concentration / self.concentration.sum(-1)
+
+ @property
+ def variance(self):
+ con0 = self.concentration.sum(-1)
+ return self.concentration * (con0 - self.concentration) / (con0.pow(2) * (con0 + 1))
+
def entropy(self):
k = self.concentration.size(-1)
a0 = self.concentration.sum(-1)
diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py
index 19341c9..4201def 100644
--- a/torch/distributions/distribution.py
+++ b/torch/distributions/distribution.py
@@ -46,6 +46,27 @@
"""
raise NotImplementedError
+ @property
+ def mean(self):
+ """
+ Returns the mean of the distribution.
+ """
+ raise NotImplementedError
+
+ @property
+ def variance(self):
+ """
+ Returns the variance of the distribution.
+ """
+ raise NotImplementedError
+
+ @property
+ def stddev(self):
+ """
+ Returns the standard deviation of the distribution.
+ """
+ return self.variance.sqrt()
+
def sample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py
index fd258e0..d22cfc6 100644
--- a/torch/distributions/exponential.py
+++ b/torch/distributions/exponential.py
@@ -24,6 +24,18 @@
support = constraints.positive
has_rsample = True
+ @property
+ def mean(self):
+ return self.rate.reciprocal()
+
+ @property
+ def stddev(self):
+ return self.rate.reciprocal()
+
+ @property
+ def variance(self):
+ return self.rate.pow(-2)
+
def __init__(self, rate):
self.rate, = broadcast_all(rate)
batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py
index 2001576..e844151 100644
--- a/torch/distributions/fishersnedecor.py
+++ b/torch/distributions/fishersnedecor.py
@@ -37,6 +37,18 @@
batch_shape = self.df1.size()
super(FisherSnedecor, self).__init__(batch_shape)
+ @property
+ def mean(self):
+ df2 = self.df2.clone()
+ df2[df2 <= 2] = float('nan')
+ return df2 / (df2 - 2)
+
+ @property
+ def variance(self):
+ df2 = self.df2.clone()
+ df2[df2 <= 4] = float('nan')
+ return 2 * df2.pow(2) * (self.df1 + df2 - 2) / (self.df1 * (df2 - 2).pow(2) * (df2 - 4))
+
def rsample(self, sample_shape=torch.Size(())):
shape = self._extended_shape(sample_shape)
# X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2)
diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py
index 23e75e1..38d4939 100644
--- a/torch/distributions/gamma.py
+++ b/torch/distributions/gamma.py
@@ -35,6 +35,14 @@
support = constraints.positive
has_rsample = True
+ @property
+ def mean(self):
+ return self.concentration / self.rate
+
+ @property
+ def variance(self):
+ return self.concentration / self.rate.pow(2)
+
def __init__(self, concentration, rate):
self.concentration, self.rate = broadcast_all(concentration, rate)
if isinstance(concentration, Number) and isinstance(rate, Number):
diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py
index 4d5bba1..df09821 100644
--- a/torch/distributions/geometric.py
+++ b/torch/distributions/geometric.py
@@ -49,6 +49,10 @@
def mean(self):
return 1. / self.probs - 1.
+ @property
+ def variance(self):
+ return (1. / self.probs - 1.) / self.probs
+
@lazy_property
def logits(self):
return probs_to_logits(self.probs, is_binary=True)
diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py
index 3aafd93..c4f8fcc 100644
--- a/torch/distributions/gumbel.py
+++ b/torch/distributions/gumbel.py
@@ -1,10 +1,12 @@
from numbers import Number
-
+import math
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _finfo, broadcast_all
+euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
+
class Gumbel(Distribution):
r"""
@@ -45,5 +47,17 @@
z = (value - self.loc) / self.scale
return -(self.scale.log() + z + torch.exp(-z))
+ @property
+ def mean(self):
+ return self.loc + self.scale * euler_constant
+
+ @property
+ def stddev(self):
+ return (math.pi / math.sqrt(6)) * self.scale
+
+ @property
+ def variance(self):
+ return self.stddev.pow(2)
+
def entropy(self):
- return self.scale.log() + 1.57721566490153286060 # 1 + Euler Mascheroni Constant
+ return self.scale.log() + (1 + euler_constant)
diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py
index 4a6fb41..ea96dbb 100644
--- a/torch/distributions/laplace.py
+++ b/torch/distributions/laplace.py
@@ -1,5 +1,4 @@
from numbers import Number
-
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
@@ -25,6 +24,18 @@
support = constraints.real
has_rsample = True
+ @property
+ def mean(self):
+ return self.loc
+
+ @property
+ def variance(self):
+ return 2 * self.scale.pow(2)
+
+ @property
+ def stddev(self):
+ return (2 ** 0.5) * self.scale
+
def __init__(self, loc, scale):
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, Number) and isinstance(scale, Number):
diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py
index bd2c697..1432c18 100644
--- a/torch/distributions/log_normal.py
+++ b/torch/distributions/log_normal.py
@@ -38,5 +38,13 @@
def scale(self):
return self.base_dist.scale
+ @property
+ def mean(self):
+ return (self.loc + self.scale.pow(2) / 2).exp()
+
+ @property
+ def variance(self):
+ return (self.scale.pow(2).exp() - 1) * (2 * self.loc + self.scale.pow(2)).exp()
+
def entropy(self):
return self.base_dist.entropy() + self.loc
diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py
index 2929fc2..b58ced9 100644
--- a/torch/distributions/multinomial.py
+++ b/torch/distributions/multinomial.py
@@ -42,6 +42,14 @@
"""
params = {'logits': constraints.real} # Let logits be the canonical parameterization.
+ @property
+ def mean(self):
+ return self.probs * self.total_count
+
+ @property
+ def variance(self):
+ return self.total_count * self.probs * (1 - self.probs)
+
def __init__(self, total_count=1, probs=None, logits=None):
if not isinstance(total_count, Number):
raise NotImplementedError('inhomogeneous total_count is not supported')
diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py
index 44ef346..1ddb10e 100644
--- a/torch/distributions/normal.py
+++ b/torch/distributions/normal.py
@@ -28,6 +28,18 @@
support = constraints.real
has_rsample = True
+ @property
+ def mean(self):
+ return self.loc
+
+ @property
+ def stddev(self):
+ return self.scale
+
+ @property
+ def variance(self):
+ return self.stddev.pow(2)
+
def __init__(self, loc, scale):
self.loc, self.scale = broadcast_all(loc, scale)
if isinstance(loc, Number) and isinstance(scale, Number):
diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py
index 52798c4..015ffdf 100644
--- a/torch/distributions/one_hot_categorical.py
+++ b/torch/distributions/one_hot_categorical.py
@@ -40,6 +40,22 @@
return self._categorical._new(*args, **kwargs)
@property
+ def probs(self):
+ return self._categorical.probs
+
+ @property
+ def logits(self):
+ return self._categorical.logits
+
+ @property
+ def mean(self):
+ return self._categorical.probs
+
+ @property
+ def variance(self):
+ return self._categorical.probs * (1 - self._categorical.probs)
+
+ @property
def param_shape(self):
return self._categorical.param_shape
diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py
index d03bc09..15854b9 100644
--- a/torch/distributions/pareto.py
+++ b/torch/distributions/pareto.py
@@ -34,6 +34,18 @@
batch_shape = self.scale.size()
super(Pareto, self).__init__(batch_shape)
+ @property
+ def mean(self):
+ # mean is inf for alpha <= 1
+ a = self.alpha.clone().clamp(min=1)
+ return a * self.scale / (a - 1)
+
+ @property
+ def variance(self):
+ # var is inf for alpha <= 2
+ a = self.alpha.clone().clamp(min=2)
+ return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2))
+
@constraints.dependent_property
def support(self):
return constraints.greater_than(self.scale)
diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py
index 21238b1..98e6644 100644
--- a/torch/distributions/poisson.py
+++ b/torch/distributions/poisson.py
@@ -33,6 +33,14 @@
params = {'rate': constraints.positive}
support = constraints.nonnegative_integer
+ @property
+ def mean(self):
+ return self.rate
+
+ @property
+ def variance(self):
+ return self.rate
+
def __init__(self, rate):
self.rate, = broadcast_all(rate)
if isinstance(rate, Number):
diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py
index 1ce538d..f3e19bf 100644
--- a/torch/distributions/studentT.py
+++ b/torch/distributions/studentT.py
@@ -25,6 +25,20 @@
support = constraints.real
has_rsample = True
+ @property
+ def mean(self):
+ m = self.loc.clone()
+ m[self.df <= 1] = float('nan')
+ return m
+
+ @property
+ def variance(self):
+ m = self.df.clone()
+ m[self.df > 2] = self.scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2)
+ m[(self.df <= 2) & (self.df > 1)] = float('inf')
+ m[self.df <= 1] = float('nan')
+ return m
+
def __init__(self, df, loc=0., scale=1.):
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
self._chi2 = Chi2(df)
diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py
index ae278f2..1d233a3 100644
--- a/torch/distributions/uniform.py
+++ b/torch/distributions/uniform.py
@@ -28,6 +28,18 @@
params = {'low': constraints.dependent, 'high': constraints.dependent}
has_rsample = True
+ @property
+ def mean(self):
+ return (self.high + self.low) / 2
+
+ @property
+ def stddev(self):
+ return (self.high - self.low) / 12**0.5
+
+ @property
+ def variance(self):
+ return (self.high - self.low).pow(2) / 12
+
def __init__(self, low, high):
self.low, self.high = broadcast_all(low, high)
if isinstance(low, Number) and isinstance(high, Number):