|  | #ifndef CAFFE2_OPERATORS_MEAN_OPS_H_ | 
|  | #define CAFFE2_OPERATORS_MEAN_OPS_H_ | 
|  |  | 
|  | #include "caffe2/core/common_omp.h" | 
|  | #include "caffe2/core/context.h" | 
|  | #include "caffe2/core/logging.h" | 
|  | #include "caffe2/core/operator.h" | 
|  | #include "caffe2/core/types.h" | 
|  | #include "caffe2/utils/math.h" | 
|  | #include "caffe2/utils/proto_utils.h" | 
|  | #include "c10/util/irange.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <class Context> | 
|  | class MeanOp final : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  | USE_SIMPLE_CTOR_DTOR(MeanOp) | 
|  |  | 
|  | template <typename T> | 
|  | bool DoRunWithType() { | 
|  | auto& input0 = Input(0); | 
|  |  | 
|  | auto* output = Output(0, input0.sizes(), at::dtype<T>()); | 
|  | output->CopyFrom(input0, true /*async*/); | 
|  |  | 
|  | if (InputSize() == 1) { | 
|  | return true; | 
|  | } | 
|  |  | 
|  | // Dimension checking | 
|  | for (const auto i : c10::irange(1, InputSize())) { | 
|  | if (output->sizes() != Input(i).sizes()) { | 
|  | CAFFE_THROW( | 
|  | "Check failed: output->sizes() == Input(i).sizes().", | 
|  | "Description: Input #", | 
|  | i, | 
|  | ", input dimension:", | 
|  | Input(i).sizes(), | 
|  | " should match output dimension: ", | 
|  | output->sizes()); | 
|  | } | 
|  | } | 
|  |  | 
|  | T* output_data = output->template mutable_data<T>(); | 
|  | for (const auto i : c10::irange(1, InputSize())) { | 
|  | math::Add( | 
|  | output->numel(), | 
|  | output_data, | 
|  | Input(i).template data<T>(), | 
|  | output_data, | 
|  | &context_); | 
|  | } | 
|  |  | 
|  | math::Scale( | 
|  | output->numel(), | 
|  | 1.0f / InputSize(), | 
|  | output_data, | 
|  | output_data, | 
|  | &context_); | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | if (Input(0).template IsType<float>()) { | 
|  | return DoRunWithType<float>(); | 
|  | } else if (Input(0).template IsType<double>()) { | 
|  | return DoRunWithType<double>(); | 
|  | } else { | 
|  | CAFFE_THROW( | 
|  | "Mean operator only supports 32-bit float or 64-bit double, but", | 
|  | " input was of type ", | 
|  | Input(0).dtype().name()); | 
|  | } | 
|  | } | 
|  | }; | 
|  |  | 
|  | template <class Context> | 
|  | class MeanGradientOp : public Operator<Context> { | 
|  | public: | 
|  | USE_OPERATOR_CONTEXT_FUNCTIONS; | 
|  |  | 
|  | template <class... Args> | 
|  | explicit MeanGradientOp(Args&&... args) | 
|  | : Operator<Context>(std::forward<Args>(args)...) {} | 
|  |  | 
|  | template <typename T> | 
|  | bool DoRunWithType() { | 
|  | auto& dY = Input(0); | 
|  | const auto* dY_data = dY.template data<T>(); | 
|  | int size = dY.numel(); | 
|  |  | 
|  | int num_inputs = OutputSize(); | 
|  | float scale = 1.0f / num_inputs; | 
|  |  | 
|  | // dX0 = scale * dY | 
|  |  | 
|  | auto* dX0 = Output(0, dY.sizes(), at::dtype<T>()); | 
|  | math::Scale( | 
|  | size, scale, dY_data, dX0->template mutable_data<T>(), &context_); | 
|  |  | 
|  | // Copy the rest dX | 
|  | for (const auto i : c10::irange(1, num_inputs)) { | 
|  | auto* cur_dX = Output(i); | 
|  | cur_dX->ResizeLike(dY); | 
|  | cur_dX->CopyFrom(*dX0, true /*async*/); | 
|  | } | 
|  |  | 
|  | return true; | 
|  | } | 
|  |  | 
|  | bool RunOnDevice() override { | 
|  | if (Input(0).template IsType<float>()) { | 
|  | return DoRunWithType<float>(); | 
|  | } else if (Input(0).template IsType<double>()) { | 
|  | return DoRunWithType<double>(); | 
|  | } else { | 
|  | CAFFE_THROW( | 
|  | "Mean operator only supports 32-bit float or 64-bit double, but", | 
|  | " input was of type ", | 
|  | Input(0).dtype().name()); | 
|  | } | 
|  | } | 
|  | }; | 
|  |  | 
|  | } // namespace caffe2 | 
|  |  | 
|  | #endif // CAFFE2_OPERATORS_MEAN_OPS_H_ |