blob: aa9fcbb4800a15ec5e8dea0861edc7d8bbe9c3fe [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NamedTensorUtils.h>
namespace at { namespace native {
template<typename scalar_t>
bool gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda, scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy);
constexpr inline bool lda_cond(int64_t m, int64_t n, int64_t lda) {
return n == 1 || lda > std::max<int64_t>(1L, m);
}
Tensor &addmv_impl_cpu(Tensor& result, const Tensor &self, const Tensor &mat, const Tensor &vec, Scalar beta_, Scalar alpha_) {
auto r_stride = result.stride(0);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, mat.scalar_type(), "addmv_impl_cpu", [&] {
auto beta = beta_.to<scalar_t>();
auto alpha = alpha_.to<scalar_t>();
bool is_fast = false;
if (mat.stride(0) == 1 && lda_cond(mat.size(0), mat.size(1), mat.stride(1))) {
is_fast = gemv<scalar_t>('n', mat.size(0), mat.size(1), alpha, mat.data_ptr<scalar_t>(), mat.stride(1),
vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride);
}
else if (mat.stride(1) == 1 && lda_cond(mat.size(1), mat.size(0), mat.stride(0))) {
is_fast = gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, mat.data_ptr<scalar_t>(), mat.stride(0),
vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride);
}
else {
Tensor cmat = mat.contiguous();
is_fast = gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, cmat.data_ptr<scalar_t>(), cmat.stride(0),
vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride);
}
// In THE FAST PATH of gemv (x,0).mv(0) does not handle beta, whereas gemm does for case where (x,0).mm(0,y).
// But in the naive fall back implementation, this is not the case.
if (is_fast && vec.size(0) == 0 && mat.size(0) != 0) {
if (beta == scalar_t(0)) {
result.zero_();
} else if (beta != scalar_t(1)) {
result.mul_(beta);
}
}
});
return result;
}
Tensor &addmv_out(Tensor& result, const Tensor &self, const Tensor &mat, const Tensor &vec, Scalar beta, Scalar alpha) {
{ // scope of NoNamesGuard
at::NoNamesGuard guard;
result.resize_({mat.size(0)});
Tensor self_ = self;
if (self.dim() == 0 || self.size(0) == 1) {
self_ = self.expand({mat.size(0)});
}
TORCH_CHECK((mat.dim() == 2 && vec.dim() == 1 && self_.dim() == 1),
"vector + matrix @ vector expected, got ", self_.dim(), ", ", mat.dim(), ", ", vec.dim());
TORCH_CHECK((mat.size(1) == vec.size(0) && mat.size(0) == self_.size(0)),
"size mismatch, get ", self_.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0));
if (!result.is_same(self_)) {
at::native::copy_(result, self_);
}
if (result.numel() != 0) {
at::_addmv_impl_(result, self_, mat, vec, beta, alpha);
}
} // scope of NoNamesGuard
at::namedinference::propagate_names_for_addmv(result, mat, vec, self);
return result;
}
Tensor addmv(const Tensor &self, const Tensor &mat, const Tensor &vec, Scalar beta, Scalar alpha) {
Tensor result = at::empty({mat.size(0)}, mat.options());
return native::addmv_out(result, self, mat, vec, beta, alpha);
}
Tensor &addmv_(Tensor &self, const Tensor &mat, const Tensor &vec, Scalar beta, Scalar alpha) {
return native::addmv_out(self, self, mat, vec, beta, alpha);
}
Tensor &mv_out(Tensor& result, const Tensor &self, const Tensor &vec) {
return native::addmv_out(result, result, self, vec, 0, 1);
}
Tensor mv(const Tensor &self, const Tensor &vec) {
Tensor result = at::empty({self.size(0)}, self.options());
return native::mv_out(result, self, vec);
}
}} // namespace at::native