blob: 44ca2802bf3a2be7e86c29d3c86f8e11be80e68e [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <ATen/functorch/BatchRulesHelper.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/Operators.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <utility>
namespace at::functorch {
template <typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor,optional<int64_t>> _binary_pointwise_batch_rule(
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim,
ExtraArgs... extra_args) {
auto tensor_other = _binary_pointwise_helper(
tensor, tensor_batch_dim, other, other_batch_dim);
auto tensor_ = std::get<0>(tensor_other);
auto other_ = std::get<1>(tensor_other);
auto result = Func(tensor_, other_, std::forward<ExtraArgs>(extra_args)...);
return std::make_tuple(result, 0);
}
template <typename A, A a, typename C>
struct BinaryPointwiseBatchRuleHelper;
template <typename F, F Func, typename T1, typename T2, typename... T>
struct BinaryPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
static std::tuple<Tensor,optional<int64_t>> apply(
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim,
T... extra_args) {
return _binary_pointwise_batch_rule<F, Func, T...>(
tensor, tensor_batch_dim, other, other_batch_dim,
std::forward<T>(extra_args)...);
}
};
#define BINARY_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\
BinaryPointwiseBatchRuleHelper<\
decltype(&fn),\
&fn,\
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
template <typename A, A a, typename C>
struct BinaryRandomPointwiseBatchRuleHelper;
template <typename F, F Func, typename T1, typename T2, typename... T>
struct BinaryRandomPointwiseBatchRuleHelper<F, Func, typelist<T1, T2, T...>> {
static Tensor apply(const Tensor& tensor, const Tensor& other, T... extra_args) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode);
auto maybe_layer = maybeCurrentDynamicLayer();
auto cur_level = maybe_layer->layerId();
RandomnessType randomness = maybe_layer->randomness();
auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(tensor, cur_level);
auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level);
check_randomness(randomness, (tensor_bdim || other_bdim));
if (randomness == RandomnessType::Different && !tensor_bdim && !other_bdim) {
auto shape = tensor_value.sizes();
VmapSymDimVector shapeVec(1, maybe_layer->batchSize());
shapeVec.reserve(shape.size() + 1);
shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
// not taken care of with binary batch rule, which assumes at least one input is batched
tensor_value = tensor_value.expand_symint(shapeVec);
tensor_bdim = 0;
} else if (randomness == RandomnessType::Same && !tensor_bdim && !other_bdim) {
// avoids unnecessary checks and batch rule assuming output is batched
return Func(tensor_value, other_value, std::forward<T>(extra_args)...);
}
auto res = _binary_pointwise_batch_rule<F, Func, T...>(
tensor_value, tensor_bdim, other_value, other_bdim,
std::forward<T>(extra_args)...);
return makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
}
};
#define BINARY_RANDOM_POINTWISE_BATCH_RULE(fn) SINGLE_ARG(\
BinaryRandomPointwiseBatchRuleHelper<\
decltype(&fn),\
&fn,\
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
template <typename M, M Meth, typename... ExtraArgs>
void binary_pointwise_inplace_batch_rule(
Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim,
ExtraArgs... extra_args) {
if (!tensor_batch_dim && other_batch_dim) {
vmapIncompatibleInplaceError("inplace arithmetic");
}
// compute max logical rank
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
auto other_ = moveBatchDimToFront(other, other_batch_dim);
// If the dimensions aren't aligned, we need to line them up.
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
// Note that only tensors that have a batch dim need to be modified.
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
(tensor_.*Meth)(other_, std::forward<ExtraArgs>(extra_args)...);
}
template <typename F, F Func>
std::tuple<Tensor,optional<int64_t>> comparison_pointwise_batch_rule(
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim) {
// compute max logical rank
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
auto other_ = moveBatchDimToFront(other, other_batch_dim);
// If the dimensions aren't aligned, we need to line them up.
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
// Note that only tensors that have a batch dim need to be modified.
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
auto result = Func(tensor_, other_);
return std::make_tuple( std::move(result), 0 );
}
static std::tuple<Tensor,optional<int64_t>> where_self_batch_rule(
const Tensor& condition, optional<int64_t> condition_bdim,
const Tensor& self, optional<int64_t> self_bdim, const Tensor& other, optional<int64_t> other_bdim) {
auto condition_logical_rank = rankWithoutBatchDim(condition, condition_bdim);
auto tensor_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto other_logical_rank = rankWithoutBatchDim(other, other_bdim);
auto max_logical_rank = std::max({tensor_logical_rank, other_logical_rank, condition_logical_rank});
auto condition_ = moveBatchDimToFront(condition, condition_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
auto other_ = moveBatchDimToFront(other, other_bdim);
condition_ = maybePadToLogicalRank(condition_, condition_bdim, max_logical_rank);
self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank);
other_ = maybePadToLogicalRank(other_, other_bdim, max_logical_rank);
return std::make_tuple(at::where(condition_, self_, other_), 0);
}
static std::tuple<Tensor, optional<int64_t>> gelu_backward_batch_rule(
const Tensor& grad_out, optional<int64_t> grad_out_bdim, const Tensor& input, optional<int64_t> input_bdim,
c10::string_view approximate) {
// repeat the preprocessing from _binary_pointwise_batch_rule
const auto tensor_other = _binary_pointwise_helper(grad_out, grad_out_bdim, input, input_bdim);
auto grad_out_ = std::get<0>(tensor_other);
auto input_ = std::get<1>(tensor_other);
// gelu_backward doesn't broadcast well so we need to insist all inputs have a bdim
const auto batch_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim);
grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), batch_size);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), batch_size);
return std::make_tuple(at::gelu_backward(grad_out_, input_, approximate), 0);
}
static std::tuple<Tensor,optional<int64_t>> masked_select_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& mask, optional<int64_t> mask_bdim) {
TORCH_CHECK(!mask_bdim.has_value(),
"vmap: Attempted to vmap over `mask` in torch.masked_select(self, mask) ",
"We cannot support this because for each batch this would return a ",
"differently shaped Tensor. "
"Please voice your support in https://github.com/pytorch/functorch/issues/256");
auto self_ = moveBatchDimToFront(self, self_bdim);
const auto batch_size = self_.size(0);
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
const auto max_logical_rank = std::max(self_logical_rank, mask.dim());
self_ = maybePadToLogicalRank(self_, 0, max_logical_rank);
// masked_select returns a 1D tensor, so we have to reshape it into 2D
const auto result = at::masked_select(self_, mask).view({ batch_size, -1 });
return std::make_tuple(result, 0);
}
static std::tuple<Tensor,optional<int64_t>> masked_select_backward_batch_rule(
const Tensor& grad, optional<int64_t> grad_bdim,
const Tensor& self, optional<int64_t> self_bdim,
const Tensor& mask, optional<int64_t> mask_bdim) {
TORCH_CHECK(!mask_bdim.has_value(),
"vmap: Attempted to vmap over `mask` in torch.masked_select_backward(grad, self, mask) ",
"We cannot support this because for each batch this would return a ",
"differently shaped Tensor. "
"Please voice your support in https://github.com/pytorch/functorch/issues/256");
auto self_ = moveBatchDimToFront(self, self_bdim);
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
const auto max_logical_rank = std::max(self_logical_rank, mask.dim());
self_ = maybePadToLogicalRank(self_, self_bdim, max_logical_rank);
const auto batch_size = get_bdim_size2(grad, grad_bdim, self, self_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), batch_size);
const auto result = at::masked_select_backward(grad_, self_.contiguous(), mask);
return std::make_tuple(result, 0);
}
static std::tuple<Tensor,optional<int64_t>> cdist_backward_batch_rule(
const Tensor& grad, optional<int64_t> grad_bdim,
const Tensor& x1, optional<int64_t> x1_bdim,
const Tensor& x2, optional<int64_t> x2_bdim,
const double p,
const Tensor& cdist, optional<int64_t> cdist_bdim) {
auto x1_ = x1;
if (cdist_bdim && !x1_bdim) {
// We need to make sure that x1 has batch dim if cdist has one
// otherwise, we get
// RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
// but expected shape compatible with [4, 5]
auto bs = cdist.size(*cdist_bdim);
x1_ = ensure_has_bdim(x1, false, bs);
x1_ = x1_.contiguous();
x1_bdim = 0;
}
// We need to apply the same preprocessing on x1 and x2 as in the forward pass
// _binary_pointwise_batch_rule
auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim);
x1_ = std::get<0>(x12);
auto x2_ = std::get<1>(x12);
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
if ((x1_bdim || x2_bdim) && !grad_bdim) {
// We need to make sure that grad has batch dim if x1 or x2 have one
// Probably, there is an assumption on the strides.
// Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
auto bs = get_bdim_size2(x1_, 0, x2_, 0);
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs);
grad_ = grad_.contiguous();
}
auto out = at::_cdist_backward(grad_, x1_, x2_, p, cdist);
optional<int64_t> out_bdim = nullopt;
if (x1_bdim || x2_bdim) {
out_bdim = 0;
}
return std::make_tuple(out, out_bdim);
}
static void fill__Tensor_batch_rule(
Tensor& self,
optional<int64_t> self_bdim,
const Tensor& other,
optional<int64_t> other_bdim) {
if (!other_bdim.has_value()) {
// Optimization: fill_ is faster than the other path which does
// reshaping + copy_
self.fill_(other);
return;
}
if (!self_bdim && other_bdim) {
vmapIncompatibleInplaceError("fill_");
}
auto self_and_other = _binary_pointwise_helper(
self, self_bdim, other, other_bdim, /*do_type_promotion*/false);
std::get<0>(self_and_other).copy_(std::get<1>(self_and_other));
}
static std::tuple<Tensor, optional<int64_t>> log_sigmoid_backward_batch_rule(
Tensor& grad, optional<int64_t> grad_bdim,
Tensor& self, optional<int64_t> self_bdim,
Tensor& buffer, optional<int64_t> buffer_bdim) {
// NB: This emulates handle_pointwise_ops except we ignore the last argument, buffer
// when any of the inputs are on cuda.
// We do this because on cuda, buffer is a dummy tensor always of logical rank 1 and
// it becomes an issue when the rest of the inputs are scalar
int64_t out_logical_rank = std::max(rankWithoutBatchDim(grad, grad_bdim), rankWithoutBatchDim(self, self_bdim));
if (!grad.is_cuda() && !self.is_cuda() && !buffer.is_cuda()) {
out_logical_rank = std::max(out_logical_rank, rankWithoutBatchDim(buffer, buffer_bdim));
}
Tensor out_grad = maybePadToLogicalRank(moveBatchDimToFront(grad, grad_bdim), grad_bdim, out_logical_rank);
Tensor out_self = maybePadToLogicalRank(moveBatchDimToFront(self, self_bdim), self_bdim, out_logical_rank);
Tensor out_buffer = maybePadToLogicalRank(moveBatchDimToFront(buffer, buffer_bdim), buffer_bdim, out_logical_rank);
return std::make_tuple(at::log_sigmoid_backward(out_grad, out_self, out_buffer), 0);
}
static Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional<Generator> gen) {
return at::binomial(count, prob.contiguous(), std::move(gen)); // Bug in PyTorch, prob shouldn't need to be contiguous
}
TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
#define BINARY_RANDOM_POINTWISE(op) \
m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
#define BINARY_RANDOM_POINTWISE2(op, overload) \
m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
BINARY_RANDOM_POINTWISE2(normal, Tensor_Tensor);
m.impl("binomial", BINARY_RANDOM_POINTWISE_BATCH_RULE(at::functorch::binomial_wrapper));
}
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
#define BINARY_POINTWISE2(op, overload) \
VMAP_SUPPORT2(op, overload, BINARY_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
#define BINARY_POINTWISE(op) \
VMAP_SUPPORT(op, BINARY_POINTWISE_BATCH_RULE(ATEN_FN(op)));
#define UNARY_POINTWISE2(op, overload) \
VMAP_SUPPORT2(op, overload, BASIC_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
#define UNARY_POINTWISE(op) \
VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
#define UNARY_SCALAR_POINTWISE2(op, overload) \
VMAP_SUPPORT(op, overload, SCALAR_UNARY_BATCH_RULE(ATEN_FN2(op, overload)));
#define BINARY_SCALAR_2(op, tensor_tensor, tensor_scalar) \
BINARY_POINTWISE2(op, tensor_tensor);\
UNARY_POINTWISE2(op, tensor_scalar);
// For all 3 combinations of Tensor x Tensor, Tensor x Scalar, Scalar x Tensor
#define BINARY_SCALAR_3(op, tensor_tensor, tensor_scalar, scalar_tensor) \
BINARY_POINTWISE2(op, tensor_tensor);\
UNARY_POINTWISE2(op, tensor_scalar);\
POINTWISE_BOXED(op.scalar_tensor);
#define BINARY_SCALAR_3_Tensor(op, tensor_scalar, scalar_tensor) \
BINARY_POINTWISE(op);\
UNARY_POINTWISE2(op, tensor_scalar);\
POINTWISE_BOXED(op.scalar_tensor);
// Batching rule registrations start
POINTWISE_BOXED(__ilshift__.Tensor);
POINTWISE_BOXED(__ilshift__.Scalar);
POINTWISE_BOXED(__irshift__.Tensor)
POINTWISE_BOXED(__irshift__.Scalar)
BINARY_SCALAR_2(__lshift__, Tensor, Scalar);
BINARY_SCALAR_2(__rshift__, Tensor, Scalar);
BINARY_SCALAR_2(add, Tensor, Scalar);
POINTWISE_BOXED(addcdiv);
POINTWISE_BOXED(addcmul);
BINARY_POINTWISE(atan2);
BINARY_SCALAR_2(bitwise_and, Tensor, Scalar);
BINARY_POINTWISE2(bitwise_and_, Tensor);
POINTWISE_BOXED(bitwise_and_.Scalar);
POINTWISE_BOXED(bitwise_and.Scalar_Tensor);
BINARY_SCALAR_2(bitwise_or, Tensor, Scalar);
BINARY_POINTWISE2(bitwise_or_, Tensor);
POINTWISE_BOXED(bitwise_or_.Scalar);
POINTWISE_BOXED(bitwise_or.Scalar_Tensor);
BINARY_SCALAR_2(bitwise_xor, Tensor, Scalar);
BINARY_POINTWISE2(bitwise_xor_, Tensor);
POINTWISE_BOXED(bitwise_xor_.Scalar);
POINTWISE_BOXED(bitwise_xor.Scalar_Tensor);
BINARY_SCALAR_3(bitwise_left_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
POINTWISE_BOXED(bitwise_left_shift_.Tensor_Scalar);
POINTWISE_BOXED(bitwise_left_shift_.Tensor);
BINARY_SCALAR_3(bitwise_right_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
POINTWISE_BOXED(bitwise_right_shift_.Tensor_Scalar);
POINTWISE_BOXED(bitwise_right_shift_.Tensor);
UNARY_POINTWISE(clamp);
POINTWISE_BOXED(clamp.Tensor);
BINARY_POINTWISE2(clamp_min, Tensor);
UNARY_POINTWISE(clamp_min);
POINTWISE_BOXED(clamp_min_);
BINARY_POINTWISE2(clamp_max, Tensor);
UNARY_POINTWISE(clamp_max);
POINTWISE_BOXED(clamp_max_);
BINARY_POINTWISE(complex);
VARIADIC_BDIMS_BOXED(_euclidean_dist);
// Implementation note: _binary_pointwise_helper performs a dtype promotion if args are scalars,
// but cdist can't work with scalars, at least 2d tensors.
BINARY_POINTWISE(_cdist_forward);
VMAP_SUPPORT(_cdist_backward, cdist_backward_batch_rule);
BINARY_SCALAR_2(copysign, Tensor, Scalar);
POINTWISE_BOXED(copysign_.Tensor);
POINTWISE_BOXED(copysign_.Scalar);
BINARY_SCALAR_2(div, Tensor, Scalar);
BINARY_SCALAR_2(div, Tensor_mode, Scalar_mode);
BINARY_POINTWISE(floor_divide);
UNARY_POINTWISE2(floor_divide, Scalar);
BINARY_POINTWISE(fmax);
BINARY_POINTWISE(fmin);
BINARY_SCALAR_2(fmod, Tensor, Scalar);
POINTWISE_BOXED(frexp.Tensor);
BINARY_POINTWISE(heaviside);
BINARY_POINTWISE(hypot);
BINARY_POINTWISE(gcd);
BINARY_POINTWISE(igamma);
BINARY_POINTWISE(igammac);
BINARY_POINTWISE(logaddexp);
BINARY_POINTWISE(logaddexp2);
POINTWISE_BOXED(lerp.Scalar);
POINTWISE_BOXED(lerp.Tensor);
BINARY_POINTWISE(lcm);
POINTWISE_BOXED(log_sigmoid_forward);
BINARY_POINTWISE(maximum);
BINARY_POINTWISE(minimum);
BINARY_SCALAR_2(mul, Tensor, Scalar);
BINARY_POINTWISE(nextafter);
BINARY_SCALAR_3(pow, Tensor_Tensor, Tensor_Scalar, Scalar);
POINTWISE_BOXED2(pow_, Scalar);
BINARY_POINTWISE(polar);
POINTWISE_BOXED(polygamma);
BINARY_SCALAR_2(sub, Tensor, Scalar);
BINARY_SCALAR_3(remainder, Tensor, Scalar, Scalar_Tensor);
BINARY_POINTWISE(rrelu_with_noise);
BINARY_SCALAR_2(rsub, Tensor, Scalar);
BINARY_SCALAR_3_Tensor(special_xlog1py, other_scalar, self_scalar);
BINARY_SCALAR_3_Tensor(special_zeta, other_scalar, self_scalar);
VMAP_SUPPORT2(where, self, where_self_batch_rule);
BINARY_SCALAR_3(xlogy, Tensor, Scalar_Other, Scalar_Self);
POINTWISE_BOXED(elu_backward);
BINARY_POINTWISE(hardsigmoid_backward);
BINARY_POINTWISE(hardtanh_backward);
BINARY_POINTWISE(hardshrink_backward);
BINARY_POINTWISE(hardswish_backward);
BINARY_POINTWISE(_prelu_kernel);
VARIADIC_BDIMS_BOXED(_prelu_kernel_backward);
BINARY_POINTWISE(leaky_relu_backward);
BINARY_POINTWISE(logit_backward);
VMAP_SUPPORT(log_sigmoid_backward, log_sigmoid_backward_batch_rule);
VMAP_SUPPORT(gelu_backward, gelu_backward_batch_rule);
BINARY_POINTWISE(sigmoid_backward);
POINTWISE_BOXED(softplus_backward);
BINARY_POINTWISE(softshrink_backward);
BINARY_POINTWISE(tanh_backward);
BINARY_POINTWISE(threshold_backward);
BINARY_POINTWISE(silu_backward);
using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const;
using TensorInplaceModeT = Tensor& (Tensor::*)(const Tensor&, c10::optional<c10::string_view>) const;
using ScalarInplaceT = Tensor& (Tensor::*)(const Scalar&) const;
using CopyT = Tensor& (Tensor::*)(const Tensor&, bool) const;
POINTWISE_BOXED(add_.Tensor); // just testing
POINTWISE_BOXED(atan2_);
POINTWISE_BOXED(gcd_);
POINTWISE_BOXED(lcm_);
VMAP_SUPPORT2(add_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarScalarInplaceT, &Tensor::add_, const Scalar&, const Scalar&>));
VMAP_SUPPORT2(sub_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorScalarInplaceT, &Tensor::sub_, const Scalar&>));
VMAP_SUPPORT2(sub_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarScalarInplaceT, &Tensor::sub_, const Scalar&, const Scalar&>));
VMAP_SUPPORT2(mul_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::mul_>));
VMAP_SUPPORT2(mul_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::mul_, const Scalar&>));
VMAP_SUPPORT2(div_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::div_>));
VMAP_SUPPORT2(div_, Tensor_mode, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceModeT, &Tensor::div_, c10::optional<c10::string_view>>));
VMAP_SUPPORT2(div_, Scalar, SINGLE_ARG(unary_inplace_batch_rule<ScalarInplaceT, &Tensor::div_, const Scalar&>));
VMAP_SUPPORT2(clamp_min_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_min_>));
VMAP_SUPPORT2(clamp_max_, Tensor, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor::clamp_max_>));
VMAP_SUPPORT2(masked_fill_, Scalar, SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorScalarInplaceT, &Tensor::masked_fill_, const Scalar&>));
VMAP_SUPPORT(copy_, SINGLE_ARG(binary_pointwise_inplace_batch_rule<CopyT, &Tensor::copy_, bool>));
#define COMPARISON_POINTWISE(op) \
VMAP_SUPPORT2(op, Tensor, \
SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>)); \
UNARY_POINTWISE2(op, Scalar)
COMPARISON_POINTWISE(eq);
COMPARISON_POINTWISE(gt);
COMPARISON_POINTWISE(ge);
COMPARISON_POINTWISE(le);
COMPARISON_POINTWISE(lt);
COMPARISON_POINTWISE(ne);
#undef COMPARISON_POINTWISE
#undef BINARY_POINTWISE2
#undef BINARY_POINTWISE
#undef UNARY_POINTWISE2
#undef UNARY_POINTWISE
#undef UNARY_SCALAR_POINTWISE2
#undef BINARY_SCALAR_3
#define LOGICAL_COMPARISON_POINTWISE(op) \
VMAP_SUPPORT(op, \
SINGLE_ARG(comparison_pointwise_batch_rule<decltype(&ATEN_FN(op)), &ATEN_FN(op)>)); \
VMAP_SUPPORT(op ## _, \
SINGLE_ARG(binary_pointwise_inplace_batch_rule<TensorInplaceT, &Tensor:: op ## _ >));
LOGICAL_COMPARISON_POINTWISE(logical_and);
LOGICAL_COMPARISON_POINTWISE(logical_or);
LOGICAL_COMPARISON_POINTWISE(logical_xor);
#undef SINGLE_ARG
#undef LOGICAL_COMPARISON_POINTWISE
VMAP_SUPPORT(masked_select, masked_select_batch_rule);
VMAP_SUPPORT(masked_select_backward, masked_select_backward_batch_rule);
VMAP_SUPPORT2(fill_, Tensor, fill__Tensor_batch_rule);
}
} // namespace at::functorch