remove data-dependent shapes from some distributions (#84322)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84322
Approved by: https://github.com/voznesenskym
diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py
index e751c5a..e8f4bca 100644
--- a/torch/distributions/half_cauchy.py
+++ b/torch/distributions/half_cauchy.py
@@ -61,7 +61,7 @@
value = torch.as_tensor(value, dtype=self.base_dist.scale.dtype,
device=self.base_dist.scale.device)
log_prob = self.base_dist.log_prob(value) + math.log(2)
- log_prob[value.expand(log_prob.shape) < 0] = -inf
+ log_prob = torch.where(value >= 0, log_prob, -inf)
return log_prob
def cdf(self, value):
diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py
index 1c3f9e8..d5b1337 100644
--- a/torch/distributions/half_normal.py
+++ b/torch/distributions/half_normal.py
@@ -59,7 +59,7 @@
if self._validate_args:
self._validate_sample(value)
log_prob = self.base_dist.log_prob(value) + math.log(2)
- log_prob[value.expand(log_prob.shape) < 0] = -inf
+ log_prob = torch.where(value >= 0, log_prob, -inf)
return log_prob
def cdf(self, value):