| #include <ATen/native/BinaryOps.h> |
| |
| #include <type_traits> |
| |
| #include <ATen/ATen.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/MemoryOverlap.h> |
| #include <ATen/NativeFunctions.h> |
| #include <ATen/native/TensorIterator.h> |
| |
| #include <torch/library.h> |
| |
| namespace at { |
| namespace native { |
| |
| DEFINE_DISPATCH(add_stub); |
| DEFINE_DISPATCH(add_clamp_stub); |
| DEFINE_DISPATCH(sub_stub); |
| DEFINE_DISPATCH(mul_stub); |
| DEFINE_DISPATCH(div_stub); |
| DEFINE_DISPATCH(remainder_stub); |
| DEFINE_DISPATCH(atan2_stub); |
| DEFINE_DISPATCH(bitwise_and_stub); |
| DEFINE_DISPATCH(bitwise_or_stub); |
| DEFINE_DISPATCH(bitwise_xor_stub); |
| DEFINE_DISPATCH(lshift_stub); |
| DEFINE_DISPATCH(rshift_stub); |
| DEFINE_DISPATCH(logical_and_stub); |
| DEFINE_DISPATCH(logical_or_stub); |
| DEFINE_DISPATCH(logical_xor_stub); |
| DEFINE_DISPATCH(lt_stub); |
| DEFINE_DISPATCH(le_stub); |
| DEFINE_DISPATCH(gt_stub); |
| DEFINE_DISPATCH(ge_stub); |
| DEFINE_DISPATCH(eq_stub); |
| DEFINE_DISPATCH(ne_stub); |
| DEFINE_DISPATCH(sigmoid_backward_stub); |
| DEFINE_DISPATCH(logit_backward_stub); |
| DEFINE_DISPATCH(tanh_backward_stub); |
| DEFINE_DISPATCH(maximum_stub); |
| DEFINE_DISPATCH(minimum_stub); |
| DEFINE_DISPATCH(fmod_stub); |
| DEFINE_DISPATCH(fmod_scalar_stub); |
| DEFINE_DISPATCH(logaddexp_stub); |
| DEFINE_DISPATCH(logaddexp2_stub); |
| DEFINE_DISPATCH(gcd_stub); |
| DEFINE_DISPATCH(lcm_stub); |
| DEFINE_DISPATCH(hypot_stub); |
| DEFINE_DISPATCH(nextafter_stub); |
| DEFINE_DISPATCH(heaviside_stub); |
| |
| Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| alpha_check(iter.dtype(), alpha); |
| add_stub(iter.device_type(), iter, alpha); |
| TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype()); |
| return result; |
| } |
| |
| Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| alpha_check(iter.dtype(), alpha); |
| add_stub(iter.device_type(), iter, alpha); |
| return iter.output(); |
| } |
| |
| Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { |
| return native::add_out(self, self, other, alpha); |
| } |
| |
| Tensor& add_relu_impl( |
| Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| Scalar min_val; |
| Scalar max_val; |
| if (self.dtype() == at::kInt) { |
| min_val = 0; |
| max_val = std::numeric_limits<int32_t>::max(); |
| } else if (self.dtype() == at::kLong) { |
| min_val = 0; |
| max_val = std::numeric_limits<int64_t>::max(); |
| } else if (self.dtype() == at::kShort) { |
| min_val = 0; |
| max_val = std::numeric_limits<int16_t>::max(); |
| } else if (self.dtype() == at::kChar) { |
| min_val = 0; |
| max_val = std::numeric_limits<int8_t>::max(); |
| } else if (self.dtype() == at::kFloat) { |
| min_val = 0.0; |
| max_val = std::numeric_limits<float>::max(); |
| } else if (self.dtype() == at::kDouble) { |
| min_val = 0.0; |
| max_val = std::numeric_limits<double>::max(); |
| } else { |
| TORCH_INTERNAL_ASSERT( |
| "Unsupported datatype for add_relu:", self.dtype().name()); |
| } |
| |
| result = iter.output(); |
| add_clamp_stub(iter.device_type(), iter, alpha, min_val, max_val); |
| return result; |
| } |
| |
| Tensor& add_relu_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { |
| return add_relu_impl(result, self, other, alpha); |
| } |
| |
| Tensor add_relu(const Tensor& self, const Tensor& other, Scalar alpha) { |
| Tensor result; |
| return add_relu_impl(result, self, other, alpha); |
| } |
| |
| Tensor& add_relu_(Tensor& self, const Tensor& other, Scalar alpha) { |
| return add_relu_impl(self, self, other, alpha); |
| } |
| |
| Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| if (isIntegralType(result.scalar_type(), /*includeBool=*/ true)) { |
| TORCH_CHECK(false, |
| "Integer division of tensors using div or / is no longer supported, ", |
| "and in a future release div will perform true division as in Python 3. ", |
| "Use true_divide or floor_divide (// in Python) instead."); |
| } |
| |
| auto iter = TensorIterator::binary_op(result, self, other); |
| div_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor div(const Tensor& self, const Tensor& other) { |
| if (isIntegralType(self.scalar_type(), /*includeBool=*/ true) |
| && isIntegralType(other.scalar_type(), /*includeBool=*/ true)) { |
| TORCH_CHECK(false, |
| "Integer division of tensors using div or / is no longer supported, ", |
| "and in a future release div will perform true division as in Python 3. ", |
| "Use true_divide or floor_divide (// in Python) instead."); |
| } |
| |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| div_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& div_(Tensor& self, const Tensor& other) { |
| return native::div_out(self, self, other); |
| } |
| |
| Tensor& remainder_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| remainder_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor remainder(const Tensor& self, const Tensor& other) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| remainder_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& remainder_(Tensor& self, const Tensor& other) { |
| return native::remainder_out(self, self, other); |
| } |
| |
| Tensor& true_divide_out(Tensor& result, const Tensor& self, const Tensor& divisor) { |
| TensorIterator iter = TensorIteratorConfig() |
| .set_check_mem_overlap(true) |
| .add_output(result) |
| .add_input(self) |
| .add_input(divisor) |
| .allow_cpu_scalars(true) |
| .promote_inputs_to_common_dtype(true) |
| .promote_integer_inputs_to_float(true) |
| .cast_common_dtype_to_outputs(true) |
| .enforce_safe_casting_to_output(true) |
| .build(); |
| |
| div_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor true_divide(const Tensor& self, const Tensor& divisor) { |
| // If both inputs have integral (or bool) types, creates |
| // temporary float copies as new inputs and sets the result's type to |
| // the default scalar type |
| if (isIntegralType(self.scalar_type(), /*includeBool=*/ true) |
| && isIntegralType(divisor.scalar_type(), /*includeBool=*/ true)) { |
| const auto scalar_type = typeMetaToScalarType(c10::get_default_dtype()); |
| Tensor result = at::empty({0}, self.options().dtype(scalar_type)); |
| auto iter = TensorIterator::binary_op(result, |
| self.to(scalar_type), |
| divisor.to(scalar_type)); |
| div_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| // If at least one input is non-integral (or bool) participates in |
| // type promotion like other binary ufuncs |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, divisor); |
| div_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& true_divide_(Tensor& self, const Tensor& divisor) { |
| return native::true_divide_out(self, self, divisor); |
| } |
| |
| Tensor& floor_divide_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| div_stub(iter.device_type(), iter); |
| |
| if (result.is_floating_point()) { |
| result.trunc_(); |
| } |
| |
| return result; |
| } |
| |
| Tensor floor_divide(const Tensor& self, const Tensor& other) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| |
| div_stub(iter.device_type(), iter); |
| |
| auto out = iter.output(); |
| if (out.is_floating_point()) { |
| out.trunc_(); |
| } |
| |
| return out; |
| } |
| |
| Tensor& floor_divide_(Tensor& self, const Tensor& other) { |
| return native::floor_divide_out(self, self, other); |
| } |
| |
| Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| mul_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor mul(const Tensor& self, const Tensor& other) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| mul_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& mul_(Tensor& self, const Tensor& other) { |
| return native::mul_out(self, self, other); |
| } |
| |
| Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { |
| sub_check(self, other); |
| auto iter = TensorIterator::binary_op(result, self, other); |
| alpha_check(iter.dtype(), alpha); |
| sub_stub(iter.device_type(), iter, alpha); |
| TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype()); |
| return result; |
| } |
| |
| Tensor sub(const Tensor& self, const Tensor& other, Scalar alpha) { |
| sub_check(self, other); |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| alpha_check(iter.dtype(), alpha); |
| sub_stub(iter.device_type(), iter, alpha); |
| return iter.output(); |
| } |
| |
| Tensor& sub_(Tensor& self, const Tensor& other, Scalar alpha) { |
| return native::sub_out(self, self, other, alpha); |
| } |
| |
| Tensor& sigmoid_backward_out(Tensor& result, const Tensor& grad_output, const Tensor& output) { |
| auto iter = TensorIterator::binary_op(result, grad_output, output); |
| sigmoid_backward_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor sigmoid_backward(const Tensor& grad_output, const Tensor& output) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, grad_output, output); |
| sigmoid_backward_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& logit_backward_out( |
| Tensor& result, |
| const Tensor& grad_output, |
| const Tensor& input, |
| c10::optional<double> eps) { |
| auto iter = TensorIterator::binary_op(result, grad_output, input); |
| logit_backward_stub( |
| iter.device_type(), iter, Scalar(eps ? eps.value() : -1.0)); |
| return result; |
| } |
| |
| Tensor logit_backward( |
| const Tensor& grad_output, |
| const Tensor& input, |
| c10::optional<double> eps) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, grad_output, input); |
| logit_backward_stub( |
| iter.device_type(), iter, Scalar(eps ? eps.value() : -1.0)); |
| return iter.output(); |
| } |
| |
| Tensor& tanh_backward_out(Tensor& result, const Tensor& grad_output, const Tensor& output) { |
| auto iter = TensorIterator::binary_op(result, grad_output, output); |
| tanh_backward_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor tanh_backward(const Tensor& grad_output, const Tensor& output) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, grad_output, output); |
| tanh_backward_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor rsub(const Tensor& self, const Tensor& other, Scalar alpha) { |
| return native::sub(other, self, alpha); |
| } |
| |
| Tensor& atan2_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| atan2_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor atan2(const Tensor& self, const Tensor& other) { |
| Tensor result = at::empty({0}, self.options()); |
| return native::atan2_out(result, self, other); |
| } |
| |
| Tensor& atan2_(Tensor& self, const Tensor& other) { |
| return native::atan2_out(self, self, other); |
| } |
| |
| // These are still needed because we don't have C++ conversions from number |
| // types (int, float, etc.) to Tensor (only to Scalar). They're not exposed |
| // to Python. |
| |
| static Tensor wrapped_scalar_tensor(Scalar scalar) { |
| auto tensor = scalar_to_tensor(scalar); |
| tensor.unsafeGetTensorImpl()->set_wrapped_number(true); |
| return tensor; |
| } |
| |
| static void check_convert(Scalar scalar, ScalarType scalarType) { |
| // Validate that is possible to convert scalar to tensor dtype without overflow |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{ |
| scalar.to<scalar_t>(); |
| }); |
| } |
| |
| static Tensor wrapped_scalar_tensor_and_check_convert(Scalar scalar, Tensor tensor) { |
| check_convert(scalar, tensor.scalar_type()); |
| return wrapped_scalar_tensor(scalar); |
| } |
| |
| Tensor add(const Tensor& self, Scalar other, Scalar alpha) { |
| return native::add(self, wrapped_scalar_tensor(other), alpha); |
| } |
| |
| Tensor& add_(Tensor& self, Scalar other, Scalar alpha) { |
| return native::add_(self, wrapped_scalar_tensor(other), alpha); |
| } |
| |
| // WARNING: There doesn't appear to be any testing for this function |
| // with sparse self input. |
| Tensor div(const Tensor& self, Scalar other) { |
| return self.div(wrapped_scalar_tensor(other)); // redispatch! |
| } |
| |
| // WARNING: This function, with a sparse self, is currently only |
| // exercised by DistributedDataParallelTest.test_sparse_gradients |
| // (you need to exercise it from C++, because this overload is never |
| // used for Python) |
| Tensor& div_(Tensor& self, Scalar other) { |
| return self.div_(wrapped_scalar_tensor(other)); // redispatch! |
| } |
| |
| Tensor remainder(const Tensor& self, Scalar other) { |
| Tensor other_tensor = wrapped_scalar_tensor(other); |
| // FIXME: 'other' is converted to match the dtype of 'self' to retain |
| // BC with TH, but in the future, we should use normal type promotion, |
| // like in numpy |
| return native::remainder(self, other_tensor.toType(self.scalar_type())); |
| } |
| |
| Tensor& remainder_(Tensor& self, Scalar other) { |
| Tensor other_tensor = wrapped_scalar_tensor(other); |
| // FIXME: 'other' is converted to match the dtype of 'self' to retain |
| // BC with TH, but in the future, we should use normal type promotion, |
| // like in numpy |
| return native::remainder_(self, other_tensor.toType(self.scalar_type())); |
| } |
| |
| Tensor& remainder_out(Tensor& result, const Tensor& self, Scalar other) { |
| Tensor other_tensor = wrapped_scalar_tensor(other); |
| // FIXME: 'other' is converted to match the dtype of 'self' to retain |
| // BC with TH, but in the future, we should use normal type promotion, |
| // like in numpy |
| return native::remainder_out(result, self, other_tensor.toType(self.scalar_type())); |
| } |
| |
| Tensor mul(const Tensor& self, Scalar other) { |
| return native::mul(self, wrapped_scalar_tensor(other)); |
| } |
| |
| Tensor& mul_(Tensor& self, Scalar other) { |
| return native::mul_(self, wrapped_scalar_tensor(other)); |
| } |
| |
| Tensor sub(const Tensor& self, Scalar other, Scalar alpha) { |
| return native::sub(self, wrapped_scalar_tensor(other), alpha); |
| } |
| |
| Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) { |
| return native::sub_(self, wrapped_scalar_tensor(other), alpha); |
| } |
| |
| Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) { |
| return native::rsub(self, wrapped_scalar_tensor(other), alpha); |
| } |
| |
| Tensor& bitwise_and_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| bitwise_and_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor bitwise_and(const Tensor& self, const Tensor& other) { |
| Tensor result = at::empty({0}, self.options()); |
| at::bitwise_and_out(result, self, other); |
| return result; |
| } |
| |
| Tensor& bitwise_and_(Tensor& self, const Tensor& other) { |
| return at::bitwise_and_out(self, self, other); |
| } |
| |
| Tensor& bitwise_and_out(Tensor& result, const Tensor& self, Scalar other) { |
| return at::bitwise_and_out(result, self, wrapped_scalar_tensor(other)); |
| } |
| |
| Tensor bitwise_and(const Tensor& self, Scalar other) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::bitwise_and_out(result, self, other); |
| } |
| |
| Tensor& bitwise_and_(Tensor& self, Scalar other) { |
| return at::bitwise_and_out(self, self, other); |
| } |
| |
| // Legacy and interfaces. They are aliased to bitwise_and* functions |
| Tensor __and__(const Tensor& self, const Tensor& other) { |
| return at::bitwise_and(self, other); |
| } |
| |
| Tensor __and__(const Tensor& self, Scalar other) { |
| return at::bitwise_and(self, other); |
| } |
| |
| Tensor& __iand__(Tensor& self, const Tensor& other) { |
| return self.bitwise_and_(other); |
| } |
| |
| Tensor& __iand__(Tensor& self, Scalar other) { |
| return self.bitwise_and_(other); |
| } |
| |
| Tensor& bitwise_or_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| bitwise_or_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor bitwise_or(const Tensor& self, const Tensor& other) { |
| Tensor result = at::empty({0}, self.options()); |
| at::bitwise_or_out(result, self, other); |
| return result; |
| } |
| |
| Tensor& bitwise_or_(Tensor& self, const Tensor& other) { |
| return at::bitwise_or_out(self, self, other); |
| } |
| |
| Tensor& bitwise_or_out(Tensor& result, const Tensor& self, Scalar other) { |
| return at::bitwise_or_out(result, self, wrapped_scalar_tensor(other)); |
| } |
| |
| Tensor bitwise_or(const Tensor& self, Scalar other) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::bitwise_or_out(result, self, other); |
| } |
| |
| Tensor& bitwise_or_(Tensor& self, Scalar other) { |
| return at::bitwise_or_out(self, self, other); |
| } |
| |
| // Legacy or interfaces. They are aliased to bitwise_or* functions |
| Tensor __or__(const Tensor& self, const Tensor& other) { |
| return at::bitwise_or(self, other); |
| } |
| |
| Tensor __or__(const Tensor& self, Scalar other) { |
| return at::bitwise_or(self, other); |
| } |
| |
| Tensor& __ior__(Tensor& self, const Tensor& other) { |
| return self.bitwise_or_(other); |
| } |
| |
| Tensor& __ior__(Tensor& self, Scalar other) { |
| return self.bitwise_or_(other); |
| } |
| |
| Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| bitwise_xor_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor bitwise_xor(const Tensor& self, const Tensor& other) { |
| Tensor result = at::empty({0}, self.options()); |
| at::bitwise_xor_out(result, self, other); |
| return result; |
| } |
| |
| Tensor& bitwise_xor_(Tensor& self, const Tensor& other) { |
| return at::bitwise_xor_out(self, self, other); |
| } |
| |
| Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, Scalar other) { |
| return at::bitwise_xor_out(result, self, wrapped_scalar_tensor(other)); |
| } |
| |
| Tensor bitwise_xor(const Tensor& self, Scalar other) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::bitwise_xor_out(result, self, other); |
| } |
| |
| Tensor& bitwise_xor_(Tensor& self, Scalar other) { |
| return at::bitwise_xor_out(self, self, other); |
| } |
| |
| // Legacy xor interfaces. They are aliased to bitwise_xor* functions |
| Tensor __xor__(const Tensor& self, const Tensor& other) { |
| return at::bitwise_xor(self, other); |
| } |
| |
| Tensor __xor__(const Tensor& self, Scalar other) { |
| return at::bitwise_xor(self, other); |
| } |
| |
| Tensor& __ixor__(Tensor& self, const Tensor& other) { |
| return self.bitwise_xor_(other); |
| } |
| |
| Tensor& __ixor__(Tensor& self, Scalar other) { |
| return self.bitwise_xor_(other); |
| } |
| |
| Tensor __lshift__(const Tensor& self, const Tensor& other) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| lshift_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor __lshift__(const Tensor& self, Scalar other) { |
| Tensor result; |
| auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type()); |
| auto iter = TensorIterator::binary_op(result, self, wrapper); |
| lshift_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& __ilshift__(Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(self, self, other); |
| lshift_stub(iter.device_type(), iter); |
| return self; |
| } |
| |
| Tensor& __ilshift__(Tensor& self, Scalar other) { |
| auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type()); |
| auto iter = TensorIterator::binary_op(self, self, wrapper); |
| lshift_stub(iter.device_type(), iter); |
| return self; |
| } |
| |
| Tensor __rshift__(const Tensor& self, const Tensor& other) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| rshift_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor __rshift__(const Tensor& self, Scalar other) { |
| Tensor result; |
| auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type()); |
| auto iter = TensorIterator::binary_op(result, self, wrapper); |
| rshift_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& __irshift__(Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(self, self, other); |
| rshift_stub(iter.device_type(), iter); |
| return self; |
| } |
| |
| Tensor& __irshift__(Tensor& self, Scalar other) { |
| auto wrapper = wrapped_scalar_tensor(other).toType(self.scalar_type()); |
| auto iter = TensorIterator::binary_op(self, self, wrapper); |
| rshift_stub(iter.device_type(), iter); |
| return self; |
| } |
| |
| template <typename Stub> |
| Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& other, Stub& stub) { |
| // Validate that is possible to convert zero-dim tensor's dtype to other dtype without overflow |
| if (self.scalar_type() != other.scalar_type()) { |
| if (self.dim() != 0 && other.dim() == 0) { |
| check_convert(other.item(), self.scalar_type()); |
| } else if (self.dim() == 0 && other.dim() != 0) { |
| check_convert(self.item(), other.scalar_type()); |
| } |
| } |
| auto iter = TensorIterator::comparison_op(result, self, other); |
| stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| template <typename OutImpl> |
| Tensor comparison_op(const Tensor& self, const Tensor& other, OutImpl& out_impl) { |
| Tensor result = at::empty({0}, self.options().dtype(kBool)); |
| return out_impl(result, self, other); |
| } |
| |
| // To avoid overflow during type promotion we will check that both dtypes of self and other are same |
| template <typename OutImpl> |
| Tensor& comparison_op_(Tensor& self, const Tensor& other, OutImpl& out_impl) { |
| TORCH_CHECK(self.dtype() == other.dtype(), |
| "Expected object of scalar type ", self.dtype(), " but got scalar type ", |
| other.dtype(), " for argument 'other'"); |
| return out_impl(self, self, other); |
| } |
| |
| // validates that is possible to convert Scalar other to self's dtype without overflow. |
| // This behavior is unique to comparison ops; arithmetic operations don't do this. |
| // In the future, we should reconsider this inconsistency and decide if we want to add the same check to arithmetic ops. |
| template <typename OutImpl> |
| Tensor& comparison_op_out(Tensor& result, const Tensor& self, Scalar other, OutImpl& out_impl) { |
| return out_impl(result, self, wrapped_scalar_tensor_and_check_convert(other, self)); |
| } |
| |
| template <typename OutImpl> |
| Tensor comparison_op(const Tensor& self, Scalar other, OutImpl& out_impl) { |
| return comparison_op(self, wrapped_scalar_tensor_and_check_convert(other, self), out_impl); |
| } |
| |
| template <typename OutImpl> |
| Tensor& comparison_op_(Tensor& self, Scalar other, OutImpl& out_impl) { |
| return out_impl(self, self, wrapped_scalar_tensor_and_check_convert(other, self)); |
| } |
| |
| // We need explicit cast to OutFunc because each *_out func is overloaded twice. Without An explicit cast, merely |
| // referring to *_out function is ambiguious. |
| using OutFunc = std::add_const<Tensor&(&)(Tensor&, const Tensor&, const Tensor&)>::type; |
| |
| Tensor& lt_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, lt_stub); } |
| Tensor lt(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::lt_out)); } |
| Tensor& lt_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::lt_out)); } |
| Tensor& lt_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::lt_out)); } |
| Tensor lt(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::lt_out)); } |
| Tensor& lt_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::lt_out)); } |
| |
| Tensor& le_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, le_stub); } |
| Tensor le(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::le_out)); } |
| Tensor& le_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::le_out)); } |
| Tensor& le_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::le_out)); } |
| Tensor le(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::le_out)); } |
| Tensor& le_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::le_out)); } |
| |
| Tensor& gt_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, gt_stub); } |
| Tensor gt(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::gt_out)); } |
| Tensor& gt_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::gt_out)); } |
| Tensor& gt_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::gt_out)); } |
| Tensor gt(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::gt_out)); } |
| Tensor& gt_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::gt_out)); } |
| |
| Tensor& ge_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, ge_stub); } |
| Tensor ge(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::ge_out)); } |
| Tensor& ge_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::ge_out)); } |
| Tensor& ge_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::ge_out)); } |
| Tensor ge(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::ge_out)); } |
| Tensor& ge_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::ge_out)); } |
| |
| Tensor& eq_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, eq_stub); } |
| Tensor eq(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::eq_out)); } |
| Tensor& eq_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::eq_out)); } |
| Tensor& eq_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::eq_out)); } |
| Tensor eq(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::eq_out)); } |
| Tensor& eq_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::eq_out)); } |
| |
| Tensor& ne_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, ne_stub); } |
| Tensor ne(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::ne_out)); } |
| Tensor& ne_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::ne_out)); } |
| Tensor& ne_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::ne_out)); } |
| Tensor ne(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::ne_out)); } |
| Tensor& ne_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::ne_out)); } |
| |
| Tensor& logical_and_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, logical_and_stub); } |
| Tensor logical_and(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out)); } |
| Tensor& logical_and_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out)); } |
| Tensor& logical_and_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_and_out)); } |
| Tensor logical_and(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out)); } |
| Tensor& logical_and_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out)); } |
| |
| Tensor& logical_or_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, logical_or_stub); } |
| Tensor logical_or(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out)); } |
| Tensor& logical_or_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out)); } |
| Tensor& logical_or_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_or_out)); } |
| Tensor logical_or(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out)); } |
| Tensor& logical_or_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out)); } |
| |
| Tensor& logical_xor_out(Tensor& result, const Tensor& self, const Tensor& other) { return comparison_op_out(result, self, other, logical_xor_stub); } |
| Tensor logical_xor(const Tensor& self, const Tensor& other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out)); } |
| Tensor& logical_xor_(Tensor& self, const Tensor& other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); } |
| Tensor& logical_xor_out(Tensor& result, const Tensor& self, Scalar other) { return comparison_op_out(result, self, other, static_cast<OutFunc>(at::logical_xor_out)); } |
| Tensor logical_xor(const Tensor& self, Scalar other) { return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out)); } |
| Tensor& logical_xor_(Tensor& self, Scalar other) { return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out)); } |
| |
| Tensor& maximum_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs."); |
| |
| auto iter = TensorIterator::binary_op(result, self, other); |
| maximum_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor maximum(const Tensor& self, const Tensor& other) { |
| TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs."); |
| |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| maximum_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| // binary max, alias for maximum |
| Tensor& max_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| return at::maximum_out(result, self, other); |
| } |
| |
| Tensor max(const Tensor& self, const Tensor& other) { |
| return at::maximum(self, other); |
| } |
| |
| Tensor& minimum_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs."); |
| |
| auto iter = TensorIterator::binary_op(result, self, other); |
| minimum_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor minimum(const Tensor& self, const Tensor& other) { |
| TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs."); |
| |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| minimum_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| // binary min, alias for minimum |
| Tensor& min_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| return at::minimum_out(result, self, other); |
| } |
| |
| Tensor min(const Tensor& self, const Tensor& other) { |
| return at::minimum(self, other); |
| } |
| |
| Tensor floor_divide(const Tensor& self, Scalar other) { |
| return at::floor_divide(self, wrapped_scalar_tensor(other)); |
| } |
| |
| Tensor& floor_divide_(Tensor& self, Scalar other) { |
| return at::floor_divide_out(self, self, wrapped_scalar_tensor(other)); |
| } |
| |
| Tensor& fmod_out(Tensor & result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU"); |
| fmod_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor& fmod_out(Tensor & result, const Tensor& self, Scalar other) { |
| auto iter = TensorIterator::unary_op(result, self); |
| TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU"); |
| fmod_scalar_stub(iter.device_type(), iter, other); |
| return result; |
| } |
| |
| Tensor fmod(const Tensor& self, const Tensor & other) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::fmod_out(result, self, other); |
| } |
| |
| Tensor fmod(const Tensor& self, Scalar other) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::fmod_out(result, self, other); |
| } |
| |
| Tensor& fmod_(Tensor& self, const Tensor& other) { |
| return at::fmod_out(self, self, other); |
| } |
| |
| Tensor& fmod_(Tensor& self, Scalar other) { |
| return at::fmod_out(self, self, other); |
| } |
| |
| Tensor& logaddexp_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| logaddexp_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor logaddexp(const Tensor& self, const Tensor& other) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::logaddexp_out(result, self, other); |
| } |
| |
| Tensor& logaddexp2_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| logaddexp2_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor logaddexp2(const Tensor& self, const Tensor& other) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::logaddexp2_out(result, self, other); |
| } |
| |
| Tensor& gcd_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| gcd_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor gcd(const Tensor& self, const Tensor& other) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::gcd_out(result, self, other); |
| } |
| |
| Tensor& gcd_(Tensor& self, const Tensor& other) { |
| return at::gcd_out(self, self, other); |
| } |
| |
| Tensor& lcm_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| lcm_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor lcm(const Tensor& self, const Tensor& other) { |
| Tensor result = at::empty({0}, self.options()); |
| return at::lcm_out(result, self, other); |
| } |
| |
| Tensor& lcm_(Tensor& self, const Tensor& other) { |
| return at::lcm_out(self, self, other); |
| } |
| |
| Tensor& hypot_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| hypot_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor hypot(const Tensor& self, const Tensor& other) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| hypot_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& hypot_(Tensor& self, const Tensor& other) { |
| return at::hypot_out(self, self, other); |
| } |
| |
| Tensor& nextafter_out(Tensor& result, const Tensor& self, const Tensor& other) { |
| auto iter = TensorIterator::binary_op(result, self, other); |
| nextafter_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor nextafter(const Tensor& self, const Tensor& other) { |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, other); |
| nextafter_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& nextafter_(Tensor& self, const Tensor& other) { |
| return at::nextafter_out(self, self, other); |
| } |
| |
| Tensor true_divide(const Tensor& self, Scalar divisor) { |
| return self.true_divide(wrapped_scalar_tensor(divisor)); // redispatch! |
| } |
| |
| Tensor& true_divide_(Tensor& self, Scalar divisor) { |
| return self.true_divide_(wrapped_scalar_tensor(divisor)); // redispatch! |
| } |
| |
| // Note: this function is only for testing. |
| // It is undocumented and should not be used outside of tests. |
| Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, Scalar alpha) { |
| return self - (other * alpha); |
| } |
| |
| Tensor& heaviside_out(Tensor& result, const Tensor& self, const Tensor& values) { |
| TORCH_CHECK(!self.is_complex() && !result.is_complex() && !values.is_complex(), |
| "heaviside is not yet implemented for complex tensors."); |
| TORCH_CHECK(self.dtype() == values.dtype() && result.dtype() == self.dtype(), |
| "heaviside is not yet implemented for tensors with different dtypes."); |
| |
| auto iter = TensorIterator::binary_op(result, self, values, /*check_mem_overlap=*/true); |
| heaviside_stub(iter.device_type(), iter); |
| return result; |
| } |
| |
| Tensor heaviside(const Tensor& self, const Tensor& values) { |
| TORCH_CHECK(!self.is_complex() && !values.is_complex(), |
| "heaviside is not yet implemented for complex tensors."); |
| TORCH_CHECK(self.dtype() == values.dtype(), |
| "heaviside is not yet implemented for tensors with different dtypes."); |
| |
| Tensor result; |
| auto iter = TensorIterator::binary_op(result, self, values); |
| heaviside_stub(iter.device_type(), iter); |
| return iter.output(); |
| } |
| |
| Tensor& heaviside_(Tensor& self, const Tensor& values) { |
| return at::heaviside_out(self, self, values); |
| } |
| |
| // TODO: Deduplicate this with the TensorIterator logic. This would |
| // also fix the TODOs below. |
| Tensor binary_op_meta(const Tensor& self, const Tensor& other) { |
| // TODO: Doesn't do type promotion correctly |
| // TODO: Doesn't do strides correctly |
| int64_t dim = std::max(self.dim(), other.dim()); |
| std::vector<int64_t> sizes(dim); |
| for (int64_t i = 0; i < dim; i++) { |
| int64_t j = -1 - i; |
| if (i >= self.dim() || self.size(j) == 1) { |
| sizes[dim + j] = other.size(j); |
| } else if (i >= other.dim() || self.size(i) == 1) { |
| sizes[dim + j] = self.size(j); |
| } else { |
| TORCH_CHECK( |
| self.size(j) == other.size(j), |
| "Expected self.size(", j, ") == other.size(", j, "), but got ", self.size(j), " != ", other.size(j) |
| ); |
| sizes[dim + j] = self.size(j); |
| } |
| } |
| return at::empty_meta(sizes, self.options()); |
| } |
| |
| Tensor binary_op_with_scalar_meta(const Tensor& self, const Tensor& other, Scalar x) { |
| return binary_op_meta(self, other); |
| } |
| |
| TORCH_LIBRARY_IMPL(aten, Meta, m) { |
| m.impl("add.Tensor", binary_op_with_scalar_meta); |
| } |
| |
| } // namespace native |
| } // namespace at |