Add offsets-based reduction to segment_reduce (CPU, CUDA)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78907
Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp
index 1139515..85b19fd 100644
--- a/aten/src/ATen/native/SegmentReduce.cpp
+++ b/aten/src/ATen/native/SegmentReduce.cpp
@@ -8,8 +8,10 @@
namespace at {
namespace native {
-DEFINE_DISPATCH(_segment_reduce_stub);
-DEFINE_DISPATCH(_segment_reduce_backward_stub);
+DEFINE_DISPATCH(_segment_reduce_lengths_stub);
+DEFINE_DISPATCH(_segment_reduce_offsets_stub);
+DEFINE_DISPATCH(_segment_reduce_lengths_backward_stub);
+DEFINE_DISPATCH(_segment_reduce_offsets_backward_stub);
namespace {
@@ -29,8 +31,8 @@
}
}
-template <typename T>
-void _segment_reduce_cpu_kernel1(
+template <typename T, bool is_offsets_like=false>
+void _segment_reduce_lengths_cpu_kernel1(
SegmentReductionType reduction,
const Tensor& data,
const T* lengths_data,
@@ -46,14 +48,30 @@
outer_offset *= output.size(d);
for (int64_t d = axis + 1; d < output.dim(); d++)
inner_offset *= output.size(d);
+ int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
+ auto data_stride_axis = data.stride(axis);
+ auto data_size_axis = data.size(axis);
+ auto output_stride_axis = output.stride(axis);
+ auto output_size_axis = output.size(axis);
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", [&]() {
auto* output_data = output.data_ptr<scalar_t>();
const auto* values_data = data.data_ptr<scalar_t>();
for (const auto outer_idx : c10::irange(outer_offset)) {
- int64_t lengths_cum_sum = 0;
+ int64_t segment_start, segment_length;
+ int64_t segment_end = is_offsets_like ?
+ lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
+ 0;
for (const auto dim_idx : c10::irange(segment_count)) {
- int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx];
+ segment_start = segment_end;
+ auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
+ if (is_offsets_like) {
+ segment_end = lengths_data[lengths_idx + 1];
+ segment_length = segment_end - segment_start;
+ } else {
+ segment_length = lengths_data[lengths_idx];
+ segment_end += segment_length;
+ }
for (const auto inner_idx : c10::irange(inner_offset)) {
// ===== step1: initialize starting value
scalar_t initial_value;
@@ -72,9 +90,9 @@
}
// ===== step2: apply reduction
- for (const auto j : c10::irange(segment_length)) {
- int64_t data_index = outer_idx * data.stride(axis) * data.size(axis)
- + (lengths_cum_sum + j) * data.stride(axis) + inner_idx;
+ for (const auto j : c10::irange(segment_start, segment_end)) {
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ + j * data_stride_axis + inner_idx;
const auto val = values_data[data_index];
if (reduction == SegmentReductionType::MAX) {
initial_value = at::_isnan(val)
@@ -104,17 +122,16 @@
segment_length > 0 && !at::_isnan(initial_value)) {
initial_value = initial_value / segment_length;
}
- int64_t output_index = outer_idx * output.stride(axis) * output.size(axis)
- + dim_idx * output.stride(axis) + inner_idx;
+ int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+ + dim_idx * output_stride_axis + inner_idx;
output_data[output_index] = initial_value;
}
- lengths_cum_sum += segment_length;
}
}
});
}
-Tensor _segment_reduce_cpu_kernel(
+Tensor _segment_reduce_lengths_cpu_kernel(
SegmentReductionType reduction,
const Tensor& data,
const Tensor& lengths,
@@ -131,17 +148,43 @@
output_shape[axis] = segment_count;
auto output = at::empty(output_shape, data.options());
- AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_cpu_kernel1", [&]() {
+ AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_lengths_cpu_kernel1", [&]() {
const auto* lengths_data = lengths.data_ptr<index_t>();
- _segment_reduce_cpu_kernel1(
+ _segment_reduce_lengths_cpu_kernel1(
reduction, data, lengths_data, axis, initial, output, segment_count, lengths_stride_axis);
});
return output;
}
-template <typename T>
-void _segment_reduce_cpu_backward_kernel1(
+Tensor _segment_reduce_offsets_cpu_kernel(
+ SegmentReductionType reduction,
+ const Tensor& data,
+ const Tensor& offsets,
+ int64_t axis,
+ const c10::optional<Scalar>& initial) {
+ // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
+ TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
+ TORCH_CHECK(offsets.is_contiguous(), "Expected offsets to be contiguous.");
+ // reduction axis should always be the last dimension of lengths
+ axis = offsets.dim() - 1;
+ int64_t segment_count = offsets.size(axis) - 1;
+ int64_t offsets_stride_axis = offsets.stride(axis);
+ auto output_shape = data.sizes().vec();
+ output_shape[axis] = segment_count;
+ auto output = at::empty(output_shape, data.options());
+
+ AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_segment_reduce_offsets_cpu_kernel1", [&]() {
+ const auto* offsets_data = offsets.data_ptr<index_t>();
+ _segment_reduce_lengths_cpu_kernel1<index_t, /*is_offsets_like=*/true>(
+ reduction, data, offsets_data, axis, initial, output, segment_count, offsets_stride_axis);
+ });
+
+ return output;
+}
+
+template <typename T, bool is_offsets_like = false>
+void _segment_reduce_cpu_lengths_backward_kernel1(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
@@ -159,7 +202,12 @@
outer_offset *= output_contig.size(d);
for (int64_t d = axis + 1; d < output_contig.dim(); d++)
inner_offset *= output_contig.size(d);
- // TODO: Swtich to TensorIterator for better maintainablility and
+ int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
+ auto data_stride_axis = data_contig.stride(axis);
+ auto data_size_axis = data_contig.size(axis);
+ auto output_stride_axis = output_contig.stride(axis);
+ auto output_size_axis = output_contig.size(axis);
+ // TODO: Switch to TensorIterator for better maintainablility and
// readability
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16,
@@ -182,21 +230,34 @@
}
for (const auto outer_idx : c10::irange(outer_offset)) {
- int64_t lengths_cum_sum = 0;
+ // int64_t lengths_cum_sum = 0;
+ int64_t segment_start, segment_length;
+ int64_t segment_end = is_offsets_like ?
+ lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
+ 0;
for (const auto dim_idx : c10::irange(segment_count)) {
- int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx];
+ // int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx];
+ segment_start = segment_end;
+ auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
+ if (is_offsets_like) {
+ segment_end = lengths_data[lengths_idx + 1];
+ segment_length = segment_end - segment_start;
+ } else {
+ segment_length = lengths_data[lengths_idx];
+ segment_end += segment_length;
+ }
if (segment_length == 0) {
continue;
}
for (const auto inner_idx : c10::irange(inner_offset)) {
- int64_t output_index = outer_idx * output_contig.stride(axis) * output_contig.size(axis)
- + dim_idx * output_contig.stride(axis) + inner_idx;
+ int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+ + dim_idx * output_stride_axis + inner_idx;
if (reduction == SegmentReductionType::MAX ||
reduction == SegmentReductionType::MIN) {
int64_t counter = 0;
- for (const auto j : c10::irange(segment_length)) {
- int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
- + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+ for (const auto j : c10::irange(segment_start, segment_end)) {
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ + j * data_stride_axis + inner_idx;
if (at::_isnan(values_data[data_index]) ||
values_data[data_index] == output_data[output_index]) {
grad_input_data[data_index] = grad_data[output_index];
@@ -208,9 +269,9 @@
if (counter < 2) {
continue;
}
- for (const auto j : c10::irange(segment_length)) {
- int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
- + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+ for (const auto j : c10::irange(segment_start, segment_end)) {
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ + j * data_stride_axis + inner_idx;
if (grad_input_data[data_index] > 0) {
grad_input_data[data_index] =
grad_input_data[data_index] / counter;
@@ -218,32 +279,32 @@
}
} else if (reduction == SegmentReductionType::MEAN) {
auto grad_val = grad_data[output_index] / segment_length;
- for (const auto j : c10::irange(segment_length)) {
- int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
- + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+ for (const auto j : c10::irange(segment_start, segment_end)) {
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ + j * data_stride_axis + inner_idx;
grad_input_data[data_index] = grad_val;
}
} else if (reduction == SegmentReductionType::SUM) {
const auto& grad_val = grad_data[output_index];
- for (const auto j : c10::irange(segment_length)) {
- int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
- + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+ for (const auto j : c10::irange(segment_start, segment_end)) {
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ + j * data_stride_axis + inner_idx;
grad_input_data[data_index] = grad_val;
}
} else if (reduction == SegmentReductionType::PROD) {
const auto& grad_val = grad_data[output_index] * output_data[output_index];
- for (const auto j : c10::irange(segment_length)) {
- int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
- + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+ for (const auto j : c10::irange(segment_start, segment_end)) {
+ int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+ + j * data_stride_axis + inner_idx;
if (at::_isnan(values_data[data_index]) ||
values_data[data_index] == 0) {
// explicitly compute exclusive prod
scalar_t exclusive_prod = initial_prod_value;
int64_t idx;
- for (const auto k : c10::irange(segment_length)) {
+ for (const auto k : c10::irange(segment_start, segment_end)) {
if (k != j) {
- idx = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
- + (lengths_cum_sum + k) * data_contig.stride(axis) + inner_idx;
+ idx = outer_idx * data_stride_axis * data_size_axis
+ + k * data_stride_axis + inner_idx;
exclusive_prod *= values_data[idx];
}
}
@@ -254,13 +315,12 @@
}
}
}
- lengths_cum_sum += segment_length;
}
}
});
}
-Tensor _segment_reduce_cpu_backward_kernel(
+Tensor _segment_reduce_cpu_lengths_backward_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
@@ -274,9 +334,9 @@
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
AT_DISPATCH_INDEX_TYPES(
- lengths_contig.scalar_type(), "_segment_reduce_cpu_backward_kernel1", [&] {
+ lengths_contig.scalar_type(), "_segment_reduce_cpu_lengths_backward_kernel1", [&] {
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
- _segment_reduce_cpu_backward_kernel1(
+ _segment_reduce_cpu_lengths_backward_kernel1(
grad_contig,
output_contig,
data_contig,
@@ -292,6 +352,39 @@
return grad_input;
}
+
+Tensor _segment_reduce_cpu_offsets_backward_kernel(
+ const Tensor& grad_contig,
+ const Tensor& output_contig,
+ const Tensor& data_contig,
+ SegmentReductionType reduction,
+ const Tensor& offsets_contig,
+ int64_t axis,
+ const c10::optional<Scalar>& initial) {
+ axis = offsets_contig.dim() - 1;
+ int64_t segment_count = offsets_contig.size(axis) - 1;
+ int64_t offsets_stride_axis = offsets_contig.stride(axis);
+ auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
+
+ AT_DISPATCH_INDEX_TYPES(
+ offsets_contig.scalar_type(), "_segment_reduce_cpu_offsets_backward_kernel1", [&] {
+ const auto* offsets_data = offsets_contig.data_ptr<index_t>();
+ _segment_reduce_cpu_lengths_backward_kernel1<index_t, /*is_offsets_like=*/true>(
+ grad_contig,
+ output_contig,
+ data_contig,
+ reduction,
+ offsets_data,
+ axis,
+ initial,
+ grad_input,
+ segment_count,
+ offsets_stride_axis);
+ });
+
+ return grad_input;
+}
+
} // namespace
Tensor segment_reduce_kernel(
@@ -299,49 +392,94 @@
c10::string_view reduce,
const c10::optional<Tensor>& lengths,
const c10::optional<Tensor>& indices,
+ const c10::optional<Tensor>& offsets,
int64_t axis,
bool unsafe,
const c10::optional<Scalar>& initial) {
axis = maybe_wrap_dim(axis, data.ndimension());
TORCH_CHECK(data.numel() > 0);
- // length related checks
+ // check that one of lengths or offsets is defined
+ auto lengths_has_value = lengths.has_value();
+ auto offsets_has_value = offsets.has_value();
TORCH_CHECK(
- lengths.has_value() && !indices.has_value(),
- "Currently only lengths based reduction is supported!")
- const auto& lengths_value = lengths.value();
- TORCH_CHECK(data.get_device() == lengths_value.get_device());
- TORCH_CHECK(data.dim() >= lengths_value.dim());
- TORCH_CHECK(axis == lengths_value.dim() - 1, "Expected axis to be equal to lengths.ndim() - 1 but got ", axis, ".");
-
- if (!unsafe) {
- auto min_length = lengths_value.min().item<int64_t>();
- TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
- TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item<bool>(),
- "Expected all rows of lengths to sum to data.size(lengths.dim()-1) when unsafe=False");
- }
+ !indices.has_value(),
+ "segment_reduce(): indices based reduction is not supported yet.");
+ TORCH_CHECK(
+ lengths_has_value || offsets_has_value,
+ "segment_reduce(): Either lengths or offsets must be defined.")
auto reduction = get_reduction_enum(reduce);
const auto data_contig = data.contiguous();
- const auto lengths_contig = lengths_value.contiguous();
- return _segment_reduce_stub(
+ if (offsets_has_value) {
+ const auto& offsets_value = offsets.value();
+
+ // offsets related checks
+ TORCH_CHECK(data.get_device() == offsets_value.get_device());
+ TORCH_CHECK(data.dim() >= offsets_value.dim());
+ TORCH_CHECK(axis == offsets_value.dim() - 1,
+ "segment_reduce(): Expected axis to be the last dimension of offsets but got ", axis, ".");
+
+ // TODO: add checks when !unsafe
+
+ const auto offsets_contig = offsets_value.contiguous();
+
+ return _segment_reduce_offsets_stub(
+ data_contig.device().type(),
+ reduction,
+ data_contig,
+ offsets_contig,
+ axis,
+ initial);
+
+ } else {
+ const auto& lengths_value = lengths.value();
+
+ // length related checks
+ TORCH_CHECK(data.get_device() == lengths_value.get_device());
+ TORCH_CHECK(data.dim() >= lengths_value.dim());
+ TORCH_CHECK(axis == lengths_value.dim() - 1,
+ "segment_reduce(): Expected axis to be the last dimension of lengths but got ", axis, ".");
+
+ if (!unsafe) {
+ auto min_length = lengths_value.min().item<int64_t>();
+ TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
+ TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item<bool>(),
+ "segment_reduce(): Expected all rows of lengths along axis ",
+ "to sum to data.size(lengths.dim()-1) when !unsafe.");
+ }
+
+ const auto lengths_contig = lengths_value.contiguous();
+
+ return _segment_reduce_lengths_stub(
data_contig.device().type(),
reduction,
data_contig,
lengths_contig,
axis,
initial);
+ }
}
REGISTER_ARCH_DISPATCH(
- _segment_reduce_stub,
+ _segment_reduce_lengths_stub,
DEFAULT,
- &_segment_reduce_cpu_kernel);
-REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
-REGISTER_AVX512_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
-REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
-REGISTER_ZVECTOR_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
+ &_segment_reduce_lengths_cpu_kernel);
+REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
+REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
+REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
+REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
+
+// offsets dispatches
+REGISTER_ARCH_DISPATCH(
+ _segment_reduce_offsets_stub,
+ DEFAULT,
+ &_segment_reduce_offsets_cpu_kernel);
+REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
+REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
+REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
+REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
// Currently some computation is being duplicated across forward and backward.
// TODO: Cache indices in forward pass to re-use in backward
@@ -351,21 +489,40 @@
const Tensor& data,
c10::string_view reduce,
const c10::optional<Tensor>& lengths,
+ const c10::optional<Tensor>& offsets,
int64_t axis,
const c10::optional<Scalar>& initial) {
axis = maybe_wrap_dim(axis, data.ndimension());
+ // check that one of lengths or offsets is defined
+ // codegen for derivatives.yaml passes an undefined Tensor for None rather than a c10::optional
+ // so checking has_value() doesn't work unlike in the forward pass
+ auto lengths_has_value = lengths.has_value() && lengths.value().defined();
+ auto offsets_has_value = offsets.has_value() && offsets.value().defined();
TORCH_CHECK(
- lengths.has_value(),
- "Currently only lengths based reduction is supported!")
- const auto& lengths_value = lengths.value();
+ lengths_has_value || offsets_has_value,
+ "segment_reduce(): Either lengths or offsets must be defined.");
const auto grad_contig = grad.contiguous();
const auto output_contig = output.contiguous();
const auto data_contig = data.contiguous();
- const auto lengths_contig = lengths_value.contiguous();
-
auto reduction = get_reduction_enum(reduce);
- return _segment_reduce_backward_stub(
+
+ if (offsets_has_value) {
+ const auto& offsets_value = offsets.value();
+ const auto offsets_contig = offsets_value.contiguous();
+ return _segment_reduce_offsets_backward_stub(
+ grad_contig.device().type(),
+ grad_contig,
+ output_contig,
+ data_contig,
+ reduction,
+ offsets_contig,
+ axis,
+ initial);
+ } else {
+ const auto& lengths_value = lengths.value();
+ const auto lengths_contig = lengths_value.contiguous();
+ return _segment_reduce_lengths_backward_stub(
grad_contig.device().type(),
grad_contig,
output_contig,
@@ -374,24 +531,42 @@
lengths_contig,
axis,
initial);
+ }
}
REGISTER_ARCH_DISPATCH(
- _segment_reduce_backward_stub,
+ _segment_reduce_lengths_backward_stub,
DEFAULT,
- &_segment_reduce_cpu_backward_kernel);
+ &_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_AVX512_DISPATCH(
- _segment_reduce_backward_stub,
- &_segment_reduce_cpu_backward_kernel);
+ _segment_reduce_lengths_backward_stub,
+ &_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_AVX2_DISPATCH(
- _segment_reduce_backward_stub,
- &_segment_reduce_cpu_backward_kernel);
+ _segment_reduce_lengths_backward_stub,
+ &_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_VSX_DISPATCH(
- _segment_reduce_backward_stub,
- &_segment_reduce_cpu_backward_kernel);
+ _segment_reduce_lengths_backward_stub,
+ &_segment_reduce_cpu_lengths_backward_kernel);
REGISTER_ZVECTOR_DISPATCH(
- _segment_reduce_backward_stub,
- &_segment_reduce_cpu_backward_kernel);
+ _segment_reduce_lengths_backward_stub,
+ &_segment_reduce_cpu_lengths_backward_kernel);
+
+REGISTER_ARCH_DISPATCH(
+ _segment_reduce_offsets_backward_stub,
+ DEFAULT,
+ &_segment_reduce_cpu_offsets_backward_kernel);
+REGISTER_AVX512_DISPATCH(
+ _segment_reduce_offsets_backward_stub,
+ &_segment_reduce_cpu_offsets_backward_kernel);
+REGISTER_AVX2_DISPATCH(
+ _segment_reduce_offsets_backward_stub,
+ &_segment_reduce_cpu_offsets_backward_kernel);
+REGISTER_VSX_DISPATCH(
+ _segment_reduce_offsets_backward_stub,
+ &_segment_reduce_cpu_offsets_backward_kernel);
+REGISTER_ZVECTOR_DISPATCH(
+ _segment_reduce_offsets_backward_stub,
+ &_segment_reduce_cpu_offsets_backward_kernel);
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/SegmentReduce.h b/aten/src/ATen/native/SegmentReduce.h
index a7cb5f8..7fb1512 100644
--- a/aten/src/ATen/native/SegmentReduce.h
+++ b/aten/src/ATen/native/SegmentReduce.h
@@ -11,15 +11,23 @@
enum SegmentReductionType { MAX, MEAN, MIN, SUM, PROD};
-using segment_reduce_fn = Tensor (*)(
+using segment_reduce_lengths_fn = Tensor (*)(
SegmentReductionType,
const Tensor&,
const Tensor&,
int64_t,
const c10::optional<Scalar>&);
-DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub);
+DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
-using segment_reduce_backward_fn = Tensor (*)(
+using segment_reduce_offsets_fn = Tensor (*)(
+ SegmentReductionType,
+ const Tensor&,
+ const Tensor&,
+ int64_t,
+ const c10::optional<Scalar>&);
+DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
+
+using segment_reduce_lengths_backward_fn = Tensor (*)(
const Tensor&,
const Tensor&,
const Tensor&,
@@ -27,7 +35,17 @@
const Tensor&,
int64_t,
const c10::optional<Scalar>&);
-DECLARE_DISPATCH(segment_reduce_backward_fn, _segment_reduce_backward_stub);
+DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
+
+using segment_reduce_offsets_backward_fn = Tensor (*)(
+ const Tensor&,
+ const Tensor&,
+ const Tensor&,
+ SegmentReductionType,
+ const Tensor&,
+ int64_t,
+ const c10::optional<Scalar>&);
+DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/cuda/SegmentReduce.cu b/aten/src/ATen/native/cuda/SegmentReduce.cu
index ab8571d..bfaa5ca 100644
--- a/aten/src/ATen/native/cuda/SegmentReduce.cu
+++ b/aten/src/ATen/native/cuda/SegmentReduce.cu
@@ -70,7 +70,7 @@
offsets[0].zero_();
AT_DISPATCH_INDEX_TYPES(
- lengths.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
+ lengths.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
at::cuda::cub::inclusive_sum(
@@ -278,23 +278,33 @@
}
} // namespace
-Tensor _segment_reduce_cuda_backward_kernel(
+Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel(
const Tensor& grad_contig,
const Tensor& output_contig,
const Tensor& data_contig,
SegmentReductionType reduction,
- const Tensor& lengths_contig,
+ const Tensor& lengths_or_offsets_contig,
int64_t axis,
- const c10::optional<Scalar>& initial) {
- axis = lengths_contig.dim() - 1;
- int64_t segment_count = lengths_contig.size(axis);
- int64_t lengths_stride_axis = lengths_contig.stride(axis);
+ const c10::optional<Scalar>& initial,
+ bool is_offsets_like) {
+ axis = lengths_or_offsets_contig.dim() - 1;
+ int64_t segment_count = is_offsets_like ?
+ lengths_or_offsets_contig.size(axis) - 1 :
+ lengths_or_offsets_contig.size(axis);
+ int64_t lengths_stride_axis = lengths_or_offsets_contig.stride(axis);
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
- auto zeros_shape = lengths_contig.sizes().vec();
- zeros_shape[axis] = 1;
- auto offsets = at::cat({at::zeros(zeros_shape, lengths_contig.options()), lengths_contig}, axis);
- offsets.cumsum_(axis);
+ auto offsets = lengths_or_offsets_contig;
+ auto lengths = lengths_or_offsets_contig;
+ if (is_offsets_like) {
+ lengths = lengths.diff();
+ } else {
+ // _get_complete_sum only supports 1D
+ auto zeros_shape = offsets.sizes().vec();
+ zeros_shape[axis] = 1;
+ offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
+ offsets.cumsum_(axis);
+ }
// outer_offset is the size of the outer dimensions of output (before axis)
// inner_offset is the size of the inner dimensions of output (after axis)
@@ -318,8 +328,8 @@
auto offsets_stride_axis = offsets.stride(axis);
AT_DISPATCH_INDEX_TYPES(
- lengths_contig.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
- const auto* lengths_data = lengths_contig.data_ptr<index_t>();
+ lengths_or_offsets_contig.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
+ const auto* lengths_data = lengths.data_ptr<index_t>();
auto* offsets_data = offsets.data_ptr<index_t>();
// TODO: Switch to TensorIterator for better maintainablility and
@@ -371,27 +381,59 @@
return grad_input;
}
-Tensor _segment_reduce_cuda_kernel(
- SegmentReductionType reduction,
- const Tensor& data,
- const Tensor& lengths,
- int64_t axis,
- const c10::optional<Scalar>& initial) {
- // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
- TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
- TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous.");
- axis = lengths.dim() - 1;
- int64_t segment_count = lengths.size(axis);
- int64_t lengths_stride_axis = lengths.stride(axis);
+Tensor _segment_reduce_lengths_backward_cuda_kernel(
+ const Tensor& grad_contig,
+ const Tensor& output_contig,
+ const Tensor& data_contig,
+ SegmentReductionType reduction,
+ const Tensor& lengths_contig,
+ int64_t axis,
+ const c10::optional<Scalar>& initial) {
+ return _segment_reduce_lengths_offsets_backward_cuda_kernel(
+ grad_contig, output_contig, data_contig, reduction, lengths_contig, axis, initial, /*is_offsets_like=*/false);
+}
+
+Tensor _segment_reduce_offsets_backward_cuda_kernel(
+ const Tensor& grad_contig,
+ const Tensor& output_contig,
+ const Tensor& data_contig,
+ SegmentReductionType reduction,
+ const Tensor& offsets_contig,
+ int64_t axis,
+ const c10::optional<Scalar>& initial) {
+ return _segment_reduce_lengths_offsets_backward_cuda_kernel(
+ grad_contig, output_contig, data_contig, reduction, offsets_contig, axis, initial, /*is_offsets_like=*/true);
+}
+
+Tensor _segment_reduce_lengths_offsets_cuda_kernel(
+ SegmentReductionType reduction,
+ const Tensor& data,
+ const Tensor& lengths_or_offsets,
+ int64_t axis,
+ const c10::optional<Scalar>& initial,
+ bool is_offsets_like) {
+ // data and lengths_or_offsets should be contiguous from the call to .contiguous in segment_reduce_kernel
+ TORCH_CHECK(data.is_contiguous());
+ TORCH_CHECK(lengths_or_offsets.is_contiguous());
+ axis = lengths_or_offsets.dim() - 1;
+ int64_t segment_count = is_offsets_like ? lengths_or_offsets.size(axis) - 1 : lengths_or_offsets.size(axis);
+ int64_t lengths_stride_axis = lengths_or_offsets.stride(axis);
auto output_shape = data.sizes().vec();
output_shape[axis] = segment_count;
auto output = at::empty(output_shape, data.options());
- // _get_complete_sum only supports 1D?
- auto zeros_shape = lengths.sizes().vec();
- zeros_shape[axis] = 1;
- auto offsets = at::cat({at::zeros(zeros_shape, lengths.options()), lengths}, axis);
- offsets.cumsum_(axis);
+
+ auto offsets = lengths_or_offsets;
+ auto lengths = lengths_or_offsets;
+ if (is_offsets_like) {
+ lengths = lengths.diff();
+ } else {
+ // _get_complete_sum only supports 1D
+ auto zeros_shape = offsets.sizes().vec();
+ zeros_shape[axis] = 1;
+ offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
+ offsets.cumsum_(axis);
+ }
// outer_offset is the size of the outer dimensions of output (before axis)
// inner_offset is the size of the inner dimensions of output (after axis)
@@ -416,7 +458,7 @@
auto offsets_stride_axis = offsets.stride(axis);
AT_DISPATCH_INDEX_TYPES(
- lengths.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
+ lengths_or_offsets.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
auto* offsets_data_ptr = offsets.data_ptr<index_t>();
auto* lengths_data_ptr = lengths.data_ptr<index_t>();
AT_DISPATCH_FLOATING_TYPES_AND2(
@@ -549,10 +591,34 @@
return output;
}
-REGISTER_DISPATCH(_segment_reduce_stub, &_segment_reduce_cuda_kernel);
+Tensor _segment_reduce_lengths_cuda_kernel(
+ SegmentReductionType reduction,
+ const Tensor& data,
+ const Tensor& lengths,
+ int64_t axis,
+ const c10::optional<Scalar>& initial) {
+ return _segment_reduce_lengths_offsets_cuda_kernel(
+ reduction, data, lengths, axis, initial, /*is_offsets_like=*/false);
+}
+
+Tensor _segment_reduce_offsets_cuda_kernel(
+ SegmentReductionType reduction,
+ const Tensor& data,
+ const Tensor& offsets,
+ int64_t axis,
+ const c10::optional<Scalar>& initial) {
+ return _segment_reduce_lengths_offsets_cuda_kernel(
+ reduction, data, offsets, axis, initial, /*is_offsets_like=*/true);
+}
+
+REGISTER_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cuda_kernel);
+REGISTER_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cuda_kernel);
REGISTER_DISPATCH(
- _segment_reduce_backward_stub,
- &_segment_reduce_cuda_backward_kernel);
+ _segment_reduce_lengths_backward_stub,
+ &_segment_reduce_lengths_backward_cuda_kernel);
+REGISTER_DISPATCH(
+ _segment_reduce_offsets_backward_stub,
+ &_segment_reduce_offsets_backward_cuda_kernel);
} // namespace native
} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 1b18d4d..cc88af0 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -11987,12 +11987,12 @@
dispatch:
CompositeExplicitAutograd: _test_warn_in_autograd
-- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
+- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
variants: function
dispatch:
CPU, CUDA: segment_reduce_kernel
-- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, int axis=0, Scalar? initial=None) -> Tensor
+- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor
variants: function
dispatch:
CPU, CUDA: _segment_reduce_backward_kernel
diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
index 6b9bf35..5ab7285 100644
--- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py
+++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
@@ -143,6 +143,8 @@
("aten::_csr_to_block_csr", datetime.date(2022, 5, 20)),
("aten::_weight_norm_cuda_interface", datetime.date(9999, 1, 1)),
("aten::_weight_norm_cuda_interface_backward", datetime.date(9999, 1, 1)),
+ ("aten::segment_reduce", datetime.date(2022, 6, 30)),
+ ("aten::_segment_reduce_backward", datetime.date(2022, 6, 30)),
# TODO: FIXME: prims shouldn't be checked
("prims::.*", datetime.date(9999, 1, 1)),
]
diff --git a/test/test_segment_reductions.py b/test/test_segment_reductions.py
index 20f871a..b91a56e 100644
--- a/test/test_segment_reductions.py
+++ b/test/test_segment_reductions.py
@@ -1,6 +1,7 @@
# Owner(s): ["module: scatter & gather ops"]
from itertools import product
+from functools import partial
import numpy as np
import torch
@@ -52,6 +53,11 @@
lengths_dtype=torch.int,
):
lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
+ # generate offsets from lengths
+ zeros_shape = list(lengths.shape)
+ zeros_shape[-1] = 1
+ offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
+
data = torch.tensor(
data_arr,
device=device,
@@ -60,52 +66,56 @@
)
expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
- actual_result = torch.segment_reduce(
- data=data,
- reduce=reduction,
- lengths=lengths,
- axis=axis,
- unsafe=unsafe,
- initial=initial_value,
- )
- self.assertEqual(
- expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
- )
-
- if not check_backward:
- return
-
- # Test backward
- actual_result.sum().backward()
- self.assertEqual(
- expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
- )
-
- # gradcheck does not work well with bfloat16 or fp16 cpu types
- # also there is small numerical difference with fp32
- if dtype not in [torch.half, torch.bfloat16, torch.float]:
- # gradcheck does not like "nan" input, setting to random 10
- d_non_nan = np.nan_to_num(data_arr, nan=10)
- data = torch.tensor(
- # [10 if v == float("nan") else v for v in data],
- d_non_nan,
- device=device,
- dtype=dtype,
- requires_grad=True,
+ for mode in ['lengths', 'offsets']:
+ segment_reduce_kwargs = dict(
+ axis=axis,
+ unsafe=unsafe,
+ initial=initial_value)
+ if (mode == 'lengths'):
+ segment_reduce_kwargs['lengths'] = lengths
+ else:
+ segment_reduce_kwargs['offsets'] = offsets
+ actual_result = torch.segment_reduce(
+ data=data,
+ reduce=reduction,
+ **segment_reduce_kwargs
)
- self.assertTrue(
- gradcheck(
- lambda x: torch.segment_reduce(
- data=x,
- reduce=reduction,
- lengths=lengths,
- axis=axis,
- unsafe=unsafe,
- initial=initial_value,
- ),
- (data,),
+ self.assertEqual(
+ expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
+ )
+
+ if not check_backward:
+ return
+
+ # Test backward
+ actual_result.sum().backward()
+ self.assertEqual(
+ expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
+ )
+ data = data.clone().detach().requires_grad_(True)
+
+ # gradcheck does not work well with bfloat16 or fp16 cpu types
+ # also there is small numerical difference with fp32
+ if dtype not in [torch.half, torch.bfloat16, torch.float]:
+ # gradcheck does not like "nan" input, setting to random 10
+ d_non_nan = np.nan_to_num(data_arr, nan=10)
+ new_data = torch.tensor(
+ # [10 if v == float("nan") else v for v in data],
+ d_non_nan,
+ device=device,
+ dtype=dtype,
+ requires_grad=True,
)
- )
+ self.assertTrue(
+ gradcheck(
+ lambda x: torch.segment_reduce(
+ data=x,
+ reduce=reduction,
+ **segment_reduce_kwargs
+ ),
+ (new_data,),
+ )
+ )
@dtypes(
*product(
@@ -384,8 +394,18 @@
)
self.assertEqual(actual_result, expected)
+ # test offsets
+ actual_result = torch.segment_reduce(
+ data=data,
+ reduce=reduce,
+ offsets=indptr,
+ axis=dim,
+ unsafe=True,
+ )
+ self.assertEqual(actual_result, expected)
+
if val_dtype == torch.float64:
- def fn(x):
+ def fn(x, mode='lengths'):
initial = 1
# supply initial values to prevent gradcheck from failing for 0 length segments
# where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
@@ -393,8 +413,16 @@
initial = 1000
elif reduce == 'max':
initial = -1000
- return torch.segment_reduce(x, reduce, lengths=lengths, axis=dim, unsafe=True, initial=initial)
- self.assertTrue(gradcheck(fn, (data.clone().detach().requires_grad_(True))))
+ segment_reduce_args = {x, reduce}
+ segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
+ if mode == 'lengths':
+ segment_reduce_kwargs[mode] = lengths
+ elif mode == 'offsets':
+ segment_reduce_kwargs[mode] = indptr
+ return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
+ self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
+ self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
+
@dtypes(
*product(
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 02a947a..e4b59d2 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -2745,8 +2745,8 @@
- name: nonzero(Tensor self) -> Tensor
output_differentiability: [False]
-- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
- data: _segment_reduce_backward(grad, result, data, reduce, lengths, axis, initial)
+- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
+ data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial)
- name: _pin_memory(Tensor self, Device? device=None) -> Tensor
self: grad
diff --git a/torch/overrides.py b/torch/overrides.py
index 81ebab6..1410a12 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -947,7 +947,7 @@
torch.scatter_add: lambda input, dim, index, src: -1,
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
- torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
+ torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
torch.select: lambda input, dim, index: -1,
torch.select_scatter: lambda input, src, dim, index: -1,
torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 5f72e3c..d9826b7 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8381,9 +8381,19 @@
for args, reduce, initial in product(test_cases, reductions, [1, 2]):
inp_shape, dim, lengths, unsafe = args
lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)
+ sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial}
+ if mode == 'lengths':
+ sample_input_kwargs['lengths'] = lengths_t
+ elif mode == 'offsets':
+ zeros_shape = list(lengths_t.shape)
+ zeros_shape[dim] = 1
+ offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim)
+ sample_input_kwargs['offsets'] = offsets_t
+ else:
+ raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.")
yield SampleInput(_tensor(inp_shape),
args=(reduce,),
- kwargs={'lengths': lengths_t, 'axis': dim, 'unsafe': unsafe, 'initial': initial})
+ kwargs=sample_input_kwargs)
def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs):
@@ -19497,6 +19507,25 @@
),
),
),
+ OpInfo(
+ 'segment_reduce',
+ variant_test_name='offsets',
+ dtypes=floating_types_and(torch.float16, torch.bfloat16),
+ supports_out=False,
+ # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
+ supports_gradgrad=False,
+ sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'),
+ skips=(
+ # FIXME: CUDA driver API confirmed a leak in
+ # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32
+ DecorateInfo(
+ unittest.skip("Skipped!"),
+ "TestJit",
+ "test_variant_consistency_jit",
+ device_type="cuda",
+ ),
+ ),
+ ),
UnaryUfuncInfo(
'special.bessel_j0',
decorators=(