blob: d63d0a511d616fc53d928dd1a2928de8bc8e60c8 [file] [log] [blame]
#include <ATen/ATen.h>
#include <ATen/native/Repeat.h>
#include <ATen/Parallel.h>
static void compute_cpu(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) {
for(int64_t i = i_begin; i < i_end; i++) {
int64_t end = cumsum_ptr[i];
int64_t size = repeat_ptr[i];
int64_t start = end - size;
for(int64_t j = start; j < end; j++) {
result_ptr[j] = i;
}
}
});
}
namespace at { namespace native {
Tensor repeat_interleave_cpu(const Tensor &repeat) {
return repeat_interleave_common<compute_cpu>(repeat);
}
Tensor repeat_interleave(const Tensor &self, const Tensor &repeats, c10::optional<int64_t> dim) {
Tensor input = self;
if(!dim) {
input = self.flatten();
dim = 0;
}
Tensor repeats_ = repeats;
if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) {
repeats_ = repeats.reshape({1}).expand({input.size(dim.value())});
} else if (repeats.dim() == 1) {
TORCH_CHECK(repeats.size(0) == input.size(dim.value()), "repeats must have the same size as input along dim")
} else {
AT_ERROR("repeats must be 0-dim or 1-dim tensor");
}
return input.index_select(dim.value(), at::repeat_interleave(repeats_));
}
Tensor repeat_interleave(const Tensor &self, int64_t repeats, c10::optional<int64_t> dim) {
return at::native::repeat_interleave(self, at::tensor({repeats}, self.options().dtype(kLong)), dim);
}
}}