|  | #ifndef CAFFE2_OPERATORS_ROW_MUL_H_ | 
|  | #define CAFFE2_OPERATORS_ROW_MUL_H_ | 
|  |  | 
|  | #include "caffe2/core/context.h" | 
|  | #include "caffe2/core/logging.h" | 
|  | #include "caffe2/core/operator.h" | 
|  | #include "caffe2/utils/math.h" | 
|  | #include "c10/util/irange.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | // A hacky version of Mul with broadcast | 
|  | // RowMul([mat, w], [output]) | 
|  | template <typename T, class Context> | 
|  | class RowMulOp : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | USE_SIMPLE_CTOR_DTOR(RowMulOp); | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | auto& mat = Input(0); | 
|  | auto& w = Input(1); | 
|  |  | 
|  | auto* output = Output(0, mat.sizes(), at::dtype<T>()); | 
|  | T* output_data = output->template mutable_data<T>(); | 
|  | const T* mat_data = mat.template data<T>(); | 
|  | const T* w_data = w.template data<T>(); | 
|  |  | 
|  | // Dimension checking | 
|  | CAFFE_ENFORCE_EQ( | 
|  | w.numel(), | 
|  | mat.dim32(0), | 
|  | "Length of w should be equal to the first dim of mat"); | 
|  |  | 
|  | auto block_size = mat.size_from_dim(1); | 
|  | for (const auto i : c10::irange(w.numel())) { | 
|  | size_t offset = i * block_size; | 
|  | for (const auto j : c10::irange(block_size)) { | 
|  | output_data[offset + j] = mat_data[offset + j] * w_data[i]; | 
|  | } | 
|  | } | 
|  |  | 
|  | return true; | 
|  | } | 
|  | }; | 
|  |  | 
|  | // A hacky version | 
|  | template <typename T, class Context> | 
|  | class ReduceTailSumOp : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | USE_SIMPLE_CTOR_DTOR(ReduceTailSumOp); | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | auto& mat = Input(0); | 
|  |  | 
|  | int N = mat.dim32(0); | 
|  | int block_size = mat.size_from_dim(1); | 
|  |  | 
|  | auto* output = Output(0, {N}, at::dtype<T>()); | 
|  | T* output_data = output->template mutable_data<T>(); | 
|  | const T* mat_data = mat.template data<T>(); | 
|  |  | 
|  | for (const auto i : c10::irange(N)) { | 
|  | output_data[i] = 0; | 
|  | size_t offset = i * block_size; | 
|  | for (const auto j : c10::irange(block_size)) { | 
|  | output_data[i] += mat_data[offset + j]; | 
|  | } | 
|  | } | 
|  | return true; | 
|  | } | 
|  | }; | 
|  |  | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_OPERATORS_ROW_MUL_H_ |