blob: 6e075abc189ef01af0a7bc0def6a99d134761eef [file] [log] [blame]
#include <ATen/ATen.h>
namespace at { namespace native {
Tensor constant_pad_nd(const Tensor& self, IntArrayRef pad, Scalar value) {
TORCH_CHECK(pad.size() % 2 == 0, "Length of pad must be even but instead it equals ",
pad.size());
auto input_sizes = self.sizes();
auto l_inp = self.dim();
auto l_pad = pad.size() / 2;
auto l_diff = l_inp - l_pad;
TORCH_CHECK(l_inp >= (int64_t)l_pad, "Length of pad should be no more than twice the number of "
"dimensions of the input. Pad length is ", pad.size(), "while the input has ",
l_inp, "dimensions.");
std::vector<int64_t> new_shape;
bool all_pads_non_positive = true;
auto c_input = self;
for (int i = l_diff; i < l_inp; i++) {
auto pad_idx = 2 * (l_inp - i - 1);
if (pad[pad_idx] < 0) {
c_input = c_input.narrow(i, -pad[pad_idx], c_input.size(i) + pad[pad_idx]);
} else if (pad[pad_idx] != 0) {
all_pads_non_positive = false;
}
if (pad[pad_idx + 1] < 0) {
c_input = c_input.narrow(i, 0, c_input.size(i) + pad[pad_idx + 1]);
} else if (pad[pad_idx + 1] != 0) {
all_pads_non_positive = false;
}
}
// if none of the pads are positive we can optimize and just return the result
// of calling .narrow() on the input
if (all_pads_non_positive) {
return c_input;
}
for (size_t i = 0; i < (size_t)l_diff; i ++) {
new_shape.emplace_back(input_sizes[i]);
}
for (size_t i = 0; i < (size_t)l_pad; i++) {
auto pad_idx = pad.size() - ((i + 1) * 2);
auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1];
TORCH_CHECK(new_dim > 0, "The input size ", input_sizes[l_diff + i], ", plus negative padding ",
pad[pad_idx], " and ", pad[pad_idx + 1], "resulted in a negative output size, "
"which is invalid. Check dimension ", l_diff + i, "of your input.");
new_shape.emplace_back(new_dim);
}
auto output = at::empty(new_shape, self.options());
output.fill_(value);
auto c_output = output;
for (int i = l_diff; i < l_inp; i++) {
auto pad_idx = 2 * (l_inp - i - 1);
if (pad[pad_idx] > 0) {
c_output = c_output.narrow(i, pad[pad_idx], c_output.size(i) - pad[pad_idx]);
}
if (pad[pad_idx + 1] > 0) {
c_output = c_output.narrow(i, 0, c_output.size(i) - pad[pad_idx + 1]);
}
}
c_output.copy_(c_input);
return output;
}
}} // namespace at::native