adds sample_n function (#3249)
* adds sample_n function
* fixes style issues
* uses more efficient api calls
* fix bug where transpose applied to 1 dimension
diff --git a/test/test_distributions.py b/test/test_distributions.py
index 0d6881c..0f69461 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -27,6 +27,10 @@
def test_bernoulli(self):
p = Variable(torch.Tensor([0.7, 0.2, 0.4]), requires_grad=True)
+ r = Variable(torch.Tensor([0.3]), requires_grad=True)
+ self.assertEqual(Bernoulli(p).sample_n(8).size(), (8, 3))
+ self.assertEqual(Bernoulli(r).sample_n(8).size(), (8, 1))
+ self.assertEqual(Bernoulli(r).sample().size(), (1,))
self._gradcheck_log_prob(Bernoulli, (p,))
def ref_log_prob(idx, val, log_prob):
@@ -38,17 +42,20 @@
def test_bernoulli_3d(self):
p = Variable(torch.Tensor(2, 3, 5).fill_(0.5), requires_grad=True)
self.assertEqual(Bernoulli(p).sample().size(), (2, 3, 5))
+ self.assertEqual(Bernoulli(p).sample_n(2).size(), (2, 2, 3, 5))
def test_multinomial_1d(self):
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
# TODO: this should return a 0-dim tensor once we have Scalar support
self.assertEqual(Multinomial(p).sample().size(), (1,))
+ self.assertEqual(Multinomial(p).sample_n(1).size(), (1, 1))
self._gradcheck_log_prob(Multinomial, (p,))
def test_multinomial_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
p = Variable(torch.Tensor(probabilities), requires_grad=True)
self.assertEqual(Multinomial(p).sample().size(), (2,))
+ self.assertEqual(Multinomial(p).sample_n(6).size(), (6, 2))
self._gradcheck_log_prob(Multinomial, (p,))
def ref_log_prob(idx, val, log_prob):
@@ -60,7 +67,15 @@
def test_normal(self):
mean = Variable(torch.randn(5, 5), requires_grad=True)
std = Variable(torch.randn(5, 5).abs(), requires_grad=True)
+ mean_1d = Variable(torch.randn(1), requires_grad=True)
+ std_1d = Variable(torch.randn(1), requires_grad=True)
self.assertEqual(Normal(mean, std).sample().size(), (5, 5))
+ self.assertEqual(Normal(mean, std).sample_n(7).size(), (7, 5, 5))
+ self.assertEqual(Normal(mean_1d, std_1d).sample_n(1).size(), (1, 1))
+ self.assertEqual(Normal(mean_1d, std_1d).sample().size(), (1,))
+ self.assertEqual(Normal(0.2, .6).sample_n(1).size(), (1, 1))
+ self.assertEqual(Normal(-0.7, 50.0).sample_n(1).size(), (1, 1))
+
self._gradcheck_log_prob(Normal, (mean, std))
self._gradcheck_log_prob(Normal, (mean, 1.0))
self._gradcheck_log_prob(Normal, (0.0, std))
diff --git a/torch/distributions.py b/torch/distributions.py
index 0b59188..e037424 100644
--- a/torch/distributions.py
+++ b/torch/distributions.py
@@ -34,6 +34,13 @@
"""
raise NotImplementedError
+ def sample_n(self, n):
+ """
+ Generates n samples or n batches of samples if the distribution parameters
+ are batched.
+ """
+ raise NotImplementedError
+
def log_prob(self, value):
"""
Returns the log of the probability density/mass function evaluated at
@@ -62,12 +69,16 @@
Args:
probs (Tensor or Variable): the probabilty of sampling `1`
"""
+
def __init__(self, probs):
self.probs = probs
def sample(self):
return torch.bernoulli(self.probs)
+ def sample_n(self, n):
+ return torch.bernoulli(self.probs.expand(n, *self.probs.size()))
+
def log_prob(self, value):
# compute the log probabilities for 0 and 1
log_pmf = (torch.stack([1 - self.probs, self.probs])).log()
@@ -99,6 +110,7 @@
Args:
probs (Tensor or Variable): event probabilities
"""
+
def __init__(self, probs):
if probs.dim() != 1 and probs.dim() != 2:
# TODO: treat higher dimensions as part of the batch
@@ -108,6 +120,12 @@
def sample(self):
return torch.multinomial(self.probs, 1, True).squeeze(-1)
+ def sample_n(self, n):
+ if n == 1:
+ return self.sample().expand(1, 1)
+ else:
+ return torch.multinomial(self.probs, n, True).t()
+
def log_prob(self, value):
p = self.probs / self.probs.sum(-1, keepdim=True)
if value.dim() == 1 and self.probs.dim() == 1:
@@ -133,6 +151,7 @@
mean (float or Tensor or Variable): mean of the distribution
std (float or Tensor or Variable): standard deviation of the distribution
"""
+
def __init__(self, mean, std):
self.mean = mean
self.std = std
@@ -140,6 +159,15 @@
def sample(self):
return torch.normal(self.mean, self.std)
+ def sample_n(self, n):
+ # cleanly expand float or Tensor or Variable parameters
+ def expand(v):
+ if isinstance(v, Number):
+ return torch.Tensor([v]).expand(n, 1)
+ else:
+ return v.expand(n, *v.size())
+ return torch.normal(expand(self.mean), expand(self.std))
+
def log_prob(self, value):
# compute the variance
var = (self.std ** 2)