CUDA version of elementwise power + rename to Pow + gradient
Summary: Renamed ElementwisePower to Pow for better discoverability. Added CUDA version and Gradient + tests.
Reviewed By: kennyhorror
Differential Revision: D4665550
fbshipit-source-id: dd33d8ad3917d71504e363ab397af50d38a63b1f
diff --git a/caffe2/operators/elementwise_power_op.cc b/caffe2/operators/elementwise_power_op.cc
deleted file mode 100644
index fbdf699..0000000
--- a/caffe2/operators/elementwise_power_op.cc
+++ /dev/null
@@ -1,25 +0,0 @@
-#include "caffe2/operators/elementwise_power_op.h"
-
-namespace caffe2 {
-namespace {
-
-REGISTER_CPU_OPERATOR(
- ElementwisePower,
- UnaryElementwiseWithArgsOp<TensorTypes<float>, CPUContext, PowCPUFunctor>);
-
-OPERATOR_SCHEMA(ElementwisePower)
- .NumInputs(1)
- .NumOutputs(1)
- .Arg("exponent", "The exponent of the power function.")
- .AllowInplace({{0, 0}})
- .IdenticalTypeAndShape()
- .SetDoc(R"DOC(
-ElementwisePower takes input data (Tensor<T>) and an argument exponent, and
-produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
-is applied to the data tensor elementwise.
-)DOC")
- .Input(0, "X", "1D input tensor")
- .Output(0, "Y", "1D input tensor");
-
-} // namespace
-} // namespace caffe2
diff --git a/caffe2/operators/elementwise_power_op.h b/caffe2/operators/elementwise_power_op.h
deleted file mode 100644
index 09f906e..0000000
--- a/caffe2/operators/elementwise_power_op.h
+++ /dev/null
@@ -1,22 +0,0 @@
-#pragma once
-
-#include "caffe2/core/operator.h"
-#include "caffe2/operators/elementwise_op.h"
-
-namespace caffe2 {
-
-struct PowCPUFunctor {
- explicit PowCPUFunctor(OperatorBase& op) {
- exponent_ = op.GetSingleArgument<float>("exponent", 0);
- }
-
- template <typename T>
- inline void
- operator()(const int n, const T* x, T* y, CPUContext* device_context) {
- math::Powx<T, CPUContext>(n, x, exponent_, y, device_context);
- }
-
- float exponent_;
-};
-
-} // namespace caffe2
diff --git a/caffe2/operators/math_ops.cc b/caffe2/operators/math_ops.cc
index 6277e4d..9f9ee74 100644
--- a/caffe2/operators/math_ops.cc
+++ b/caffe2/operators/math_ops.cc
@@ -1,4 +1,4 @@
-#include "caffe2/operators/elementwise_op.h"
+#include "caffe2/operators/math_ops.h"
#include "caffe2/utils/math.h"
@@ -87,5 +87,59 @@
};
REGISTER_GRADIENT(Sqr, GetSqrGradient);
+REGISTER_CPU_OPERATOR(
+ Pow,
+ UnaryElementwiseWithArgsOp<TensorTypes<float>, CPUContext, PowFunctor>);
+
+OPERATOR_SCHEMA(Pow)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("exponent", "The exponent of the power function.")
+ .AllowInplace({{0, 0}})
+ .IdenticalTypeAndShape()
+ .SetDoc(R"DOC(
+Pow takes input data (Tensor<T>) and an argument exponent, and
+produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
+is applied to the data tensor elementwise.
+)DOC")
+ .Input(0, "X", "Input tensor of any shape")
+ .Output(0, "Y", "Output tensor (same size as X)");
+
+class GetPowGradient : public GradientMakerBase {
+ using GradientMakerBase::GradientMakerBase;
+ vector<OperatorDef> GetGradientDefs() override {
+ ArgumentHelper arg_helper(def_);
+ float exponent = arg_helper.GetSingleArgument<float>("exponent", 0.0);
+ Argument scale_arg;
+ scale_arg.set_name("scale");
+ scale_arg.set_f(exponent);
+ Argument pow_arg;
+ pow_arg.set_name("exponent");
+ pow_arg.set_f(exponent - 1);
+ return vector<OperatorDef>{CreateOperatorDef(
+ "Pow",
+ "",
+ std::vector<string>{I(0)},
+ std::vector<string>{GI(0)},
+ std::vector<Argument>{pow_arg}),
+ CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(0), GO(0)},
+ std::vector<string>{GI(0)}),
+ CreateOperatorDef(
+ "Scale",
+ "",
+ std::vector<string>{GI(0)},
+ std::vector<string>{GI(0)},
+ std::vector<Argument>{scale_arg})};
+ }
+ virtual bool CopyArguments() const override {
+ return false;
+ }
+};
+
+REGISTER_GRADIENT(Pow, GetPowGradient);
+
} // namespace
} // namespace caffe2
diff --git a/caffe2/operators/math_ops.cu b/caffe2/operators/math_ops.cu
index c3ba59f..50ae3de 100644
--- a/caffe2/operators/math_ops.cu
+++ b/caffe2/operators/math_ops.cu
@@ -1,7 +1,5 @@
#include "caffe2/core/context_gpu.h"
-#include "caffe2/core/operator.h"
-#include "caffe2/operators/elementwise_op.h"
-#include "caffe2/utils/math.h"
+#include "caffe2/operators/math_ops.h"
namespace caffe2 {
@@ -30,4 +28,7 @@
Sqr,
UnaryElementwiseOp<TensorTypes<float>, CUDAContext, SqrCUDAFunctor>);
}
+REGISTER_CUDA_OPERATOR(
+ Pow,
+ UnaryElementwiseWithArgsOp<TensorTypes<float>, CUDAContext, PowFunctor>);
}
diff --git a/caffe2/operators/math_ops.h b/caffe2/operators/math_ops.h
new file mode 100644
index 0000000..bc7386e
--- /dev/null
+++ b/caffe2/operators/math_ops.h
@@ -0,0 +1,29 @@
+#ifndef CAFFE2_OPERATORS_MATH_OP_H_
+#define CAFFE2_OPERATORS_MATH_OP_H_
+
+#include "caffe2/core/common_omp.h"
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/tensor.h"
+#include "caffe2/operators/elementwise_op.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+struct PowFunctor {
+ explicit PowFunctor(OperatorBase& op) {
+ exponent_ = op.GetSingleArgument<float>("exponent", 0);
+ }
+
+ template <typename T, class Context>
+ inline void
+ operator()(const int n, const T* x, T* y, Context* device_context) {
+ math::Powx<float, Context>(n, x, exponent_, y, device_context);
+ }
+
+ float exponent_;
+};
+}
+
+#endif
diff --git a/caffe2/python/hypothesis_test_util.py b/caffe2/python/hypothesis_test_util.py
index cd31b80..0b03d5a 100644
--- a/caffe2/python/hypothesis_test_util.py
+++ b/caffe2/python/hypothesis_test_util.py
@@ -526,9 +526,10 @@
output_blob_name, shapes, types, output)
outs.append(output)
if grad_reference and output_to_grad:
- self._assertGradReferenceChecks(
- op, inputs, reference_outputs,
- output_to_grad, grad_reference)
+ with core.DeviceScope(device_option):
+ self._assertGradReferenceChecks(
+ op, inputs, reference_outputs,
+ output_to_grad, grad_reference)
return outs
def assertValidationChecks(
diff --git a/caffe2/python/operator_test/elementwise_power_op_test.py b/caffe2/python/operator_test/elementwise_power_op_test.py
deleted file mode 100644
index eea915d..0000000
--- a/caffe2/python/operator_test/elementwise_power_op_test.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-from __future__ import unicode_literals
-
-from caffe2.python import core
-from hypothesis import given
-from hypothesis import strategies as st
-import caffe2.python.hypothesis_test_util as hu
-
-import unittest
-
-
-class TestElementwisePowerOp(hu.HypothesisTestCase):
-
- @given(X=hu.tensor(),
- exponent=st.floats(min_value=-1.0, max_value=1.0),
- **hu.gcs_cpu_only)
- def test_elementwise_power(self, X, exponent, gc, dc):
- def elementwise_power(X):
- return (X ** exponent,)
-
- op = core.CreateOperator(
- "ElementwisePower", ["X"], ["Y"], exponent=exponent)
- self.assertDeviceChecks(dc, op, [X], [0])
- self.assertReferenceChecks(gc, op, [X], elementwise_power)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/caffe2/python/operator_test/pow_op_test.py b/caffe2/python/operator_test/pow_op_test.py
new file mode 100644
index 0000000..3c29f73
--- /dev/null
+++ b/caffe2/python/operator_test/pow_op_test.py
@@ -0,0 +1,35 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from caffe2.python import core
+from hypothesis import given
+from hypothesis import strategies as st
+import caffe2.python.hypothesis_test_util as hu
+
+import unittest
+
+
+class TestPowOp(hu.HypothesisTestCase):
+
+ @given(X=hu.tensor(),
+ exponent=st.floats(min_value=2.0, max_value=3.0),
+ **hu.gcs)
+ def test_elementwise_power(self, X, exponent, gc, dc):
+ def powf(X):
+ return (X ** exponent,)
+
+ def powf_grad(g_out, outputs, fwd_inputs):
+ return (exponent * (fwd_inputs[0] ** (exponent - 1)) * g_out,)
+
+ op = core.CreateOperator(
+ "Pow", ["X"], ["Y"], exponent=exponent)
+
+ self.assertReferenceChecks(gc, op, [X], powf,
+ output_to_grad="Y",
+ grad_reference=powf_grad),
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index bfc2705..bccd23d 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -369,11 +369,35 @@
y[i] = x[i] * (*alpha);
}
}
+
+template <typename T>
+__global__ void PowKernel(const int n, const T* x, const T exponent, T* y) {
+ CUDA_1D_KERNEL_LOOP(i, n) {
+ y[i] = powf(x[i], exponent);
+ }
+}
} // namespace
template <>
+void Powx<float, CUDAContext>(
+ const int N,
+ const float* a,
+ const float b,
+ float* y,
+ CUDAContext* context) {
+ PowKernel<<<
+ CAFFE_GET_BLOCKS(N),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(N, a, b, y);
+}
+
+template <>
void Scale<float, CUDAContext>(
- const int n, const float alpha, const float *x, float* y,
+ const int n,
+ const float alpha,
+ const float* x,
+ float* y,
CUDAContext* context) {
ScaleKernel<float><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
0, context->cuda_stream()>>>(n, alpha, x, y);