blob: b23a18a8376a60f771734c75b6891fcdd03b6f84 [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/CPUApplyUtils.h>
#include <ATen/Config.h>
#include <ATen/Dispatch.h>
#include <ATen/ExpandUtils.h>
#include <ATen/NativeFunctions.h>
#include <c10/util/Exception.h>
#include <c10/util/math_compat.h>
#include <c10/util/Optional.h>
#include <ATen/Utils.h>
#include <ATen/CPUGeneratorImpl.h>
#include <ATen/core/DistributionsHelper.h>
#include <ATen/native/Distributions.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/DistributionTemplates.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/native/cpu/Loops.h>
#include <type_traits>
#include <functional>
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <assert.h>
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <float.h>
namespace {
/*
* This section is a counterpart to Distributions.cu
*
*/
// The function `sample_poisson`
// is adapted from Numpy's distributions.c implementation.
// It is MIT licensed, so here is the copyright:
/* Copyright 2005 Robert Kern (robert.kern@gmail.com)
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
int64_t sample_poisson(double lambda, at::CPUGeneratorImpl* generator) {
TORCH_CHECK(lambda >= 0, "invalid Poisson rate, expected rate to be non-negative");
at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
if (lambda >= 10) {
// transformed rejection method, (Hoermann, 1993)
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t k;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double U, V, a, b, invalpha, vr, us;
double slam = std::sqrt(lambda);
double loglam = std::log(lambda);
b = 0.931 + 2.53 * slam;
a = -0.059 + 0.02483 * b;
invalpha = 1.1239 + 1.1328 / (b - 3.4);
vr = 0.9277 - 3.6224 / (b - 2);
// NOLINTNEXTLINE(modernize-use-bool-literals)
while (1) {
U = standard_uniform(generator) - 0.5;
V = standard_uniform(generator);
us = 0.5 - std::fabs(U);
k = (int64_t)std::floor((2 * a / us + b) * U + lambda + 0.43);
if ((us >= 0.07) && (V <= vr)) {
return k;
}
if ((k < 0) || ((us < 0.013) && (V > us))) {
continue;
}
if ((std::log(V) + std::log(invalpha) - std::log(a / (us * us) + b)) <=
(-lambda + k * loglam - std::lgamma((double)k + 1))) {
return k;
}
}
} else if (lambda == 0) {
return 0;
} else {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t X;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double prod, U, enlam;
enlam = std::exp(-lambda);
X = 0;
prod = 1.0;
// NOLINTNEXTLINE(modernize-use-bool-literals)
while (1) {
U = standard_uniform(generator);
prod *= U;
if (prod > enlam) {
X += 1;
} else {
return X;
}
}
}
}
} // namespace
namespace at {
namespace native {
DEFINE_DISPATCH(bernoulli_tensor_stub);
DEFINE_DISPATCH(bernoulli_scalar_stub);
DEFINE_DISPATCH(cauchy_stub);
DEFINE_DISPATCH(exponential_stub);
DEFINE_DISPATCH(multinomial_with_replacement_stub);
DEFINE_DISPATCH(geometric_stub);
DEFINE_DISPATCH(log_normal_stub);
DEFINE_DISPATCH(uniform_stub);
DEFINE_DISPATCH(normal_stub);
DEFINE_DISPATCH(random_stub);
DEFINE_DISPATCH(random_from_to_stub);
DEFINE_DISPATCH(random_full_64_bits_range_stub);
// ==================================================== Bernoulli =====================================================
template<typename RNG>
struct BernoulliStub {
void operator()(Tensor& self, const Tensor& p_, c10::optional<Generator> gen) {
bernoulli_tensor_stub(self.device().type(), self, p_, gen);
}
void operator()(Tensor& self, double p, c10::optional<Generator> gen) {
bernoulli_scalar_stub(self.device().type(), self, p, gen);
}
};
Tensor bernoulli(const Tensor& self, c10::optional<Generator> gen) {
Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
result.bernoulli_(self, gen);
return result;
}
Tensor bernoulli(const Tensor& self, double p, c10::optional<Generator> gen) {
Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
result.bernoulli_(p, gen);
return result;
}
Tensor& bernoulli_out(const Tensor& self, c10::optional<Generator> gen, Tensor& result) {
return at::native::templates::bernoulli_out_impl<BernoulliStub, Generator>(result, self, gen);
}
Tensor& bernoulli_(Tensor& self, const Tensor& p_, c10::optional<Generator> gen) {
return at::native::templates::bernoulli_impl_<BernoulliStub, Generator>(self, p_, gen);
}
Tensor& bernoulli_(Tensor& self, double p, c10::optional<Generator> gen) {
return at::native::templates::bernoulli_impl_<BernoulliStub, Generator>(self, p, gen);
}
// ================================================== LogNormal =======================================================
template<typename RNG>
struct LogNormalStub {
void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional<Generator> gen) {
log_normal_stub(iter.device_type(), iter, mean, std, gen);
}
};
Tensor& log_normal_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
return at::native::templates::log_normal_impl_<LogNormalStub, Generator>(self, mean, std, gen);
}
// ==================================================== Cauchy ========================================================
template<typename RNG>
struct CauchyStub {
void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional<Generator> gen) {
cauchy_stub(iter.device_type(), iter, median, sigma, gen);
}
};
Tensor& cauchy_(Tensor& self, double median, double sigma, c10::optional<Generator> gen) {
return at::native::templates::cauchy_impl_<CauchyStub, Generator>(self, median, sigma, gen);
}
// ================================================== Exponential =====================================================
template<typename RNG>
struct ExponentialStub {
void operator()(TensorIteratorBase& iter, double lambda, c10::optional<Generator> gen) {
exponential_stub(iter.device_type(), iter, lambda, gen);
}
};
Tensor& exponential_(Tensor& self, double lambda, c10::optional<Generator> gen) {
return at::native::templates::exponential_impl_<ExponentialStub, Generator>(self, lambda, gen);
}
// =================================================== Geometric ======================================================
template<typename RNG>
struct GeometricStub {
void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
geometric_stub(iter.device_type(), iter, p, gen);
}
};
Tensor& geometric_(Tensor& self, double p, c10::optional<Generator> gen) {
return at::native::templates::geometric_impl_<GeometricStub, Generator>(self, p, gen);
}
// ==================================================== Uniform =======================================================
template<typename RNG>
struct UniformStub {
void operator()(TensorIteratorBase& iter, double from, double to, c10::optional<Generator> gen) {
uniform_stub(iter.device_type(), iter, from, to, gen);
}
};
template<typename RNG>
struct UniformMeta {
// No-op!
void operator()(TensorIteratorBase& iter, double from, double to, c10::optional<Generator> gen) {
}
};
Tensor& uniform_(Tensor& self, double from, double to, c10::optional<Generator> gen) {
return at::native::templates::uniform_impl_<UniformStub, Generator>(self, from, to, gen);
}
Tensor& uniform_meta_(Tensor& self, double from, double to, c10::optional<Generator> gen) {
return at::native::templates::uniform_impl_<UniformMeta, Generator>(self, from, to, gen);
}
// ==================================================== Normal ========================================================
template<typename RNG>
struct NormalStub {
void operator()(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
normal_stub(self.device().type(), self, mean, std, gen);
}
};
template<typename RNG>
struct NormalMeta {
// No-op!
void operator()(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
}
};
// inplace
Tensor& normal_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
return at::native::templates::normal_impl_<NormalStub, Generator>(self, mean, std, gen);
}
Tensor& normal_meta_(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
return at::native::templates::normal_impl_<NormalMeta, Generator>(self, mean, std, gen);
}
// out tensor float
Tensor& normal_out(const Tensor& mean, double std, c10::optional<Generator> gen, Tensor& output) {
return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, gen);
}
Tensor& normal_out_meta(const Tensor& mean, double std, c10::optional<Generator> gen, Tensor& output) {
return at::native::templates::normal_out_impl<NormalMeta, Generator>(output, mean, std, gen);
}
// out float tensor
Tensor& normal_out(double mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, gen);
}
Tensor& normal_out_meta(double mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
return at::native::templates::normal_out_impl<NormalMeta, Generator>(output, mean, std, gen);
}
// out tensor tensor
Tensor& normal_out(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
return at::native::templates::normal_out_impl<NormalStub, Generator>(output, mean, std, gen);
}
Tensor& normal_out_meta(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
return at::native::templates::normal_out_impl<NormalMeta, Generator>(output, mean, std, gen);
}
// functional tensor float
Tensor normal(const Tensor& mean, double std, c10::optional<Generator> gen) {
return at::native::templates::normal_impl<NormalStub, Generator>(mean, std, gen);
}
Tensor normal_meta(const Tensor& mean, double std, c10::optional<Generator> gen) {
return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, gen);
}
// functional float tensor
Tensor normal(double mean, const Tensor& std, c10::optional<Generator> gen) {
return at::native::templates::normal_impl<NormalStub, Generator>(mean, std, gen);
}
Tensor normal_meta(double mean, const Tensor& std, c10::optional<Generator> gen) {
return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, gen);
}
// functional tensor tensor
Tensor normal(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
return at::native::templates::normal_impl<NormalStub, Generator>(mean, std, gen);
}
Tensor normal_meta(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen) {
return at::native::templates::normal_impl<NormalMeta, Generator>(mean, std, gen);
}
// ==================================================== Random ========================================================
template<typename RNG>
struct RandomStub {
void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
random_stub(iter.device_type(), iter, gen);
}
};
Tensor& random_(Tensor& self, c10::optional<Generator> gen) {
return at::native::templates::random_impl<RandomStub, Generator>(self, gen);
}
template<typename RNG>
struct RandomFromToStub {
void operator()(TensorIteratorBase& iter, uint64_t range, int64_t from, c10::optional<Generator> gen) {
random_from_to_stub(iter.device_type(), iter, range, from, gen);
}
void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
random_full_64_bits_range_stub(iter.device_type(), iter, gen);
}
};
template<typename RNG>
struct RandomFromToMeta {
// No-op!
void operator()(TensorIteratorBase& iter, uint64_t range, int64_t from, c10::optional<Generator> gen) {
}
void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
}
};
Tensor& random_(Tensor& self, int64_t from, optional<int64_t> to, c10::optional<Generator> gen) {
return at::native::templates::random_from_to_impl<RandomFromToStub, Generator>(self, from, to, gen);
}
Tensor& random_(Tensor& self, int64_t to, c10::optional<Generator> gen) {
return random_(self, 0, to, gen);
}
Tensor& random_meta_(Tensor& self, c10::optional<Generator> gen) {
// No error checking yay
return self;
}
Tensor& random_meta_(Tensor& self, int64_t from, optional<int64_t> to, c10::optional<Generator> gen) {
return at::native::templates::random_from_to_impl<RandomFromToMeta, Generator>(self, from, to, gen);
}
Tensor& random_meta_(Tensor& self, int64_t to, c10::optional<Generator> gen) {
return random_meta_(self, 0, to, gen);
}
// ====================================================================================================================
Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
Tensor ret = at::empty(self.sizes(), self.options());
auto iter = TensorIteratorConfig()
.add_output(ret)
.add_input(self)
.add_input(output)
.build();
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "_standard_gamma_grad_cpu", [&] {
cpu_serial_kernel(iter, [](scalar_t self_val, scalar_t output_val) -> scalar_t{
return standard_gamma_grad_one<scalar_t, double>(self_val, output_val);
});
});
return ret;
}
Tensor _dirichlet_grad_cpu(const Tensor& x, const Tensor& alpha, const Tensor& total) {
Tensor ret = at::empty(x.sizes(), x.options());
auto iter = TensorIteratorConfig()
.add_output(ret)
.add_input(x)
.add_input(alpha)
.add_input(total)
.build();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "_dirichlet_grad_cpu", [&] {
cpu_serial_kernel(iter, [](scalar_t x_val, scalar_t alpha_val, scalar_t total_val) -> scalar_t{
return dirichlet_grad_one<scalar_t, double>(x_val, alpha_val, total_val);
});
});
return ret;
}
/*
* This section is a counterpart to Distributions.cu
*/
Tensor _s_binomial_cpu(const Tensor& count, const Tensor& prob, c10::optional<Generator> gen) {
Tensor ret = at::zeros(count.sizes(), count.options());
auto iter = TensorIteratorConfig()
.add_output(ret)
.add_input(count)
.add_input(prob)
.build();
AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "binomial_cpu", [&] {
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
cpu_serial_kernel(iter, [generator](scalar_t count_val, scalar_t prob_val) -> scalar_t{
auto uniform_lambda = [generator] () {
at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
return standard_uniform(generator);
};
BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
auto sample = sample_binomial<scalar_t, double, decltype(uniform_lambda)>(count_val, prob_val, standard_uniform);
return static_cast<scalar_t>(sample);
});
});
return ret;
}
Tensor _s_poisson_cpu(const Tensor& lambda, c10::optional<Generator> gen) {
Tensor ret = at::zeros(lambda.sizes(), lambda.options());
auto iter = TensorIteratorConfig()
.add_output(ret)
.add_input(lambda)
.build();
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, ret.scalar_type(), "poisson_cpu", [&] {
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
cpu_serial_kernel(iter, [generator](scalar_t lambda_val) -> scalar_t{
return static_cast<scalar_t>(sample_poisson(static_cast<double>(lambda_val), generator));
});
});
return ret;
}
Tensor _s_gamma_cpu(const Tensor& alpha, c10::optional<Generator> gen) {
Tensor ret = at::zeros(alpha.sizes(), alpha.options());
auto iter = TensorIteratorConfig()
.add_output(ret)
.add_input(alpha)
.build();
AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "gamma_cpu", [&] {
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
cpu_serial_kernel(iter, [generator](scalar_t alpha_val) -> scalar_t{
auto uniform_lambda = [generator] () {
at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
return standard_uniform(generator);
};
BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
auto normal_lambda = [generator] () {
at::normal_distribution<double> normal(0.0, 1.0);
return normal(generator);
};
BaseSampler<double, decltype(normal_lambda)> standard_normal(normal_lambda);
auto sample = sample_gamma<scalar_t, double, decltype(uniform_lambda), decltype(normal_lambda)>(alpha_val, standard_uniform, standard_normal);
return std::max(std::numeric_limits<scalar_t>::min(), (scalar_t) sample);
});
});
return ret;
}
Tensor _s_dirichlet_cpu(const Tensor& alpha, c10::optional<Generator> gen) {
Tensor ret = at::zeros(alpha.sizes(), alpha.options());
AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "dirichlet", [&] {
Tensor gamma = at::zeros(alpha.sizes(), alpha.options().dtype(ScalarType::Double));
CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(generator->mutex_);
/* Generate gamma sample by casting alpha to double to prevent underflow. */
auto iter1 = TensorIteratorConfig()
.add_output(gamma)
.add_input(alpha)
.check_all_same_dtype(false)
.build();
cpu_serial_kernel(iter1, [generator](scalar_t alpha_val) -> double{
auto uniform_lambda = [generator] () {
at::uniform_real_distribution<double> standard_uniform(0.0, 1.0);
return standard_uniform(generator);
};
BaseSampler<double, decltype(uniform_lambda)> standard_uniform(uniform_lambda);
auto normal_lambda = [generator] () {
at::normal_distribution<double> normal(0.0, 1.0);
return normal(generator);
};
BaseSampler<double, decltype(normal_lambda)> standard_normal(normal_lambda);
auto sample = sample_gamma<double, double, decltype(uniform_lambda), decltype(normal_lambda)>
(alpha_val, standard_uniform, standard_normal);
return std::max(std::numeric_limits<double>::min(), sample);
});
/* Normalize and cast back to scalar_t. */
Tensor gamma_sum = gamma.sum(-1, true).expand(alpha.sizes());
auto iter2 = TensorIteratorConfig()
.add_output(ret)
.add_input(gamma)
.add_input(gamma_sum)
.check_all_same_dtype(false)
.build();
cpu_serial_kernel(iter2, [](double gamma_val, double gamma_sum_val) -> scalar_t{
auto ret_val = gamma_val / gamma_sum_val;
auto min_val = std::numeric_limits<scalar_t>::min();
auto max_val = std::nexttoward(static_cast<scalar_t>(1.0f), 0.0f);
return std::min(max_val, std::max(min_val, static_cast<scalar_t>(ret_val)));
});
});
return ret;
}
/* The largest consecutive integer representable in float32 (2^24) */
constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG);
Tensor& multinomial_out(const Tensor& self,
int64_t n_sample,
bool with_replacement,
c10::optional<Generator> gen,
Tensor& result) {
TORCH_CHECK(
result.device() == self.device(),
"multinomial arguments must have the same device");
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
TORCH_CHECK(
at::isFloatingType(self.scalar_type()),
"multinomial only supports floating-point dtypes for input, got: ",
self.scalar_type());
TORCH_CHECK(result.scalar_type() == ScalarType::Long,
"multinomial expects Long tensor out, got: ", result.scalar_type());
TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
int64_t n_categories = self.size(-1);
TORCH_CHECK(with_replacement || (n_sample <= n_categories),
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
// Since the index tensor is float, numCategories cannot exceed max
// float integer precision
TORCH_CHECK(
n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
"number of categories cannot exceed 2^24");
if (self.dim() == 1) {
result.resize_({n_sample});
} else {
const int64_t n_dist = self.size(0);
result.resize_({n_dist, n_sample});
}
if (result.numel() == 0) {
return result;
}
// Fast-path for no replacement.
// Reference:
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
// Half is not supported on CPU.
TORCH_CHECK(
!(self.device().is_cpu() && self.scalar_type() == ScalarType::Half),
"multinomial is not implemented for half on CPU");
if (!with_replacement) {
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(
is_valid.to<bool>(),
"probability tensor contains either `inf`, `nan` or element < 0");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool zero_prob_condition;
if (self.dim() == 1){
zero_prob_condition = (self.sum() == 0).item().to<bool>();
} else {
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
}
TORCH_CHECK(
!zero_prob_condition,
"invalid multinomial distribution (sum of probabilities <= 0)");
// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
// Here we can apply exp to the formula which will not affect result of
// argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1)
Tensor q = at::empty_like(self).exponential_(1, gen);
// In theory the probability to generate 0 from exponential distribution is
// 0. However, on CUDA side there is a protection to avoid 0s, but on CPU
// side, there is a very low probability to generate 0 from
// exponential<double>. The probability is about 2^(-DBL_MANT_DIG). We just
// ignore it here, but there may be some risk to get invalid output on CPU.
at::div_out(q, self, q);
if (n_sample == 1) {
at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true);
} else {
Tensor vals = at::empty(result.sizes(), self.options());
at::topk_out(vals, result, q, n_sample);
}
return result;
}
multinomial_with_replacement_stub(
result.device().type(), result, self, n_sample, gen);
return result;
}
Tensor multinomial(
const Tensor& self,
int64_t n_sample,
bool with_replacement,
c10::optional<Generator> gen) {
Tensor result = at::empty({0}, self.options().dtype(kLong));
native::multinomial_out(self, n_sample, with_replacement, gen, result);
return result;
}
}} // namespace at::native