| #include <ATen/ATen.h> |
| #include <ATen/CPUFunctions.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/NamedTensorUtils.h> |
| #include <ATen/ScalarOps.h> |
| |
| namespace at { |
| namespace meta { |
| TORCH_META_FUNC(addmv)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta, const Scalar& alpha) { |
| TORCH_CHECK((mat.dim() == 2 && vec.dim() == 1 && self.dim() <= 1), |
| "vector + matrix @ vector expected, got ", self.dim(), ", ", mat.dim(), ", ", vec.dim()); |
| |
| TORCH_CHECK(mat.size(1) == vec.size(0) && (mat.size(0) == self.numel() || self.numel() == 1), |
| "size mismatch, got ", self.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0)); |
| auto names = at::namedinference::propagate_names_for_addmv(mat, vec, self); |
| set_output(0, IntArrayRef(mat.sizes().data(), 1), {}, mat.options(), names); |
| auto result = maybe_get_output(0); |
| //this check can fire for inplace op only, for all other versions result is guaranteed to be correct size |
| TORCH_CHECK(result.dim() == 1 && result.sizes()[0] == mat.sizes()[0], "output of addmv operation should be 1D with ", |
| "size equal to mat.size(0), yet got output size ", result.sizes(), " and mat.size(0) ", mat.size(0)); |
| } |
| } |
| |
| namespace native { |
| |
| template<typename scalar_t> |
| void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t lda, scalar_t *x, int64_t incx, scalar_t beta, scalar_t *y, int64_t incy); |
| |
| template<typename scalar_t> |
| scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy); |
| |
| template<typename scalar_t> |
| scalar_t vdot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy); |
| |
| constexpr inline bool lda_cond(int64_t m, int64_t n, int64_t lda) { |
| return n == 1 || lda >= std::max<int64_t>(1L, m); |
| } |
| |
| |
| |
| |
| TORCH_IMPL_FUNC(addmv_out_cpu)(const Tensor &self, const Tensor &mat, const Tensor &vec, const Scalar& beta_, const Scalar& alpha_, const Tensor& result) { |
| c10::MaybeOwned<Tensor> self_ = expand_size(self, {mat.size(0)}); |
| auto betaval = beta_.toComplexDouble(); |
| if (mat.numel() == 0) { |
| // shortcut for an empty matrix |
| // By definition, when beta==0, values in self should be ignored. nans and infs |
| // should not propagate |
| if (betaval == 0.0) { |
| result.zero_(); |
| } else { |
| at::cpu::mul_out( |
| const_cast<Tensor&>(result), |
| self, |
| at::native::scalar_tensor( |
| beta_, self.scalar_type(), c10::nullopt /* layout */, at::kCPU, c10::nullopt /* pin_memory */)); |
| } |
| } else { |
| if (!result.is_same(*self_) && betaval != 0.0) { //if beta is 0, result contents is ignored |
| at::native::copy_(const_cast<Tensor&>(result), *self_); |
| } |
| if (result.numel() != 0) { |
| auto r_stride = result.stride(0); |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, mat.scalar_type(), "addmv_impl_cpu", [&] { |
| auto beta = beta_.to<scalar_t>(); |
| auto alpha = alpha_.to<scalar_t>(); |
| if (mat.stride(0) == 1 && lda_cond(mat.size(0), mat.size(1), mat.stride(1))) { |
| gemv<scalar_t>('n', mat.size(0), mat.size(1), alpha, mat.data_ptr<scalar_t>(), mat.stride(1), |
| vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride); |
| } |
| else if (mat.stride(1) == 1 && lda_cond(mat.size(1), mat.size(0), mat.stride(0))) { |
| gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, mat.data_ptr<scalar_t>(), mat.stride(0), |
| vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride); |
| } |
| else { |
| Tensor cmat = mat.contiguous(); |
| gemv<scalar_t>('t', mat.size(1), mat.size(0), alpha, cmat.data_ptr<scalar_t>(), cmat.stride(0), |
| vec.data_ptr<scalar_t>(), vec.stride(0), beta, result.data_ptr<scalar_t>(), r_stride); |
| } |
| }); |
| } |
| } |
| } |
| |
| Tensor &mv_out(const Tensor &self, const Tensor &vec, Tensor& result) { |
| //self arg sent to addmv_out cannot be resized |
| //here we use result as self argument for addmv, and result is user supplied and can be wrong size |
| //it's not a hard error, because we allow resizing result, but it becomes a hard error |
| //in addmv, because addmv expects self to satisfy proper conditions |
| //to avoid this, supply correctly sized self, its contents doesn't matter because beta is 0 |
| if (result.dim() > 1 || (result.numel() != self.size(0) || result.numel() !=1)) { |
| Tensor self_addmv = at::empty({self.size(0)}, self.options()); |
| return at::addmv_out(result, self_addmv, self, vec, 0, 1); |
| } |
| return at::addmv_out(result, result, self, vec, 0, 1); |
| } |
| |
| Tensor mv(const Tensor &self, const Tensor &vec) { |
| Tensor result = at::empty({self.size(0)}, self.options()); |
| //inplace version is more efficient if we can use it |
| return at::addmv_(result, self, vec, 0, 1); |
| } |
| |
| inline void dot_check(const Tensor& self, const Tensor& other) { |
| TORCH_CHECK( |
| self.dim() == 1 && other.dim() == 1, |
| "1D tensors expected, but got ", |
| self.dim(), |
| "D and ", |
| other.dim(), |
| "D tensors"); |
| |
| TORCH_CHECK( |
| self.scalar_type() == other.scalar_type(), |
| "dot : expected both vectors to have same dtype, but found ", |
| self.scalar_type(), |
| " and ", |
| other.scalar_type()); |
| |
| TORCH_CHECK( |
| self.numel() == other.numel(), |
| "inconsistent tensor size, expected tensor [", |
| self.numel(), |
| "] and src [", |
| other.numel(), |
| "] to have the same number of elements, but got ", |
| self.numel(), |
| " and ", |
| other.numel(), |
| " elements respectively"); |
| } |
| |
| Tensor dot(const Tensor &self, const Tensor &other){ |
| at::NoNamesGuard guard; |
| |
| dot_check(self, other); |
| |
| return AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, self.scalar_type(), "dot", [&] { |
| Tensor result = at::empty({}, self.options()); |
| result.fill_(dot_impl<scalar_t>(self.numel(), self.data_ptr<scalar_t>(), self.stride(0), other.data_ptr<scalar_t>(), other.stride(0))); |
| return result; |
| }); |
| } |
| |
| Tensor vdot(const Tensor &self, const Tensor &other){ |
| at::NoNamesGuard guard; |
| |
| // Dispatch to `dot` for real dtypes. |
| if (!self.is_complex()){ |
| return at::dot(self, other); |
| } |
| |
| // For complex dtypes. |
| dot_check(self, other); |
| return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] { |
| Tensor result = at::empty({}, self.options()); |
| result.fill_(vdot_impl<scalar_t>(self.numel(), self.data_ptr<scalar_t>(), self.stride(0), other.data_ptr<scalar_t>(), other.stride(0))); |
| return result; |
| }); |
| |
| } |
| |
| }} // namespace at::native |