Add offsets-based reduction to segment_reduce (CPU, CUDA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78907

Approved by: https://github.com/cpuhrsch
diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp
index 1139515..85b19fd 100644
--- a/aten/src/ATen/native/SegmentReduce.cpp
+++ b/aten/src/ATen/native/SegmentReduce.cpp
@@ -8,8 +8,10 @@
 namespace at {
 namespace native {
 
-DEFINE_DISPATCH(_segment_reduce_stub);
-DEFINE_DISPATCH(_segment_reduce_backward_stub);
+DEFINE_DISPATCH(_segment_reduce_lengths_stub);
+DEFINE_DISPATCH(_segment_reduce_offsets_stub);
+DEFINE_DISPATCH(_segment_reduce_lengths_backward_stub);
+DEFINE_DISPATCH(_segment_reduce_offsets_backward_stub);
 
 namespace {
 
@@ -29,8 +31,8 @@
   }
 }
 
-template <typename T>
-void _segment_reduce_cpu_kernel1(
+template <typename T, bool is_offsets_like=false>
+void _segment_reduce_lengths_cpu_kernel1(
     SegmentReductionType reduction,
     const Tensor& data,
     const T* lengths_data,
@@ -46,14 +48,30 @@
       outer_offset *= output.size(d);
   for (int64_t d = axis + 1; d < output.dim(); d++)
       inner_offset *= output.size(d);
+  int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
+  auto data_stride_axis = data.stride(axis);
+  auto data_size_axis = data.size(axis);
+  auto output_stride_axis = output.stride(axis);
+  auto output_size_axis = output.size(axis);
   AT_DISPATCH_FLOATING_TYPES_AND2(
       kBFloat16, kHalf, data.scalar_type(), "_segment_reduce_cpu", [&]() {
         auto* output_data = output.data_ptr<scalar_t>();
         const auto* values_data = data.data_ptr<scalar_t>();
         for (const auto outer_idx : c10::irange(outer_offset)) {
-          int64_t lengths_cum_sum = 0;
+          int64_t segment_start, segment_length;
+          int64_t segment_end = is_offsets_like ?
+                                lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
+                                0;
           for (const auto dim_idx : c10::irange(segment_count)) {
-            int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx];
+            segment_start = segment_end;
+            auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
+            if (is_offsets_like) {
+              segment_end = lengths_data[lengths_idx + 1];
+              segment_length = segment_end - segment_start;
+            } else {
+              segment_length = lengths_data[lengths_idx];
+              segment_end += segment_length;
+            }
             for (const auto inner_idx : c10::irange(inner_offset)) {
               // ===== step1: initialize starting value
               scalar_t initial_value;
@@ -72,9 +90,9 @@
               }
 
               // ===== step2: apply reduction
-              for (const auto j : c10::irange(segment_length)) {
-                int64_t data_index = outer_idx * data.stride(axis) * data.size(axis)
-                                     + (lengths_cum_sum + j) * data.stride(axis) + inner_idx;
+              for (const auto j : c10::irange(segment_start, segment_end)) {
+                int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+                                     + j * data_stride_axis + inner_idx;
                 const auto val = values_data[data_index];
                 if (reduction == SegmentReductionType::MAX) {
                   initial_value = at::_isnan(val)
@@ -104,17 +122,16 @@
                   segment_length > 0 && !at::_isnan(initial_value)) {
                 initial_value = initial_value / segment_length;
               }
-              int64_t output_index = outer_idx * output.stride(axis) * output.size(axis)
-                                     + dim_idx * output.stride(axis) + inner_idx;
+              int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+                                     + dim_idx * output_stride_axis + inner_idx;
               output_data[output_index] = initial_value;
             }
-            lengths_cum_sum += segment_length;
           }
         }
       });
 }
 
-Tensor _segment_reduce_cpu_kernel(
+Tensor _segment_reduce_lengths_cpu_kernel(
     SegmentReductionType reduction,
     const Tensor& data,
     const Tensor& lengths,
@@ -131,17 +148,43 @@
   output_shape[axis] = segment_count;
   auto output = at::empty(output_shape, data.options());
 
-  AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_cpu_kernel1", [&]() {
+  AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_lengths_cpu_kernel1", [&]() {
     const auto* lengths_data = lengths.data_ptr<index_t>();
-    _segment_reduce_cpu_kernel1(
+    _segment_reduce_lengths_cpu_kernel1(
         reduction, data, lengths_data, axis, initial, output, segment_count, lengths_stride_axis);
   });
 
   return output;
 }
 
-template <typename T>
-void _segment_reduce_cpu_backward_kernel1(
+Tensor _segment_reduce_offsets_cpu_kernel(
+    SegmentReductionType reduction,
+    const Tensor& data,
+    const Tensor& offsets,
+    int64_t axis,
+    const c10::optional<Scalar>& initial) {
+  // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
+  TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
+  TORCH_CHECK(offsets.is_contiguous(), "Expected offsets to be contiguous.");
+  // reduction axis should always be the last dimension of lengths
+  axis = offsets.dim() - 1;
+  int64_t segment_count = offsets.size(axis) - 1;
+  int64_t offsets_stride_axis = offsets.stride(axis);
+  auto output_shape = data.sizes().vec();
+  output_shape[axis] = segment_count;
+  auto output = at::empty(output_shape, data.options());
+
+  AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_segment_reduce_offsets_cpu_kernel1", [&]() {
+    const auto* offsets_data = offsets.data_ptr<index_t>();
+    _segment_reduce_lengths_cpu_kernel1<index_t, /*is_offsets_like=*/true>(
+        reduction, data, offsets_data, axis, initial, output, segment_count, offsets_stride_axis);
+  });
+
+  return output;
+}
+
+template <typename T, bool is_offsets_like = false>
+void _segment_reduce_cpu_lengths_backward_kernel1(
     const Tensor& grad_contig,
     const Tensor& output_contig,
     const Tensor& data_contig,
@@ -159,7 +202,12 @@
       outer_offset *= output_contig.size(d);
   for (int64_t d = axis + 1; d < output_contig.dim(); d++)
       inner_offset *= output_contig.size(d);
-  // TODO: Swtich to TensorIterator for better maintainablility and
+  int64_t lengths_size_axis = is_offsets_like ? segment_count + 1 : segment_count;
+  auto data_stride_axis = data_contig.stride(axis);
+  auto data_size_axis = data_contig.size(axis);
+  auto output_stride_axis = output_contig.stride(axis);
+  auto output_size_axis = output_contig.size(axis);
+  // TODO: Switch to TensorIterator for better maintainablility and
   // readability
   AT_DISPATCH_FLOATING_TYPES_AND2(
       kBFloat16,
@@ -182,21 +230,34 @@
         }
 
         for (const auto outer_idx : c10::irange(outer_offset)) {
-          int64_t lengths_cum_sum = 0;
+          // int64_t lengths_cum_sum = 0;
+          int64_t segment_start, segment_length;
+          int64_t segment_end = is_offsets_like ?
+                                lengths_data[outer_idx * lengths_stride_axis * lengths_size_axis] :
+                                0;
           for (const auto dim_idx : c10::irange(segment_count)) {
-            int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx];
+            // int64_t segment_length = lengths_data[outer_idx * lengths_stride_axis * segment_count + dim_idx];
+            segment_start = segment_end;
+            auto lengths_idx = outer_idx * lengths_stride_axis * lengths_size_axis + dim_idx;
+            if (is_offsets_like) {
+              segment_end = lengths_data[lengths_idx + 1];
+              segment_length = segment_end - segment_start;
+            } else {
+              segment_length = lengths_data[lengths_idx];
+              segment_end += segment_length;
+            }
             if (segment_length == 0) {
               continue;
             }
             for (const auto inner_idx : c10::irange(inner_offset)) {
-              int64_t output_index = outer_idx * output_contig.stride(axis) * output_contig.size(axis)
-                                     + dim_idx * output_contig.stride(axis) + inner_idx;
+              int64_t output_index = outer_idx * output_stride_axis * output_size_axis
+                                     + dim_idx * output_stride_axis + inner_idx;
               if (reduction == SegmentReductionType::MAX ||
                   reduction == SegmentReductionType::MIN) {
                 int64_t counter = 0;
-                for (const auto j : c10::irange(segment_length)) {
-                  int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
-                                       + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+                for (const auto j : c10::irange(segment_start, segment_end)) {
+                  int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+                                       + j * data_stride_axis + inner_idx;
                   if (at::_isnan(values_data[data_index]) ||
                       values_data[data_index] == output_data[output_index]) {
                     grad_input_data[data_index] = grad_data[output_index];
@@ -208,9 +269,9 @@
                 if (counter < 2) {
                   continue;
                 }
-                for (const auto j : c10::irange(segment_length)) {
-                  int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
-                                       + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+                for (const auto j : c10::irange(segment_start, segment_end)) {
+                  int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+                                       + j * data_stride_axis + inner_idx;
                   if (grad_input_data[data_index] > 0) {
                     grad_input_data[data_index] =
                         grad_input_data[data_index] / counter;
@@ -218,32 +279,32 @@
                 }
               } else if (reduction == SegmentReductionType::MEAN) {
                 auto grad_val = grad_data[output_index] / segment_length;
-                for (const auto j : c10::irange(segment_length)) {
-                  int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
-                                       + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+                for (const auto j : c10::irange(segment_start, segment_end)) {
+                  int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+                                       + j * data_stride_axis + inner_idx;
                   grad_input_data[data_index] = grad_val;
                 }
               } else if (reduction == SegmentReductionType::SUM) {
                 const auto& grad_val = grad_data[output_index];
-                for (const auto j : c10::irange(segment_length)) {
-                  int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
-                                       + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+                for (const auto j : c10::irange(segment_start, segment_end)) {
+                  int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+                                       + j * data_stride_axis + inner_idx;
                   grad_input_data[data_index] = grad_val;
                 }
               } else if (reduction == SegmentReductionType::PROD) {
                 const auto& grad_val = grad_data[output_index] * output_data[output_index];
-                for (const auto j : c10::irange(segment_length)) {
-                  int64_t data_index = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
-                                       + (lengths_cum_sum + j) * data_contig.stride(axis) + inner_idx;
+                for (const auto j : c10::irange(segment_start, segment_end)) {
+                  int64_t data_index = outer_idx * data_stride_axis * data_size_axis
+                                       + j * data_stride_axis + inner_idx;
                   if (at::_isnan(values_data[data_index]) ||
                       values_data[data_index] == 0) {
                     // explicitly compute exclusive prod
                     scalar_t exclusive_prod = initial_prod_value;
                     int64_t idx;
-                    for (const auto k : c10::irange(segment_length)) {
+                    for (const auto k : c10::irange(segment_start, segment_end)) {
                       if (k != j) {
-                        idx = outer_idx * data_contig.stride(axis) * data_contig.size(axis)
-                              + (lengths_cum_sum + k) * data_contig.stride(axis) + inner_idx;
+                        idx = outer_idx * data_stride_axis * data_size_axis
+                              + k * data_stride_axis + inner_idx;
                         exclusive_prod *= values_data[idx];
                       }
                     }
@@ -254,13 +315,12 @@
                 }
               }
             }
-            lengths_cum_sum += segment_length;
           }
         }
       });
 }
 
-Tensor _segment_reduce_cpu_backward_kernel(
+Tensor _segment_reduce_cpu_lengths_backward_kernel(
     const Tensor& grad_contig,
     const Tensor& output_contig,
     const Tensor& data_contig,
@@ -274,9 +334,9 @@
   auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
 
   AT_DISPATCH_INDEX_TYPES(
-      lengths_contig.scalar_type(), "_segment_reduce_cpu_backward_kernel1", [&] {
+      lengths_contig.scalar_type(), "_segment_reduce_cpu_lengths_backward_kernel1", [&] {
         const auto* lengths_data = lengths_contig.data_ptr<index_t>();
-        _segment_reduce_cpu_backward_kernel1(
+        _segment_reduce_cpu_lengths_backward_kernel1(
             grad_contig,
             output_contig,
             data_contig,
@@ -292,6 +352,39 @@
   return grad_input;
 }
 
+
+Tensor _segment_reduce_cpu_offsets_backward_kernel(
+    const Tensor& grad_contig,
+    const Tensor& output_contig,
+    const Tensor& data_contig,
+    SegmentReductionType reduction,
+    const Tensor& offsets_contig,
+    int64_t axis,
+    const c10::optional<Scalar>& initial) {
+  axis = offsets_contig.dim() - 1;
+  int64_t segment_count = offsets_contig.size(axis) - 1;
+  int64_t offsets_stride_axis = offsets_contig.stride(axis);
+  auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
+
+  AT_DISPATCH_INDEX_TYPES(
+      offsets_contig.scalar_type(), "_segment_reduce_cpu_offsets_backward_kernel1", [&] {
+        const auto* offsets_data = offsets_contig.data_ptr<index_t>();
+        _segment_reduce_cpu_lengths_backward_kernel1<index_t, /*is_offsets_like=*/true>(
+            grad_contig,
+            output_contig,
+            data_contig,
+            reduction,
+            offsets_data,
+            axis,
+            initial,
+            grad_input,
+            segment_count,
+            offsets_stride_axis);
+      });
+
+  return grad_input;
+}
+
 } // namespace
 
 Tensor segment_reduce_kernel(
@@ -299,49 +392,94 @@
     c10::string_view reduce,
     const c10::optional<Tensor>& lengths,
     const c10::optional<Tensor>& indices,
+    const c10::optional<Tensor>& offsets,
     int64_t axis,
     bool unsafe,
     const c10::optional<Scalar>& initial) {
   axis = maybe_wrap_dim(axis, data.ndimension());
   TORCH_CHECK(data.numel() > 0);
 
-  // length related checks
+  // check that one of lengths or offsets is defined
+  auto lengths_has_value = lengths.has_value();
+  auto offsets_has_value = offsets.has_value();
   TORCH_CHECK(
-      lengths.has_value() && !indices.has_value(),
-      "Currently only lengths based reduction is supported!")
-  const auto& lengths_value = lengths.value();
-  TORCH_CHECK(data.get_device() == lengths_value.get_device());
-  TORCH_CHECK(data.dim() >= lengths_value.dim());
-  TORCH_CHECK(axis == lengths_value.dim() - 1, "Expected axis to be equal to lengths.ndim() - 1 but got ", axis, ".");
-
-  if (!unsafe) {
-    auto min_length = lengths_value.min().item<int64_t>();
-    TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
-    TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item<bool>(),
-                "Expected all rows of lengths to sum to data.size(lengths.dim()-1) when unsafe=False");
-  }
+    !indices.has_value(),
+    "segment_reduce(): indices based reduction is not supported yet.");
+  TORCH_CHECK(
+      lengths_has_value || offsets_has_value,
+      "segment_reduce(): Either lengths or offsets must be defined.")
 
   auto reduction = get_reduction_enum(reduce);
   const auto data_contig = data.contiguous();
-  const auto lengths_contig = lengths_value.contiguous();
 
-  return _segment_reduce_stub(
+  if (offsets_has_value) {
+    const auto& offsets_value = offsets.value();
+
+    // offsets related checks
+    TORCH_CHECK(data.get_device() == offsets_value.get_device());
+    TORCH_CHECK(data.dim() >= offsets_value.dim());
+    TORCH_CHECK(axis == offsets_value.dim() - 1,
+                "segment_reduce(): Expected axis to be the last dimension of offsets but got ", axis, ".");
+
+    // TODO: add checks when !unsafe
+
+    const auto offsets_contig = offsets_value.contiguous();
+
+    return _segment_reduce_offsets_stub(
+      data_contig.device().type(),
+      reduction,
+      data_contig,
+      offsets_contig,
+      axis,
+      initial);
+
+  } else {
+    const auto& lengths_value = lengths.value();
+
+    // length related checks
+    TORCH_CHECK(data.get_device() == lengths_value.get_device());
+    TORCH_CHECK(data.dim() >= lengths_value.dim());
+    TORCH_CHECK(axis == lengths_value.dim() - 1,
+                "segment_reduce(): Expected axis to be the last dimension of lengths but got ", axis, ".");
+
+    if (!unsafe) {
+      auto min_length = lengths_value.min().item<int64_t>();
+      TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
+      TORCH_CHECK(all(lengths_value.sum({-1}) == data.size(axis)).item<bool>(),
+                  "segment_reduce(): Expected all rows of lengths along axis ",
+                  "to sum to data.size(lengths.dim()-1) when !unsafe.");
+    }
+
+    const auto lengths_contig = lengths_value.contiguous();
+
+    return _segment_reduce_lengths_stub(
       data_contig.device().type(),
       reduction,
       data_contig,
       lengths_contig,
       axis,
       initial);
+  }
 }
 
 REGISTER_ARCH_DISPATCH(
-    _segment_reduce_stub,
+    _segment_reduce_lengths_stub,
     DEFAULT,
-    &_segment_reduce_cpu_kernel);
-REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
-REGISTER_AVX512_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
-REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
-REGISTER_ZVECTOR_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel);
+    &_segment_reduce_lengths_cpu_kernel);
+REGISTER_AVX2_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
+REGISTER_AVX512_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
+REGISTER_VSX_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
+REGISTER_ZVECTOR_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cpu_kernel);
+
+// offsets dispatches
+REGISTER_ARCH_DISPATCH(
+    _segment_reduce_offsets_stub,
+    DEFAULT,
+    &_segment_reduce_offsets_cpu_kernel);
+REGISTER_AVX2_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
+REGISTER_AVX512_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
+REGISTER_VSX_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
+REGISTER_ZVECTOR_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cpu_kernel);
 
 // Currently some computation is being duplicated across forward and backward.
 // TODO: Cache indices in forward pass to re-use in backward
@@ -351,21 +489,40 @@
     const Tensor& data,
     c10::string_view reduce,
     const c10::optional<Tensor>& lengths,
+    const c10::optional<Tensor>& offsets,
     int64_t axis,
     const c10::optional<Scalar>& initial) {
   axis = maybe_wrap_dim(axis, data.ndimension());
+  // check that one of lengths or offsets is defined
+  // codegen for derivatives.yaml passes an undefined Tensor for None rather than a c10::optional
+  // so checking has_value() doesn't work unlike in the forward pass
+  auto lengths_has_value = lengths.has_value() && lengths.value().defined();
+  auto offsets_has_value = offsets.has_value() && offsets.value().defined();
   TORCH_CHECK(
-      lengths.has_value(),
-      "Currently only lengths based reduction is supported!")
-  const auto& lengths_value = lengths.value();
+      lengths_has_value ||  offsets_has_value,
+      "segment_reduce(): Either lengths or offsets must be defined.");
 
   const auto grad_contig = grad.contiguous();
   const auto output_contig = output.contiguous();
   const auto data_contig = data.contiguous();
-  const auto lengths_contig = lengths_value.contiguous();
-
   auto reduction = get_reduction_enum(reduce);
-  return _segment_reduce_backward_stub(
+
+  if (offsets_has_value) {
+    const auto& offsets_value = offsets.value();
+    const auto offsets_contig = offsets_value.contiguous();
+    return _segment_reduce_offsets_backward_stub(
+      grad_contig.device().type(),
+      grad_contig,
+      output_contig,
+      data_contig,
+      reduction,
+      offsets_contig,
+      axis,
+      initial);
+  } else {
+    const auto& lengths_value = lengths.value();
+    const auto lengths_contig = lengths_value.contiguous();
+    return _segment_reduce_lengths_backward_stub(
       grad_contig.device().type(),
       grad_contig,
       output_contig,
@@ -374,24 +531,42 @@
       lengths_contig,
       axis,
       initial);
+  }
 }
 
 REGISTER_ARCH_DISPATCH(
-    _segment_reduce_backward_stub,
+    _segment_reduce_lengths_backward_stub,
     DEFAULT,
-    &_segment_reduce_cpu_backward_kernel);
+    &_segment_reduce_cpu_lengths_backward_kernel);
 REGISTER_AVX512_DISPATCH(
-    _segment_reduce_backward_stub,
-    &_segment_reduce_cpu_backward_kernel);
+    _segment_reduce_lengths_backward_stub,
+    &_segment_reduce_cpu_lengths_backward_kernel);
 REGISTER_AVX2_DISPATCH(
-    _segment_reduce_backward_stub,
-    &_segment_reduce_cpu_backward_kernel);
+    _segment_reduce_lengths_backward_stub,
+    &_segment_reduce_cpu_lengths_backward_kernel);
 REGISTER_VSX_DISPATCH(
-    _segment_reduce_backward_stub,
-    &_segment_reduce_cpu_backward_kernel);
+    _segment_reduce_lengths_backward_stub,
+    &_segment_reduce_cpu_lengths_backward_kernel);
 REGISTER_ZVECTOR_DISPATCH(
-    _segment_reduce_backward_stub,
-    &_segment_reduce_cpu_backward_kernel);
+    _segment_reduce_lengths_backward_stub,
+    &_segment_reduce_cpu_lengths_backward_kernel);
+
+REGISTER_ARCH_DISPATCH(
+    _segment_reduce_offsets_backward_stub,
+    DEFAULT,
+    &_segment_reduce_cpu_offsets_backward_kernel);
+REGISTER_AVX512_DISPATCH(
+    _segment_reduce_offsets_backward_stub,
+    &_segment_reduce_cpu_offsets_backward_kernel);
+REGISTER_AVX2_DISPATCH(
+    _segment_reduce_offsets_backward_stub,
+    &_segment_reduce_cpu_offsets_backward_kernel);
+REGISTER_VSX_DISPATCH(
+    _segment_reduce_offsets_backward_stub,
+    &_segment_reduce_cpu_offsets_backward_kernel);
+REGISTER_ZVECTOR_DISPATCH(
+    _segment_reduce_offsets_backward_stub,
+    &_segment_reduce_cpu_offsets_backward_kernel);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/SegmentReduce.h b/aten/src/ATen/native/SegmentReduce.h
index a7cb5f8..7fb1512 100644
--- a/aten/src/ATen/native/SegmentReduce.h
+++ b/aten/src/ATen/native/SegmentReduce.h
@@ -11,15 +11,23 @@
 
 enum SegmentReductionType { MAX, MEAN, MIN, SUM, PROD};
 
-using segment_reduce_fn = Tensor (*)(
+using segment_reduce_lengths_fn = Tensor (*)(
     SegmentReductionType,
     const Tensor&,
     const Tensor&,
     int64_t,
     const c10::optional<Scalar>&);
-DECLARE_DISPATCH(segment_reduce_fn, _segment_reduce_stub);
+DECLARE_DISPATCH(segment_reduce_lengths_fn, _segment_reduce_lengths_stub);
 
-using segment_reduce_backward_fn = Tensor (*)(
+using segment_reduce_offsets_fn = Tensor (*)(
+    SegmentReductionType,
+    const Tensor&,
+    const Tensor&,
+    int64_t,
+    const c10::optional<Scalar>&);
+DECLARE_DISPATCH(segment_reduce_offsets_fn, _segment_reduce_offsets_stub);
+
+using segment_reduce_lengths_backward_fn = Tensor (*)(
     const Tensor&,
     const Tensor&,
     const Tensor&,
@@ -27,7 +35,17 @@
     const Tensor&,
     int64_t,
     const c10::optional<Scalar>&);
-DECLARE_DISPATCH(segment_reduce_backward_fn, _segment_reduce_backward_stub);
+DECLARE_DISPATCH(segment_reduce_lengths_backward_fn, _segment_reduce_lengths_backward_stub);
+
+using segment_reduce_offsets_backward_fn = Tensor (*)(
+    const Tensor&,
+    const Tensor&,
+    const Tensor&,
+    SegmentReductionType,
+    const Tensor&,
+    int64_t,
+    const c10::optional<Scalar>&);
+DECLARE_DISPATCH(segment_reduce_offsets_backward_fn, _segment_reduce_offsets_backward_stub);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/cuda/SegmentReduce.cu b/aten/src/ATen/native/cuda/SegmentReduce.cu
index ab8571d..bfaa5ca 100644
--- a/aten/src/ATen/native/cuda/SegmentReduce.cu
+++ b/aten/src/ATen/native/cuda/SegmentReduce.cu
@@ -70,7 +70,7 @@
   offsets[0].zero_();
 
   AT_DISPATCH_INDEX_TYPES(
-      lengths.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
+      lengths.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
         auto* lengths_data_ptr = lengths.data_ptr<index_t>();
         auto* offsets_data_ptr = offsets.data_ptr<index_t>();
         at::cuda::cub::inclusive_sum(
@@ -278,23 +278,33 @@
 }
 } // namespace
 
-Tensor _segment_reduce_cuda_backward_kernel(
+Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel(
     const Tensor& grad_contig,
     const Tensor& output_contig,
     const Tensor& data_contig,
     SegmentReductionType reduction,
-    const Tensor& lengths_contig,
+    const Tensor& lengths_or_offsets_contig,
     int64_t axis,
-    const c10::optional<Scalar>& initial) {
-  axis = lengths_contig.dim() - 1;
-  int64_t segment_count = lengths_contig.size(axis);
-  int64_t lengths_stride_axis = lengths_contig.stride(axis);
+    const c10::optional<Scalar>& initial,
+    bool is_offsets_like) {
+  axis = lengths_or_offsets_contig.dim() - 1;
+  int64_t segment_count = is_offsets_like ?
+                          lengths_or_offsets_contig.size(axis) - 1 :
+                          lengths_or_offsets_contig.size(axis);
+  int64_t lengths_stride_axis = lengths_or_offsets_contig.stride(axis);
   auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
 
-  auto zeros_shape = lengths_contig.sizes().vec();
-  zeros_shape[axis] = 1;
-  auto offsets = at::cat({at::zeros(zeros_shape, lengths_contig.options()), lengths_contig}, axis);
-  offsets.cumsum_(axis);
+  auto offsets = lengths_or_offsets_contig;
+  auto lengths = lengths_or_offsets_contig;
+  if (is_offsets_like) {
+    lengths = lengths.diff();
+  } else {
+    // _get_complete_sum only supports 1D
+    auto zeros_shape = offsets.sizes().vec();
+    zeros_shape[axis] = 1;
+    offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
+    offsets.cumsum_(axis);
+  }
 
   // outer_offset is the size of the outer dimensions of output (before axis)
   // inner_offset is the size of the inner dimensions of output (after axis)
@@ -318,8 +328,8 @@
   auto offsets_stride_axis = offsets.stride(axis);
 
   AT_DISPATCH_INDEX_TYPES(
-      lengths_contig.scalar_type(), "_segment_reduce_cuda_backward_kernel1", ([&] {
-        const auto* lengths_data = lengths_contig.data_ptr<index_t>();
+      lengths_or_offsets_contig.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
+        const auto* lengths_data = lengths.data_ptr<index_t>();
         auto* offsets_data = offsets.data_ptr<index_t>();
 
         // TODO: Switch to TensorIterator for better maintainablility and
@@ -371,27 +381,59 @@
   return grad_input;
 }
 
-Tensor _segment_reduce_cuda_kernel(
-    SegmentReductionType reduction,
-    const Tensor& data,
-    const Tensor& lengths,
-    int64_t axis,
-    const c10::optional<Scalar>& initial) {
-  // data and lengths should be contiguous from the call to .contiguous in segment_reduce_kernel
-  TORCH_CHECK(data.is_contiguous(), "Expected data to be contiguous.");
-  TORCH_CHECK(lengths.is_contiguous(), "Expected lengths to be contiguous.");
-  axis = lengths.dim() - 1;
-  int64_t segment_count = lengths.size(axis);
-  int64_t lengths_stride_axis = lengths.stride(axis);
+Tensor _segment_reduce_lengths_backward_cuda_kernel(
+  const Tensor& grad_contig,
+  const Tensor& output_contig,
+  const Tensor& data_contig,
+  SegmentReductionType reduction,
+  const Tensor& lengths_contig,
+  int64_t axis,
+  const c10::optional<Scalar>& initial) {
+  return _segment_reduce_lengths_offsets_backward_cuda_kernel(
+    grad_contig, output_contig, data_contig, reduction, lengths_contig, axis, initial, /*is_offsets_like=*/false);
+}
+
+Tensor _segment_reduce_offsets_backward_cuda_kernel(
+  const Tensor& grad_contig,
+  const Tensor& output_contig,
+  const Tensor& data_contig,
+  SegmentReductionType reduction,
+  const Tensor& offsets_contig,
+  int64_t axis,
+  const c10::optional<Scalar>& initial) {
+  return _segment_reduce_lengths_offsets_backward_cuda_kernel(
+    grad_contig, output_contig, data_contig, reduction, offsets_contig, axis, initial, /*is_offsets_like=*/true);
+}
+
+Tensor _segment_reduce_lengths_offsets_cuda_kernel(
+  SegmentReductionType reduction,
+  const Tensor& data,
+  const Tensor& lengths_or_offsets,
+  int64_t axis,
+  const c10::optional<Scalar>& initial,
+  bool is_offsets_like) {
+  // data and lengths_or_offsets should be contiguous from the call to .contiguous in segment_reduce_kernel
+  TORCH_CHECK(data.is_contiguous());
+  TORCH_CHECK(lengths_or_offsets.is_contiguous());
+  axis = lengths_or_offsets.dim() - 1;
+  int64_t segment_count = is_offsets_like ? lengths_or_offsets.size(axis) - 1 : lengths_or_offsets.size(axis);
+  int64_t lengths_stride_axis = lengths_or_offsets.stride(axis);
   auto output_shape = data.sizes().vec();
   output_shape[axis] = segment_count;
   auto output = at::empty(output_shape, data.options());
 
-  // _get_complete_sum only supports 1D?
-  auto zeros_shape = lengths.sizes().vec();
-  zeros_shape[axis] = 1;
-  auto offsets = at::cat({at::zeros(zeros_shape, lengths.options()), lengths}, axis);
-  offsets.cumsum_(axis);
+
+  auto offsets = lengths_or_offsets;
+  auto lengths = lengths_or_offsets;
+  if (is_offsets_like) {
+    lengths = lengths.diff();
+  } else {
+    // _get_complete_sum only supports 1D
+    auto zeros_shape = offsets.sizes().vec();
+    zeros_shape[axis] = 1;
+    offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
+    offsets.cumsum_(axis);
+  }
 
   // outer_offset is the size of the outer dimensions of output (before axis)
   // inner_offset is the size of the inner dimensions of output (after axis)
@@ -416,7 +458,7 @@
   auto offsets_stride_axis = offsets.stride(axis);
 
   AT_DISPATCH_INDEX_TYPES(
-      lengths.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
+      lengths_or_offsets.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
         auto* offsets_data_ptr = offsets.data_ptr<index_t>();
         auto* lengths_data_ptr = lengths.data_ptr<index_t>();
         AT_DISPATCH_FLOATING_TYPES_AND2(
@@ -549,10 +591,34 @@
   return output;
 }
 
-REGISTER_DISPATCH(_segment_reduce_stub, &_segment_reduce_cuda_kernel);
+Tensor _segment_reduce_lengths_cuda_kernel(
+  SegmentReductionType reduction,
+  const Tensor& data,
+  const Tensor& lengths,
+  int64_t axis,
+  const c10::optional<Scalar>& initial) {
+  return _segment_reduce_lengths_offsets_cuda_kernel(
+    reduction, data, lengths, axis, initial, /*is_offsets_like=*/false);
+}
+
+Tensor _segment_reduce_offsets_cuda_kernel(
+  SegmentReductionType reduction,
+  const Tensor& data,
+  const Tensor& offsets,
+  int64_t axis,
+  const c10::optional<Scalar>& initial) {
+  return _segment_reduce_lengths_offsets_cuda_kernel(
+    reduction, data, offsets, axis, initial, /*is_offsets_like=*/true);
+}
+
+REGISTER_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cuda_kernel);
+REGISTER_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cuda_kernel);
 REGISTER_DISPATCH(
-    _segment_reduce_backward_stub,
-    &_segment_reduce_cuda_backward_kernel);
+    _segment_reduce_lengths_backward_stub,
+    &_segment_reduce_lengths_backward_cuda_kernel);
+REGISTER_DISPATCH(
+  _segment_reduce_offsets_backward_stub,
+  &_segment_reduce_offsets_backward_cuda_kernel);
 
 } // namespace native
 } // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 1b18d4d..cc88af0 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -11987,12 +11987,12 @@
   dispatch:
     CompositeExplicitAutograd: _test_warn_in_autograd
 
-- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
+- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
   variants: function
   dispatch:
     CPU, CUDA: segment_reduce_kernel
 
-- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, int axis=0, Scalar? initial=None) -> Tensor
+- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor
   variants: function
   dispatch:
     CPU, CUDA: _segment_reduce_backward_kernel
diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
index 6b9bf35..5ab7285 100644
--- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py
+++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py
@@ -143,6 +143,8 @@
     ("aten::_csr_to_block_csr", datetime.date(2022, 5, 20)),
     ("aten::_weight_norm_cuda_interface", datetime.date(9999, 1, 1)),
     ("aten::_weight_norm_cuda_interface_backward", datetime.date(9999, 1, 1)),
+    ("aten::segment_reduce", datetime.date(2022, 6, 30)),
+    ("aten::_segment_reduce_backward", datetime.date(2022, 6, 30)),
     # TODO: FIXME: prims shouldn't be checked
     ("prims::.*", datetime.date(9999, 1, 1)),
 ]
diff --git a/test/test_segment_reductions.py b/test/test_segment_reductions.py
index 20f871a..b91a56e 100644
--- a/test/test_segment_reductions.py
+++ b/test/test_segment_reductions.py
@@ -1,6 +1,7 @@
 # Owner(s): ["module: scatter & gather ops"]
 
 from itertools import product
+from functools import partial
 
 import numpy as np
 import torch
@@ -52,6 +53,11 @@
         lengths_dtype=torch.int,
     ):
         lengths = torch.tensor(lengths_arr, device=device, dtype=lengths_dtype)
+        # generate offsets from lengths
+        zeros_shape = list(lengths.shape)
+        zeros_shape[-1] = 1
+        offsets = torch.cat((lengths.new_zeros(zeros_shape), lengths), -1).cumsum_(-1)
+
         data = torch.tensor(
             data_arr,
             device=device,
@@ -60,52 +66,56 @@
         )
         expected_result = torch.tensor(expected_arr, device=device, dtype=dtype)
         expected_grad = torch.tensor(expected_grad_arr, device=device, dtype=dtype)
-        actual_result = torch.segment_reduce(
-            data=data,
-            reduce=reduction,
-            lengths=lengths,
-            axis=axis,
-            unsafe=unsafe,
-            initial=initial_value,
-        )
-        self.assertEqual(
-            expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
-        )
-
-        if not check_backward:
-            return
-
-        # Test backward
-        actual_result.sum().backward()
-        self.assertEqual(
-            expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
-        )
-
-        # gradcheck does not work well with bfloat16 or fp16 cpu types
-        # also there is small numerical difference with fp32
-        if dtype not in [torch.half, torch.bfloat16, torch.float]:
-            # gradcheck does not like "nan" input, setting to random 10
-            d_non_nan = np.nan_to_num(data_arr, nan=10)
-            data = torch.tensor(
-                # [10 if v == float("nan") else v for v in data],
-                d_non_nan,
-                device=device,
-                dtype=dtype,
-                requires_grad=True,
+        for mode in ['lengths', 'offsets']:
+            segment_reduce_kwargs = dict(
+                axis=axis,
+                unsafe=unsafe,
+                initial=initial_value)
+            if (mode == 'lengths'):
+                segment_reduce_kwargs['lengths'] = lengths
+            else:
+                segment_reduce_kwargs['offsets'] = offsets
+            actual_result = torch.segment_reduce(
+                data=data,
+                reduce=reduction,
+                **segment_reduce_kwargs
             )
-            self.assertTrue(
-                gradcheck(
-                    lambda x: torch.segment_reduce(
-                        data=x,
-                        reduce=reduction,
-                        lengths=lengths,
-                        axis=axis,
-                        unsafe=unsafe,
-                        initial=initial_value,
-                    ),
-                    (data,),
+            self.assertEqual(
+                expected_result, actual_result, rtol=1e-02, atol=1e-05, equal_nan=True
+            )
+
+            if not check_backward:
+                return
+
+            # Test backward
+            actual_result.sum().backward()
+            self.assertEqual(
+                expected_grad, data.grad, rtol=1e-02, atol=1e-05, equal_nan=True
+            )
+            data = data.clone().detach().requires_grad_(True)
+
+            # gradcheck does not work well with bfloat16 or fp16 cpu types
+            # also there is small numerical difference with fp32
+            if dtype not in [torch.half, torch.bfloat16, torch.float]:
+                # gradcheck does not like "nan" input, setting to random 10
+                d_non_nan = np.nan_to_num(data_arr, nan=10)
+                new_data = torch.tensor(
+                    # [10 if v == float("nan") else v for v in data],
+                    d_non_nan,
+                    device=device,
+                    dtype=dtype,
+                    requires_grad=True,
                 )
-            )
+                self.assertTrue(
+                    gradcheck(
+                        lambda x: torch.segment_reduce(
+                            data=x,
+                            reduce=reduction,
+                            **segment_reduce_kwargs
+                        ),
+                        (new_data,),
+                    )
+                )
 
     @dtypes(
         *product(
@@ -384,8 +394,18 @@
             )
             self.assertEqual(actual_result, expected)
 
+            # test offsets
+            actual_result = torch.segment_reduce(
+                data=data,
+                reduce=reduce,
+                offsets=indptr,
+                axis=dim,
+                unsafe=True,
+            )
+            self.assertEqual(actual_result, expected)
+
             if val_dtype == torch.float64:
-                def fn(x):
+                def fn(x, mode='lengths'):
                     initial = 1
                     # supply initial values to prevent gradcheck from failing for 0 length segments
                     # where nan/inf are reduction identities that produce nans when calculating the numerical jacobian
@@ -393,8 +413,16 @@
                         initial = 1000
                     elif reduce == 'max':
                         initial = -1000
-                    return torch.segment_reduce(x, reduce, lengths=lengths, axis=dim, unsafe=True, initial=initial)
-                self.assertTrue(gradcheck(fn, (data.clone().detach().requires_grad_(True))))
+                    segment_reduce_args = {x, reduce}
+                    segment_reduce_kwargs = dict(axis=dim, unsafe=True, initial=initial)
+                    if mode == 'lengths':
+                        segment_reduce_kwargs[mode] = lengths
+                    elif mode == 'offsets':
+                        segment_reduce_kwargs[mode] = indptr
+                    return torch.segment_reduce(*segment_reduce_args, **segment_reduce_kwargs)
+                self.assertTrue(gradcheck(partial(fn, mode='lengths'), (data.clone().detach().requires_grad_(True))))
+                self.assertTrue(gradcheck(partial(fn, mode='offsets'), (data.clone().detach().requires_grad_(True))))
+
 
     @dtypes(
         *product(
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 02a947a..e4b59d2 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -2745,8 +2745,8 @@
 - name: nonzero(Tensor self) -> Tensor
   output_differentiability: [False]
 
-- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
-  data: _segment_reduce_backward(grad, result, data, reduce, lengths, axis, initial)
+- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
+  data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial)
 
 - name: _pin_memory(Tensor self, Device? device=None) -> Tensor
   self: grad
diff --git a/torch/overrides.py b/torch/overrides.py
index 81ebab6..1410a12 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -947,7 +947,7 @@
         torch.scatter_add: lambda input, dim, index, src: -1,
         torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
         torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
-        torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, axis=0, unsafe=False: -1,
+        torch.segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
         torch.select: lambda input, dim, index: -1,
         torch.select_scatter: lambda input, src, dim, index: -1,
         torch.slice_scatter: lambda input, src, dim=0, start=None, end=None, step=1: -1,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 5f72e3c..d9826b7 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8381,9 +8381,19 @@
     for args, reduce, initial in product(test_cases, reductions, [1, 2]):
         inp_shape, dim, lengths, unsafe = args
         lengths_t = torch.tensor(lengths, dtype=torch.long, device=device)
+        sample_input_kwargs = {'axis': dim, 'unsafe': unsafe, 'initial': initial}
+        if mode == 'lengths':
+            sample_input_kwargs['lengths'] = lengths_t
+        elif mode == 'offsets':
+            zeros_shape = list(lengths_t.shape)
+            zeros_shape[dim] = 1
+            offsets_t = torch.cat((lengths_t.new_zeros(zeros_shape), lengths_t), dim).cumsum_(dim)
+            sample_input_kwargs['offsets'] = offsets_t
+        else:
+            raise RuntimeError(f"mode most be one of 'offsets' or 'lengths' got '{mode}'.")
         yield SampleInput(_tensor(inp_shape),
                           args=(reduce,),
-                          kwargs={'lengths': lengths_t, 'axis': dim, 'unsafe': unsafe, 'initial': initial})
+                          kwargs=sample_input_kwargs)
 
 
 def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs):
@@ -19497,6 +19507,25 @@
             ),
         ),
     ),
+    OpInfo(
+        'segment_reduce',
+        variant_test_name='offsets',
+        dtypes=floating_types_and(torch.float16, torch.bfloat16),
+        supports_out=False,
+        # RuntimeError: derivative for aten::_segment_reduce_backward is not implemented
+        supports_gradgrad=False,
+        sample_inputs_func=partial(sample_inputs_segment_reduce, mode='offsets'),
+        skips=(
+            # FIXME: CUDA driver API confirmed a leak in
+            # __main__.TestJitCUDA.test_variant_consistency_jit_segment_reduce_cuda_float32
+            DecorateInfo(
+                unittest.skip("Skipped!"),
+                "TestJit",
+                "test_variant_consistency_jit",
+                device_type="cuda",
+            ),
+        ),
+    ),
     UnaryUfuncInfo(
         'special.bessel_j0',
         decorators=(