CUDA version of elementwise power + rename to Pow + gradient

Summary: Renamed ElementwisePower to Pow for better discoverability. Added CUDA version and Gradient + tests.

Reviewed By: kennyhorror

Differential Revision: D4665550

fbshipit-source-id: dd33d8ad3917d71504e363ab397af50d38a63b1f
diff --git a/caffe2/operators/elementwise_power_op.cc b/caffe2/operators/elementwise_power_op.cc
deleted file mode 100644
index fbdf699..0000000
--- a/caffe2/operators/elementwise_power_op.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "caffe2/operators/elementwise_power_op.h"
-
-namespace caffe2 {
-namespace {
-
-REGISTER_CPU_OPERATOR(
-    ElementwisePower,
-    UnaryElementwiseWithArgsOp<TensorTypes<float>, CPUContext, PowCPUFunctor>);
-
-OPERATOR_SCHEMA(ElementwisePower)
-    .NumInputs(1)
-    .NumOutputs(1)
-    .Arg("exponent", "The exponent of the power function.")
-    .AllowInplace({{0, 0}})
-    .IdenticalTypeAndShape()
-    .SetDoc(R"DOC(
-ElementwisePower takes input data (Tensor<T>) and an argument exponent, and
-produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
-is applied to the data tensor elementwise.
-)DOC")
-    .Input(0, "X", "1D input tensor")
-    .Output(0, "Y", "1D input tensor");
-
-} // namespace
-} // namespace caffe2
diff --git a/caffe2/operators/elementwise_power_op.h b/caffe2/operators/elementwise_power_op.h
deleted file mode 100644
index 09f906e..0000000
--- a/caffe2/operators/elementwise_power_op.h
+++ /dev/null
@@ -1,22 +0,0 @@
-#pragma once
-
-#include "caffe2/core/operator.h"
-#include "caffe2/operators/elementwise_op.h"
-
-namespace caffe2 {
-
-struct PowCPUFunctor {
-  explicit PowCPUFunctor(OperatorBase& op) {
-    exponent_ = op.GetSingleArgument<float>("exponent", 0);
-  }
-
-  template <typename T>
-  inline void
-  operator()(const int n, const T* x, T* y, CPUContext* device_context) {
-    math::Powx<T, CPUContext>(n, x, exponent_, y, device_context);
-  }
-
-  float exponent_;
-};
-
-} // namespace caffe2
diff --git a/caffe2/operators/math_ops.cc b/caffe2/operators/math_ops.cc
index 6277e4d..9f9ee74 100644
--- a/caffe2/operators/math_ops.cc
+++ b/caffe2/operators/math_ops.cc
@@ -1,4 +1,4 @@
-#include "caffe2/operators/elementwise_op.h"
+#include "caffe2/operators/math_ops.h"
 #include "caffe2/utils/math.h"
 
 
@@ -87,5 +87,59 @@
 };
 REGISTER_GRADIENT(Sqr, GetSqrGradient);
 
+REGISTER_CPU_OPERATOR(
+    Pow,
+    UnaryElementwiseWithArgsOp<TensorTypes<float>, CPUContext, PowFunctor>);
+
+OPERATOR_SCHEMA(Pow)
+    .NumInputs(1)
+    .NumOutputs(1)
+    .Arg("exponent", "The exponent of the power function.")
+    .AllowInplace({{0, 0}})
+    .IdenticalTypeAndShape()
+    .SetDoc(R"DOC(
+Pow takes input data (Tensor<T>) and an argument exponent, and
+produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
+is applied to the data tensor elementwise.
+)DOC")
+    .Input(0, "X", "Input tensor of any shape")
+    .Output(0, "Y", "Output tensor (same size as X)");
+
+class GetPowGradient : public GradientMakerBase {
+  using GradientMakerBase::GradientMakerBase;
+  vector<OperatorDef> GetGradientDefs() override {
+    ArgumentHelper arg_helper(def_);
+    float exponent = arg_helper.GetSingleArgument<float>("exponent", 0.0);
+    Argument scale_arg;
+    scale_arg.set_name("scale");
+    scale_arg.set_f(exponent);
+    Argument pow_arg;
+    pow_arg.set_name("exponent");
+    pow_arg.set_f(exponent - 1);
+    return vector<OperatorDef>{CreateOperatorDef(
+                                   "Pow",
+                                   "",
+                                   std::vector<string>{I(0)},
+                                   std::vector<string>{GI(0)},
+                                   std::vector<Argument>{pow_arg}),
+                               CreateOperatorDef(
+                                   "Mul",
+                                   "",
+                                   std::vector<string>{GI(0), GO(0)},
+                                   std::vector<string>{GI(0)}),
+                               CreateOperatorDef(
+                                   "Scale",
+                                   "",
+                                   std::vector<string>{GI(0)},
+                                   std::vector<string>{GI(0)},
+                                   std::vector<Argument>{scale_arg})};
+  }
+  virtual bool CopyArguments() const override {
+    return false;
+  }
+};
+
+REGISTER_GRADIENT(Pow, GetPowGradient);
+
 } // namespace
 } // namespace caffe2
diff --git a/caffe2/operators/math_ops.cu b/caffe2/operators/math_ops.cu
index c3ba59f..50ae3de 100644
--- a/caffe2/operators/math_ops.cu
+++ b/caffe2/operators/math_ops.cu
@@ -1,7 +1,5 @@
 #include "caffe2/core/context_gpu.h"
-#include "caffe2/core/operator.h"
-#include "caffe2/operators/elementwise_op.h"
-#include "caffe2/utils/math.h"
+#include "caffe2/operators/math_ops.h"
 
 namespace caffe2 {
 
@@ -30,4 +28,7 @@
     Sqr,
     UnaryElementwiseOp<TensorTypes<float>, CUDAContext, SqrCUDAFunctor>);
 }
+REGISTER_CUDA_OPERATOR(
+    Pow,
+    UnaryElementwiseWithArgsOp<TensorTypes<float>, CUDAContext, PowFunctor>);
 }
diff --git a/caffe2/operators/math_ops.h b/caffe2/operators/math_ops.h
new file mode 100644
index 0000000..bc7386e
--- /dev/null
+++ b/caffe2/operators/math_ops.h
@@ -0,0 +1,29 @@
+#ifndef CAFFE2_OPERATORS_MATH_OP_H_
+#define CAFFE2_OPERATORS_MATH_OP_H_
+
+#include "caffe2/core/common_omp.h"
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor.h"
+#include "caffe2/operators/elementwise_op.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+struct PowFunctor {
+  explicit PowFunctor(OperatorBase& op) {
+    exponent_ = op.GetSingleArgument<float>("exponent", 0);
+  }
+
+  template <typename T, class Context>
+  inline void
+  operator()(const int n, const T* x, T* y, Context* device_context) {
+    math::Powx<float, Context>(n, x, exponent_, y, device_context);
+  }
+
+  float exponent_;
+};
+}
+
+#endif
diff --git a/caffe2/python/hypothesis_test_util.py b/caffe2/python/hypothesis_test_util.py
index cd31b80..0b03d5a 100644
--- a/caffe2/python/hypothesis_test_util.py
+++ b/caffe2/python/hypothesis_test_util.py
@@ -526,9 +526,10 @@
                         output_blob_name, shapes, types, output)
                 outs.append(output)
             if grad_reference and output_to_grad:
-                self._assertGradReferenceChecks(
-                    op, inputs, reference_outputs,
-                    output_to_grad, grad_reference)
+                with core.DeviceScope(device_option):
+                    self._assertGradReferenceChecks(
+                        op, inputs, reference_outputs,
+                        output_to_grad, grad_reference)
             return outs
 
     def assertValidationChecks(
diff --git a/caffe2/python/operator_test/elementwise_power_op_test.py b/caffe2/python/operator_test/elementwise_power_op_test.py
deleted file mode 100644
index eea915d..0000000
--- a/caffe2/python/operator_test/elementwise_power_op_test.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-
-from caffe2.python import core
-from hypothesis import given
-from hypothesis import strategies as st
-import caffe2.python.hypothesis_test_util as hu
-
-import unittest
-
-
-class TestElementwisePowerOp(hu.HypothesisTestCase):
-
-    @given(X=hu.tensor(),
-           exponent=st.floats(min_value=-1.0, max_value=1.0),
-           **hu.gcs_cpu_only)
-    def test_elementwise_power(self, X, exponent, gc, dc):
-        def elementwise_power(X):
-            return (X ** exponent,)
-
-        op = core.CreateOperator(
-            "ElementwisePower", ["X"], ["Y"], exponent=exponent)
-        self.assertDeviceChecks(dc, op, [X], [0])
-        self.assertReferenceChecks(gc, op, [X], elementwise_power)
-
-
-if __name__ == "__main__":
-    unittest.main()
diff --git a/caffe2/python/operator_test/pow_op_test.py b/caffe2/python/operator_test/pow_op_test.py
new file mode 100644
index 0000000..3c29f73
--- /dev/null
+++ b/caffe2/python/operator_test/pow_op_test.py
@@ -0,0 +1,35 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.python import core
+from hypothesis import given
+from hypothesis import strategies as st
+import caffe2.python.hypothesis_test_util as hu
+
+import unittest
+
+
+class TestPowOp(hu.HypothesisTestCase):
+
+    @given(X=hu.tensor(),
+           exponent=st.floats(min_value=2.0, max_value=3.0),
+           **hu.gcs)
+    def test_elementwise_power(self, X, exponent, gc, dc):
+        def powf(X):
+            return (X ** exponent,)
+
+        def powf_grad(g_out, outputs, fwd_inputs):
+            return (exponent * (fwd_inputs[0] ** (exponent - 1)) * g_out,)
+
+        op = core.CreateOperator(
+            "Pow", ["X"], ["Y"], exponent=exponent)
+
+        self.assertReferenceChecks(gc, op, [X], powf,
+                                   output_to_grad="Y",
+                                   grad_reference=powf_grad),
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index bfc2705..bccd23d 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -369,11 +369,35 @@
     y[i] = x[i] * (*alpha);
   }
 }
+
+template <typename T>
+__global__ void PowKernel(const int n, const T* x, const T exponent, T* y) {
+  CUDA_1D_KERNEL_LOOP(i, n) {
+    y[i] = powf(x[i], exponent);
+  }
+}
 }  // namespace
 
 template <>
+void Powx<float, CUDAContext>(
+    const int N,
+    const float* a,
+    const float b,
+    float* y,
+    CUDAContext* context) {
+  PowKernel<<<
+      CAFFE_GET_BLOCKS(N),
+      CAFFE_CUDA_NUM_THREADS,
+      0,
+      context->cuda_stream()>>>(N, a, b, y);
+}
+
+template <>
 void Scale<float, CUDAContext>(
-    const int n, const float alpha, const float *x, float* y,
+    const int n,
+    const float alpha,
+    const float* x,
+    float* y,
     CUDAContext* context) {
   ScaleKernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
                        0, context->cuda_stream()>>>(n, alpha, x, y);