[functorch] Finished most of the rest of the pad batching rules (also added an existing_batch_dim_template)
diff --git a/functorch/functorch/csrc/BatchRulesHelper.h b/functorch/functorch/csrc/BatchRulesHelper.h
index 1ae69cf..dd6c152 100644
--- a/functorch/functorch/csrc/BatchRulesHelper.h
+++ b/functorch/functorch/csrc/BatchRulesHelper.h
@@ -17,6 +17,8 @@
#include <functorch/csrc/Constants.h>
namespace at { namespace functorch {
+Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
+Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
@@ -44,6 +46,14 @@
return std::make_tuple(Func(self_, std::forward<ExtraArgs>(extra_args)...), self_bdim.has_value() ? optional<int64_t>{0} : nullopt);
}
+template <typename F, F Func, typename... ExtraArgs>
+std::tuple<Tensor,optional<int64_t>> existing_bdim_batch_rule(const Tensor& self, optional<int64_t> self_bdim, ExtraArgs... extra_args) {
+ auto self_ = reshape_dim_into(*self_bdim, 0, self);
+ auto out = Func(self_, std::forward<ExtraArgs>(extra_args)...);
+ return std::make_tuple(reshape_dim_outof(0, self.sizes()[*self_bdim], out), 0);
+}
+
+
#define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
#define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
@@ -55,8 +65,5 @@
return self;
}
-Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
-Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
-
}}
diff --git a/functorch/functorch/csrc/BatchRulesModules.cpp b/functorch/functorch/csrc/BatchRulesModules.cpp
index 8663b04..3dce057 100644
--- a/functorch/functorch/csrc/BatchRulesModules.cpp
+++ b/functorch/functorch/csrc/BatchRulesModules.cpp
@@ -211,5 +211,10 @@
m.impl("cudnn_convolution", cudnn_convolution_plumbing);
OP_DECOMPOSE(dropout);
VMAP_SUPPORT("constant_pad_nd", SINGLE_ARG(basic_unary_batch_rule<decltype(&at::constant_pad_nd), &at::constant_pad_nd, IntArrayRef, const Scalar&>));
+ VMAP_SUPPORT("reflection_pad1d", SINGLE_ARG(existing_bdim_batch_rule<decltype(&at::reflection_pad1d), &at::reflection_pad1d, IntArrayRef>));
+ VMAP_SUPPORT("reflection_pad2d", SINGLE_ARG(existing_bdim_batch_rule<decltype(&at::reflection_pad2d), &at::reflection_pad2d, IntArrayRef>));
+ VMAP_SUPPORT("replication_pad1d", SINGLE_ARG(existing_bdim_batch_rule<decltype(&at::replication_pad1d), &at::replication_pad1d, IntArrayRef>));
+ VMAP_SUPPORT("replication_pad2d", SINGLE_ARG(existing_bdim_batch_rule<decltype(&at::replication_pad2d), &at::replication_pad2d, IntArrayRef>));
+ VMAP_SUPPORT("replication_pad3d", SINGLE_ARG(existing_bdim_batch_rule<decltype(&at::replication_pad3d), &at::replication_pad3d, IntArrayRef>));
}
}}
diff --git a/functorch/test/functorch_additional_op_db.py b/functorch/test/functorch_additional_op_db.py
index 7fbcb37..0f51c0b 100644
--- a/functorch/test/functorch_additional_op_db.py
+++ b/functorch/test/functorch_additional_op_db.py
@@ -349,12 +349,20 @@
])
def sample_inputs_pad(self, device, dtype, requires_grad, mode='constant', value=None):
- inp = make_tensor((2, 3, 4, 5), device=device, dtype=dtype,
- requires_grad=requires_grad, low=-1, high=1)
+ valid_lengths = {
+ 'constant': [1,2,3,4,5],
+ 'reflect': [3, 4],
+ 'replicate': [3, 4, 5],
+ 'circular': [3, 4, 5],
+ }
sample_inputs = []
- for pad in [(1, 1, 1, 1), (1, 2, 3, 0)]:
- args = (pad, mode) if value is None else (pad, mode, value)
- sample_inputs.append(SampleInput(inp, args=args))
+ for length in valid_lengths[mode]:
+ inp = make_tensor(list(range(2, length + 2)), device=device, dtype=dtype,
+ requires_grad=requires_grad, low=-1, high=1)
+ num_pad_dims = (length - 2) * 2
+ for pad in [[1] * num_pad_dims, [1, 2] * (num_pad_dims // 2)]:
+ args = (pad, mode) if value is None else (pad, mode, value)
+ sample_inputs.append(SampleInput(inp, args=args))
return sample_inputs
for mode in ['constant', 'reflect', 'replicate', 'circular']: