blob: 143783779f085f08a993da7d6e2e22943039ce0c [file] [log] [blame]
#include <ATen/native/SegmentReduce.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <c10/util/irange.h>
namespace at {
namespace native {
DEFINE_DISPATCH(_segment_reduce_stub);
DEFINE_DISPATCH(_segment_reduce_backward_stub);
namespace {
SegmentReductionType get_reduction_enum(const c10::string_view& reduce) {
if (reduce == "max") {
return SegmentReductionType::MAX;
} else if (reduce == "mean") {
return SegmentReductionType::MEAN;
} else if (reduce == "min") {
return SegmentReductionType::MIN;
} else if (reduce == "sum") {
return SegmentReductionType::SUM;
} else {
TORCH_CHECK(false, "unsopported reduction given! ", reduce);
}
}
template <typename T>
void _segment_reduce_cpu_kernel1(
SegmentReductionType reduction,
const Tensor& data,
const T* lengths_data,
int64_t axis,
const c10::optional<Scalar>& initial,
Tensor& output,
int64_t segment_count) {
int64_t stride_count = data.numel() / data.size(axis);
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", [&]() {
auto* output_data = output.data_ptr<scalar_t>();
const auto* values_data = data.data_ptr<scalar_t>();
int64_t lengths_cum_sum = 0;
for (const auto i : c10::irange(segment_count)) {
for (const auto l : c10::irange(stride_count)) {
// ===== step1: initialize starting value
scalar_t initial_value;
if (initial.has_value()) {
initial_value = initial.value().to<scalar_t>();
} else if (reduction == SegmentReductionType::MAX) {
initial_value = -std::numeric_limits<scalar_t>::infinity();
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = 0;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = std::numeric_limits<scalar_t>::infinity();
}
// ===== step2: apply reduction
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
const auto data = values_data[starting_index];
// TODO: There is no need to branch with every element
if (reduction == SegmentReductionType::MAX) {
initial_value = at::_isnan(data)
? data
: std::max<scalar_t>(initial_value, data);
} else if (
reduction == SegmentReductionType::MEAN ||
reduction == SegmentReductionType::SUM) {
initial_value = initial_value + data;
} else if (reduction == SegmentReductionType::MIN) {
initial_value = at::_isnan(data)
? data
: std::min<scalar_t>(initial_value, data);
}
}
// ===== step3: finalize reduction
TORCH_CHECK(lengths_data[i] >= 0);
if (lengths_data[i] == 0 && !initial.has_value() &&
reduction == SegmentReductionType::MEAN) {
initial_value = static_cast<scalar_t>(NAN);
} else if (
reduction == SegmentReductionType::MEAN &&
lengths_data[i] > 0 && !at::_isnan(initial_value)) {
initial_value = initial_value / lengths_data[i];
}
int64_t output_index = (i * stride_count) + l;
output_data[output_index] = initial_value;
}
lengths_cum_sum += lengths_data[i];
}
});
}
Tensor _segment_reduce_cpu_kernel(
SegmentReductionType reduction,
const Tensor& data,
const Tensor& lengths,
int64_t axis,
const c10::optional<Scalar>& initial) {
int64_t segment_count = lengths.numel();
auto output_shape = data.sizes().vec();
output_shape[axis] = segment_count;
auto output = at::empty(output_shape, data.options());
AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_cpu_kernel1", [&]() {
const auto* lengths_data = lengths.data_ptr<index_t>();
_segment_reduce_cpu_kernel1(
reduction, data, lengths_data, axis, initial, output, segment_count);
});
return output;
}
template <typename T>
void _segment_reduce_cpu_backward_kernel1(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
const T* lengths_data,
int64_t axis,
Tensor& grad_input,
int64_t segment_count) {
int64_t stride_count = data_contig.numel() / data_contig.size(axis);
// TODO: Swtich to TensorIterator for better maintainablility and
// readability
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
kHalf,
data_contig.scalar_type(),
"_segment_reduce_cpu",
[&]() {
auto* output_data = output_contig.data_ptr<scalar_t>();
auto* grad_data = grad_contig.data_ptr<scalar_t>();
auto* grad_input_data = grad_input.data_ptr<scalar_t>();
const auto* values_data = data_contig.data_ptr<scalar_t>();
int64_t lengths_cum_sum = 0;
for (const auto i : c10::irange(segment_count)) {
if (lengths_data[i] == 0) {
continue;
}
for (const auto l : c10::irange(stride_count)) {
int64_t output_index = (i * stride_count) + l;
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
int64_t counter = 0;
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
if (at::_isnan(values_data[starting_index]) ||
values_data[starting_index] == output_data[output_index]) {
grad_input_data[starting_index] = grad_data[output_index];
counter++;
}
}
// Average gradient based on number of maximum elements in
// the segment
if (counter < 2) {
continue;
}
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
if (grad_input_data[starting_index] > 0) {
grad_input_data[starting_index] =
grad_input_data[starting_index] / counter;
}
}
} else if (reduction == SegmentReductionType::MEAN) {
auto grad_val = grad_data[output_index] / lengths_data[i];
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
grad_input_data[starting_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
for (int64_t j = 0; j < lengths_data[i]; ++j) {
int64_t starting_index =
((lengths_cum_sum + j) * stride_count) + l;
grad_input_data[starting_index] = grad_val;
}
}
}
lengths_cum_sum += lengths_data[i];
}
});
}
Tensor _segment_reduce_cpu_backward_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
const Tensor& lengths_contig,
int64_t axis) {
int64_t segment_count = lengths_contig.numel();
auto output_shape = data_contig.sizes().vec();
output_shape[axis] = segment_count;
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
AT_DISPATCH_INDEX_TYPES(
lengths_contig.scalar_type(), "_segment_reduce_cpu_backward_kernel1", [&] {
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
_segment_reduce_cpu_backward_kernel1(
grad_contig,
output_contig,
data_contig,
reduction,
lengths_data,
axis,
grad_input,
segment_count);
});
return grad_input;
}
} // namespace
Tensor segment_reduce_kernel(
const Tensor& data,
c10::string_view reduce,
const c10::optional<Tensor>& lengths,
const c10::optional<Tensor>& indices,
int64_t axis,
bool unsafe,
const c10::optional<Scalar>& initial) {
axis = maybe_wrap_dim(axis, data.ndimension());
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported! ", axis);
TORCH_CHECK(data.numel() > 0);
// length related checks
TORCH_CHECK(
lengths.has_value() && !indices.has_value(),
"Currently only lengths based reduction is supported!")
const auto& lengths_value = lengths.value();
TORCH_CHECK(lengths_value.dim() == 1);
TORCH_CHECK(data.get_device() == lengths_value.get_device());
TORCH_CHECK(data.dim() >= lengths_value.dim());
if (!unsafe) {
auto min_length = lengths_value.min().item<int64_t>();
TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
TORCH_CHECK(lengths_value.sum().item<int64_t>() == data.size(axis));
}
auto reduction = get_reduction_enum(reduce);
const auto data_contig = data.contiguous();
const auto lengths_contig = lengths_value.contiguous();
return _segment_reduce_stub(
data_contig.device().type(),
reduction,
data_contig,
lengths_contig,
axis,
initial);
}
REGISTER_ARCH_DISPATCH(
_segment_reduce_stub,
DEFAULT,
&_segment_reduce_cpu_kernel);
REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
REGISTER_AVX512_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
// Currently some computation is being duplicated across forward and backward.
// TODO: Cache indices in forward pass to re-use in backward
Tensor _segment_reduce_backward_kernel(
const Tensor& grad,
const Tensor& output,
const Tensor& data,
c10::string_view reduce,
const c10::optional<Tensor>& lengths,
int64_t axis) {
axis = maybe_wrap_dim(axis, data.ndimension());
TORCH_CHECK(axis == 0, "Currently only dim=0 is supported! ", axis);
TORCH_CHECK(
lengths.has_value(),
"Currently only lengths based reduction is supported!")
const auto& lengths_value = lengths.value();
const auto grad_contig = grad.contiguous();
const auto output_contig = output.contiguous();
const auto data_contig = data.contiguous();
const auto lengths_contig = lengths_value.contiguous();
auto reduction = get_reduction_enum(reduce);
return _segment_reduce_backward_stub(
grad_contig.device().type(),
grad_contig,
output_contig,
data_contig,
reduction,
lengths_contig,
axis);
}
REGISTER_ARCH_DISPATCH(
_segment_reduce_backward_stub,
DEFAULT,
&_segment_reduce_cpu_backward_kernel);
REGISTER_AVX512_DISPATCH(
_segment_reduce_backward_stub,
&_segment_reduce_cpu_backward_kernel);
REGISTER_AVX2_DISPATCH(
_segment_reduce_backward_stub,
&_segment_reduce_cpu_backward_kernel);
REGISTER_VSX_DISPATCH(
_segment_reduce_backward_stub,
&_segment_reduce_cpu_backward_kernel);
} // namespace native
} // namespace at