|  | #ifndef CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_ | 
|  | #define CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_ | 
|  |  | 
|  | #include "caffe2/core/common_omp.h" | 
|  | #include "caffe2/core/context.h" | 
|  | #include "caffe2/core/logging.h" | 
|  | #include "caffe2/core/operator.h" | 
|  | #include "caffe2/utils/math.h" | 
|  | #include "c10/util/irange.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <class Context> | 
|  | class LengthsSplitOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | template <class... Args> | 
|  | explicit LengthsSplitOp(Args&&... args) | 
|  | : Operator<Context>(std::forward<Args>(args)...), | 
|  | n_split_(OperatorBase::GetSingleArgument<int32_t>("n_split", 0)) { | 
|  | if (InputSize() == 1) { | 
|  | // If not specified, then must have this argument | 
|  | CAFFE_ENFORCE( | 
|  | OperatorBase::HasArgument("n_split"), | 
|  | "Argument `n_split` is missing and was not specified as input."); | 
|  | CAFFE_ENFORCE( | 
|  | n_split_ > 0, | 
|  | "`n_split` must contain a positive value for defined behavior."); | 
|  | } | 
|  | } | 
|  | ~LengthsSplitOp() override {} | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | const auto& L = Input(0); | 
|  | CAFFE_ENFORCE_EQ(L.dim(), 1, "Input `LENGTHS` should be a 1D vector."); | 
|  |  | 
|  | if (InputSize() > 1) { | 
|  | // We potentially have n_split specified as inputs as well | 
|  | CAFFE_ENFORCE( | 
|  | Input(1).dim() == 1 && Input(1).numel() == 1, | 
|  | "Input `n_split` should be a vector of size 1."); | 
|  |  | 
|  | const auto& input1 = Input(1); | 
|  | context_.template CopyItems<Context, CPUContext>( | 
|  | input1.dtype(), 1, input1.raw_data(), &n_split_); | 
|  | } | 
|  |  | 
|  | CAFFE_ENFORCE( | 
|  | n_split_ > 0, | 
|  | "`n_split` must contain a positive value for defined behavior."); | 
|  | const auto M = L.numel(); | 
|  |  | 
|  | auto* Y = Output(0, {M * n_split_}, at::dtype<int32_t>()); | 
|  |  | 
|  | const int32_t* Ldata = L.template data<int32_t>(); | 
|  | int32_t* Ydata = Y->template mutable_data<int32_t>(); | 
|  |  | 
|  | for (const auto i : c10::irange(M)) { | 
|  | int32_t mod = Ldata[i] % n_split_; | 
|  | int32_t res = | 
|  | mod != 0 ? math::DivUp(Ldata[i], n_split_) : Ldata[i] / n_split_ + 1; | 
|  | for (const auto j : c10::irange(n_split_)) { | 
|  | Ydata[(i * n_split_) + j] = mod-- > 0 ? res : res - 1; | 
|  | } | 
|  | } | 
|  | return true; | 
|  | } | 
|  |  | 
|  | private: | 
|  | int32_t n_split_; | 
|  | }; | 
|  |  | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_ |