blob: caf6bc93263f6c8cfa2b8e4cd22ff42a7b0e55ad [file] [log] [blame]
#pragma once
#include <ATen/ATen.h>
#include <vector>
namespace at {
namespace native {
/*
[NativeFunction]
name: split
arg: Tensor self
arg: int64_t split_size
arg: int64_t dim
return: TensorList
variants: method, function
type_method_definition_level: base
type_method_definition_dispatch: at::native::split
[/NativeFunction]
*/
static inline std::vector<Tensor> split(const Tensor &self, int64_t split_size, int64_t dim) {
int64_t dim_size = self.size(dim);
int64_t num_splits = (dim_size + split_size - 1) / split_size;
std::vector<Tensor> splits(num_splits);
int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
for (int64_t i = 0; i < num_splits; ++i) {
auto length = i < num_splits - 1 ? split_size : last_split_size;
splits[i] = self.narrow(dim, i * split_size, length);
}
return splits;
}
/*
[NativeFunction]
name: chunk
arg: Tensor self
arg: int64_t chunks
arg: int64_t dim
return: TensorList
variants: method, function
type_method_definition_level: base
type_method_definition_dispatch: at::native::chunk
[/NativeFunction]
*/
static inline std::vector<Tensor> chunk(const Tensor &self, int64_t chunks, int64_t dim) {
int64_t split_size = (self.size(dim) + chunks - 1) / chunks;
return at::native::split(self, split_size, dim);
}
}
}