Implementing Pow operator (this merges existing pow with a scalar and new pow with a tensor exponent).
Summary: The old pow operator has been deleted in math_ops.cc, math_ops.cu and math_ops.h, while the new operator supporting scalar and tensor exponent has been added in pow_op.cc, pow_op.h an elementwise_op.cu.
Reviewed By: houseroad
Differential Revision: D6893040
fbshipit-source-id: 30f614beea6f859fee25ce4f85573142885dde45
diff --git a/caffe2/operators/elementwise_op.cu b/caffe2/operators/elementwise_op.cu
index adbc0e4..f8fb042 100644
--- a/caffe2/operators/elementwise_op.cu
+++ b/caffe2/operators/elementwise_op.cu
@@ -113,6 +113,16 @@
CUDA_FUNCTOR(Xor, CUDA_XOR, BoolTypes, FixedType<bool>);
#undef CUDA_XOR
+// pow, log and other math functions are defined in CUDA math library
+// in header file math.h
+#define CUDA_POW(x, y) (pow(x, y))
+CUDA_FUNCTOR(
+ Pow,
+ CUDA_POW,
+ TensorTypes<float> /*NumericTypes*/,
+ SameTypeAsInput);
+#undef CUDA_POW
+
__global__ void NotKernel(const int n, const bool* x, bool* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = !x[i];
diff --git a/caffe2/operators/math_ops.cc b/caffe2/operators/math_ops.cc
index f84852c..36b2f86 100644
--- a/caffe2/operators/math_ops.cc
+++ b/caffe2/operators/math_ops.cc
@@ -82,65 +82,4 @@
.IdenticalTypeAndShape();
SHOULD_NOT_DO_GRADIENT(Sign);
-REGISTER_CPU_OPERATOR(
- Pow,
- UnaryElementwiseWithArgsOp<TensorTypes<float>, CPUContext, PowFunctor>);
-
-OPERATOR_SCHEMA(Pow)
- .NumInputs(1)
- .NumOutputs(1)
- .Arg("exponent", "The exponent of the power function.")
- .AllowInplace({{0, 0}})
- .IdenticalTypeAndShape()
- .SetDoc(R"DOC(
-Pow takes input data (Tensor<T>) and an argument exponent, and
-produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
-is applied to the data tensor elementwise.
-)DOC")
- .Input(0, "X", "Input tensor of any shape")
- .Output(0, "Y", "Output tensor (same size as X)");
-
-class GetPowGradient : public GradientMakerBase {
- using GradientMakerBase::GradientMakerBase;
- vector<OperatorDef> GetGradientDefs() override {
- ArgumentHelper arg_helper(def_);
- float exponent = arg_helper.GetSingleArgument<float>("exponent", 0.0);
- Argument scale_arg;
- scale_arg.set_name("scale");
- scale_arg.set_f(exponent);
- Argument pow_arg;
- pow_arg.set_name("exponent");
- if (I(0) != O(0)) {
- pow_arg.set_f(exponent - 1);
- } else {
- LOG(WARNING) << "In-place Pow gradient, possible loss of precision";
- constexpr float kEps = 1e-12f;
- CAFFE_ENFORCE(std::fabs(exponent) > kEps);
- pow_arg.set_f((exponent - 1) / exponent);
- }
- return vector<OperatorDef>{CreateOperatorDef(
- "Pow",
- "",
- std::vector<string>{I(0)},
- std::vector<string>{GI(0)},
- std::vector<Argument>{pow_arg}),
- CreateOperatorDef(
- "Mul",
- "",
- std::vector<string>{GI(0), GO(0)},
- std::vector<string>{GI(0)}),
- CreateOperatorDef(
- "Scale",
- "",
- std::vector<string>{GI(0)},
- std::vector<string>{GI(0)},
- std::vector<Argument>{scale_arg})};
- }
- virtual bool CopyArguments() const override {
- return false;
- }
-};
-
-REGISTER_GRADIENT(Pow, GetPowGradient);
-
} // namespace caffe2
diff --git a/caffe2/operators/math_ops.cu b/caffe2/operators/math_ops.cu
index 377e941..f98a3bf 100644
--- a/caffe2/operators/math_ops.cu
+++ b/caffe2/operators/math_ops.cu
@@ -52,7 +52,4 @@
REGISTER_CUDA_OPERATOR(
Sign,
UnaryElementwiseOp<TensorTypes<float>, CUDAContext, SignCUDAFunctor>);
-REGISTER_CUDA_OPERATOR(
- Pow,
- UnaryElementwiseWithArgsOp<TensorTypes<float>, CUDAContext, PowFunctor>);
}
diff --git a/caffe2/operators/math_ops.h b/caffe2/operators/math_ops.h
index e06af24..75a3d7b 100644
--- a/caffe2/operators/math_ops.h
+++ b/caffe2/operators/math_ops.h
@@ -25,21 +25,4 @@
#include "caffe2/operators/elementwise_op.h"
#include "caffe2/utils/math.h"
-namespace caffe2 {
-
-struct PowFunctor {
- explicit PowFunctor(OperatorBase& op) {
- exponent_ = op.GetSingleArgument<float>("exponent", 0);
- }
-
- template <typename T, class Context>
- inline void
- operator()(const int n, const T* x, T* y, Context* device_context) {
- math::Powx<float, Context>(n, x, exponent_, y, device_context);
- }
-
- float exponent_;
-};
-}
-
#endif
diff --git a/caffe2/operators/pow_op.cc b/caffe2/operators/pow_op.cc
new file mode 100644
index 0000000..e7a3dd9
--- /dev/null
+++ b/caffe2/operators/pow_op.cc
@@ -0,0 +1,323 @@
+/**
+ * Copyright (c) 2018-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/pow_op.h"
+#include "caffe2/utils/math.h"
+// definition of NumericTypes and SameTypeAsInput is in below header file
+//#include "caffe2/operators/elementwise_op.h"
+#include <Eigen/Core>
+
+namespace caffe2 {
+
+#define EIGEN_POW(x, y) (x.pow(y))
+
+struct EigenPowFunctor {
+ template <int b_is_scalar, typename T1, typename T2, typename R>
+ inline void Run(size_t n, const T1* a, const T2* b, R* out, CPUContext*) {
+ if (b_is_scalar) {
+ EigenVectorArrayMap<R>(out, n) =
+ EIGEN_POW((ConstEigenVectorArrayMap<T1>(a, n)), (b[0]));
+ } else {
+ EigenVectorArrayMap<R>(out, n) = EIGEN_POW(
+ (ConstEigenVectorArrayMap<T1>(a, n)),
+ (ConstEigenVectorArrayMap<T2>(b, n)));
+ }
+ }
+ template <typename T1, typename T2, typename R>
+ void RunWithBroadcast(
+ const T1* a,
+ const T2* b,
+ R* out,
+ size_t pre,
+ size_t n,
+ CPUContext*) {
+ EigenArrayMap<R>(out, n, pre) = EIGEN_POW(
+ (ConstEigenArrayMap<T1>(a, n, pre)),
+ (ConstEigenVectorArrayMap<T2>(b, n)).rowwise().replicate(pre));
+ /*
+ //below code only allows elementary ops, such as +, -, * and /,
+ //and does not allow operations, such as pow, exp and log
+ EIGEN_POW(
+ (ConstEigenArrayMap<T>(a, n, pre).colwise()),
+ (ConstEigenVectorArrayMap<T>(b, n)));
+ */
+ }
+ template <typename T1, typename T2, typename R>
+ void RunWithBroadcast2(
+ const T1* a,
+ const T2* b,
+ R* out,
+ size_t pre,
+ size_t n,
+ size_t post,
+ CPUContext*) {
+ for (int i = 0; i < pre; ++i) {
+ EigenArrayMap<R>(out + i * n * post, post, n) = EIGEN_POW(
+ (ConstEigenArrayMap<T1>(a + i * n * post, post, n)),
+ (Eigen::Map<const Eigen::Array<T2, 1, Eigen::Dynamic>>(b, n))
+ .colwise()
+ .replicate(post));
+ /*
+ //below code only allows elementary ops, such as +, -, * and /,
+ //and does not allow for operations, such as pow, exp and log
+ EIEGN_POW(
+ (ConstEigenArrayMap<T>(a + i * n * post, post, n).rowwise()),
+ (Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>>(b, n)));
+ */
+ }
+ }
+};
+
+REGISTER_CPU_OPERATOR(
+ Pow,
+ PowOp<
+ TensorTypes<float>, /*NumericTypes,*/
+ CPUContext,
+ EigenPowFunctor,
+ SameTypeAsInput>)
+
+OPERATOR_SCHEMA(Pow)
+ .NumInputs(1, 2)
+ .NumOutputs(1)
+ .Arg("exponent", "The exponent of the power function.")
+ .AllowInplace({{0, 0}, {1, 0}})
+ .SetDoc(R"DOC(
+Pow takes input data (Tensor<T>) and an argument exponent, which can be a
+scalar or another tensor. It produces one output data (Tensor<T>), where
+the function `f(x) = x^exponent` is applied to the data tensor elementwise.
+)DOC")
+ .Input(0, "X", "Input tensor of any shape")
+ .Input(1, "exponent", "The exponent of the power function.")
+ .Output(0, "Y", "Output tensor (same size as X)");
+
+class GetPowGradient : public GradientMakerBase {
+ using GradientMakerBase::GradientMakerBase;
+ vector<OperatorDef> GetGradientDefs() override {
+ ArgumentHelper arg_helper(def_);
+ if (arg_helper.HasArgument("exponent")) { // second input is a scalar
+ // function f(w,a) = w^a
+ // gradient operator with respect to first input tensor
+ // df/dw = a * w^(a-1) (all operations are component-wise)
+ float exponent = arg_helper.GetSingleArgument<float>("exponent", 0.0);
+ Argument scale_arg;
+ scale_arg.set_name("scale");
+ scale_arg.set_f(exponent);
+ Argument pow_arg;
+ pow_arg.set_name("exponent");
+ if (I(0) != O(0)) {
+ pow_arg.set_f(exponent - 1);
+ } else {
+ LOG(WARNING) << "In-place Pow gradient, possible loss of precision";
+ constexpr float kEps = 1e-12f;
+ CAFFE_ENFORCE(std::fabs(exponent) > kEps);
+ pow_arg.set_f((exponent - 1) / exponent);
+ }
+ return vector<OperatorDef>{CreateOperatorDef(
+ "Pow",
+ "",
+ std::vector<string>{I(0)},
+ std::vector<string>{GI(0)},
+ std::vector<Argument>{pow_arg}),
+ CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(0), GO(0)},
+ std::vector<string>{GI(0)}),
+ CreateOperatorDef(
+ "Scale",
+ "",
+ std::vector<string>{GI(0)},
+ std::vector<string>{GI(0)},
+ std::vector<Argument>{scale_arg})};
+ /*
+ // Alternative gradient computation
+ return vector<OperatorDef>{CreateOperatorDef(
+ "Div",
+ "",
+ std::vector<string>{O(0), I(0)},
+ std::vector<string>{GI(0)}),
+ CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(0), GO(0)},
+ std::vector<string>{GI(0)}),
+ CreateOperatorDef(
+ "Scale",
+ "",
+ std::vector<string>{GI(0)},
+ std::vector<string>{GI(0)},
+ std::vector<Argument>{scale_arg})};
+ */
+ } else { // second input is a tensor
+ CAFFE_ENFORCE(
+ Def().input(0) != Def().output(0) &&
+ Def().input(1) != Def().output(0),
+ "Gradient computation cannot be carried out if Pow uses in-place "
+ "computation: ",
+ ProtoDebugString(Def()));
+ vector<OperatorDef> grad_ops;
+ Argument one_arg;
+ one_arg.set_name("value");
+ one_arg.set_f(1);
+ Argument broadcast, axis, axis_str, order;
+ bool bflag = ArgumentHelper::HasArgument(Def(), "broadcast");
+
+ if (bflag) {
+ if (ArgumentHelper::HasArgument(Def(), "broadcast")) {
+ broadcast = GetArgument(Def(), "broadcast");
+ } else {
+ broadcast = MakeArgument<int>("broadcast", 0);
+ }
+ if (ArgumentHelper::HasArgument(Def(), "axis")) {
+ axis = GetArgument(Def(), "axis");
+ } else {
+ axis = MakeArgument<int>("axis", -1);
+ }
+ if (ArgumentHelper::HasArgument(Def(), "axis_str")) {
+ axis_str = GetArgument(Def(), "axis_str");
+ } else {
+ axis_str = MakeArgument<string>("axis_str", "");
+ }
+ if (ArgumentHelper::HasArgument(Def(), "order")) {
+ order = GetArgument(Def(), "order");
+ } else {
+ order = MakeArgument<string>("order", "NCHW");
+ }
+ }
+
+ // function f(w,a) = w^a
+ // gradient operator with respect to first input tensor
+ // df/dw = a * w^(a-1) (all operations are component-wise)
+ grad_ops.push_back(CreateOperatorDef(
+ "ConstantFill",
+ "",
+ std::vector<string>{I(1)},
+ std::vector<string>{GI(1)},
+ std::vector<Argument>{one_arg}));
+ grad_ops.push_back(CreateOperatorDef(
+ "Sub",
+ "",
+ std::vector<string>{I(1), GI(1)},
+ std::vector<string>{GI(1)}));
+ if (bflag) {
+ grad_ops.push_back(CreateOperatorDef(
+ "Pow",
+ "",
+ std::vector<string>{I(0), GI(1)},
+ std::vector<string>{GI(0)},
+ vector<Argument>{broadcast, axis, axis_str, order}));
+ } else {
+ grad_ops.push_back(CreateOperatorDef(
+ "Pow",
+ "",
+ std::vector<string>{I(0), GI(1)},
+ std::vector<string>{GI(0)}));
+ }
+
+ grad_ops.push_back(CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(0), GO(0)},
+ std::vector<string>{GI(0)}));
+ if (bflag) {
+ grad_ops.push_back(CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(0), I(1)},
+ std::vector<string>{GI(0)},
+ vector<Argument>{broadcast, axis, axis_str, order}));
+ } else {
+ grad_ops.push_back(CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(0), I(1)},
+ std::vector<string>{GI(0)}));
+ }
+ /*
+ // Alternative gradient computation (no broadcast support)
+ grad_ops.push_back(CreateOperatorDef(
+ "Div",
+ "",
+ std::vector<string>{O(0), I(0)},
+ std::vector<string>{GI(0)}));
+ grad_ops.push_back(CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(0), GO(0)},
+ std::vector<string>{GI(0)}));
+ grad_ops.push_back(CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(0), I(1)},
+ std::vector<string>{GI(0)}));
+ */
+ // gradient operator for with respect to second input tensor
+ // df/da = w^a * ln w (all operations are component-wise)
+ /*
+ // reset GI(1) to zero
+ Argument zero_arg;
+ zero_arg.set_name("value");
+ zero_arg.set_f(0);
+ grad_ops.push_back(CreateOperatorDef(
+ "ConstantFill",
+ "",
+ std::vector<string>{I(1)},
+ std::vector<string>{GI(1)},
+ std::vector<Argument>{zero_arg}));
+ */
+ grad_ops.push_back(CreateOperatorDef(
+ "Log",
+ "",
+ std::vector<string>{I(0)},
+ std::vector<string>{GI(1) + "_autogen_pre_red"}));
+ grad_ops.push_back(CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(1) + "_autogen_pre_red", O(0)},
+ std::vector<string>{GI(1) + "_autogen_pre_red"}));
+ if (bflag) {
+ grad_ops.push_back(CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(1) + "_autogen_pre_red", GO(0)},
+ std::vector<string>{GI(1) + "_autogen_pre_red"}));
+ grad_ops.push_back(CreateOperatorDef(
+ "SumReduceLike",
+ "",
+ vector<string>{GI(1) + "_autogen_pre_red", I(1)},
+ vector<string>{GI(1)},
+ vector<Argument>{axis, axis_str, order}));
+ } else {
+ grad_ops.push_back(CreateOperatorDef(
+ "Mul",
+ "",
+ std::vector<string>{GI(1) + "_autogen_pre_red", GO(0)},
+ std::vector<string>{GI(1)}));
+ }
+
+ return grad_ops;
+ }
+ }
+
+ // Argument `shape` is no longer needed in backprop.
+ bool CopyArguments() const override {
+ return false;
+ }
+};
+
+REGISTER_GRADIENT(Pow, GetPowGradient);
+
+} // namespace caffe2
diff --git a/caffe2/operators/pow_op.h b/caffe2/operators/pow_op.h
new file mode 100644
index 0000000..579210a
--- /dev/null
+++ b/caffe2/operators/pow_op.h
@@ -0,0 +1,149 @@
+/**
+ * Copyright (c) 2018-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_OPERATORS_POW_OP_H_
+#define CAFFE2_OPERATORS_POW_OP_H_
+
+#include "caffe2/core/common_omp.h"
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/utils/math.h"
+// definition of NumericTypes and SameTypeAsInput is in below header file
+#include "caffe2/operators/elementwise_op.h"
+
+namespace caffe2 {
+
+template <
+ typename InputTypes,
+ class Context,
+ class Functor,
+ class TypeMap = SameTypeAsInput>
+class PowOp : public Operator<Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+
+ PowOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<Context>(operator_def, ws),
+ OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0),
+ OP_SINGLE_ARG(int, "axis", axis_, -1),
+ OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
+ OP_SINGLE_ARG(string, "order", order_, "NCHW"),
+ functor_() {
+ if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
+ exponent_ = this->template GetSingleArgument<float>(
+ "exponent", 0); // based on pow_ops.h
+ } else if (InputSize() == 2) { // BinaryElementwiseOp
+ // Figure out the correct axis to use.
+ if (enable_broadcast_) {
+ if (axis_ != -1) {
+ // Get axis from an explicit axis argument.
+ CAFFE_ENFORCE_EQ(
+ axis_str_.size(),
+ 0,
+ "Args axis and axis_str cannot be used simultaneously.");
+ } else if (axis_str_.size()) {
+ // Get the axis index semantically.
+ CAFFE_ENFORCE_EQ(
+ axis_str_.size(), 1, "Unsupported axis string", axis_str_);
+ size_t semantic_axis_ = order_.find(axis_str_);
+ CAFFE_ENFORCE_NE(
+ semantic_axis_,
+ string::npos,
+ "Unrecognizable axis string ",
+ axis_str_,
+ " from order string ",
+ order_);
+ axis_ = semantic_axis_;
+ }
+ } else {
+ CAFFE_ENFORCE(
+ axis_ == -1 && axis_str_.size() == 0,
+ "Do not specify axis or axis_str if broadcast is not enabled.");
+ }
+ } else {
+ CAFFE_THROW(
+ "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
+ }
+ }
+
+ bool RunOnDevice() override {
+ return DispatchHelper<InputTypes>::call(this, Input(0));
+ }
+
+ template <typename T>
+ bool DoRunWithType() {
+ if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
+ const auto& A = Input(0);
+ auto* C = Output(0);
+ C->ResizeLike(A);
+ const T* Adata = A.template data<T>();
+ auto* Cdata =
+ C->template mutable_data<typename TypeMap::template type<T>>();
+ functor_.template Run<true, T, float, T>(
+ A.size(), Adata, &exponent_, Cdata, &context_);
+ } else if (InputSize() == 2) { // BinaryElementwiseOp
+ const auto& A = Input(0);
+ const auto& B = Input(1);
+ auto* C = Output(0);
+ CAFFE_ENFORCE(
+ &B != C || !enable_broadcast_,
+ "In-place is allowed only with the first tensor when broadcasting");
+ C->ResizeLike(A);
+ const T* Adata = A.template data<T>();
+ const T* Bdata = B.template data<T>();
+ auto* Cdata =
+ C->template mutable_data<typename TypeMap::template type<T>>();
+ if (!enable_broadcast_) {
+ CAFFE_ENFORCE_EQ(
+ A.dims(),
+ B.dims(),
+ "Dimension mismatch - did you forget to set broadcast=1?");
+ functor_.template Run<false, T, T, T>(
+ A.size(), Adata, Bdata, Cdata, &context_);
+ } else if (B.size() == 1) {
+ functor_.template Run<true, T, T, T>(
+ A.size(), Adata, Bdata, Cdata, &context_);
+ } else {
+ size_t pre, n, post;
+ std::tie(pre, n, post) = calculate_broadcast_sizes(A, B, axis_);
+ if (post == 1) {
+ functor_.template RunWithBroadcast<T, T, T>(
+ Adata, Bdata, Cdata, pre, n, &context_);
+ } else {
+ functor_.template RunWithBroadcast2<T, T, T>(
+ Adata, Bdata, Cdata, pre, n, post, &context_);
+ }
+ }
+ } else {
+ CAFFE_THROW(
+ "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
+ }
+ return true;
+ }
+
+ private:
+ bool enable_broadcast_;
+ int axis_;
+ string axis_str_;
+ string order_;
+ float exponent_;
+ Functor functor_;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_POW_OP_H_
diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py
index d1cfa80..7492455 100644
--- a/caffe2/python/hypothesis_test.py
+++ b/caffe2/python/hypothesis_test.py
@@ -1107,6 +1107,27 @@
reference=log_ref)
self.assertGradientChecks(gc, op, [input_tensor], 0, [0])
+ @given(input_tensors=hu.tensors(n=2, elements=st.floats(min_value=2.0, max_value=3.0, allow_nan=False, allow_infinity=False)),
+ **hu.gcs_cpu_only)
+ def test_powt(self, input_tensors, gc, dc):
+ X1, X2 = input_tensors
+
+ op = core.CreateOperator(
+ "Pow",
+ ["X1", "X2"],
+ ["output"]
+ )
+
+ def powt_ref(X1, X2):
+ return (np.power(X1,X2),)
+
+ self.assertReferenceChecks(
+ device_option=gc,
+ op=op,
+ inputs=[X1, X2],
+ reference=powt_ref)
+ self.assertGradientChecks(gc, op, [X1, X2], 0, [0])
+
def test_blobs_dequeue_timeout(self):
op = core.CreateOperator(
"CreateBlobsQueue",
diff --git a/caffe2/python/operator_test/elementwise_op_broadcast_test.py b/caffe2/python/operator_test/elementwise_op_broadcast_test.py
index e4e4411..9265453 100644
--- a/caffe2/python/operator_test/elementwise_op_broadcast_test.py
+++ b/caffe2/python/operator_test/elementwise_op_broadcast_test.py
@@ -184,6 +184,58 @@
self.assertGradientChecks(gc, op, [X, Y], 1, [0])
@given(**hu.gcs)
+ def test_broadcast_powt(self, gc, dc):
+ # Set broadcast and no axis, i.e. broadcasting last dimensions.
+ X = np.random.rand(2, 3, 4, 5).astype(np.float32)
+ Y = np.random.rand(4, 5).astype(np.float32) + 2.0
+
+ op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1)
+ workspace.FeedBlob("X", X)
+ workspace.FeedBlob("Y", Y)
+ workspace.RunOperatorOnce(op)
+ out = workspace.FetchBlob("out")
+ np.testing.assert_array_almost_equal(out, np.power(X, Y))
+ self.assertDeviceChecks(dc, op, [X, Y], [0])
+ self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+
+ # broadcasting intermediate dimensions
+ X = np.random.rand(2, 3, 4, 5).astype(np.float32)
+ Y = np.random.rand(3, 4).astype(np.float32) + 2.0
+ op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1, axis=1)
+ workspace.FeedBlob("X", X)
+ workspace.FeedBlob("Y", Y)
+ workspace.RunOperatorOnce(op)
+ out = workspace.FetchBlob("out")
+ np.testing.assert_array_almost_equal(out, np.power(X, Y[:, :, np.newaxis]))
+ self.assertDeviceChecks(dc, op, [X, Y], [0])
+ self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+
+ # broadcasting the first dimension
+ X = np.random.rand(2, 3, 4, 5).astype(np.float32)
+ Y = np.random.rand(2).astype(np.float32) + 2.0
+ op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1, axis=0)
+ workspace.FeedBlob("X", X)
+ workspace.FeedBlob("Y", Y)
+ workspace.RunOperatorOnce(op)
+ out = workspace.FetchBlob("out")
+ np.testing.assert_array_almost_equal(
+ out, np.power(X, Y[:, np.newaxis, np.newaxis, np.newaxis]))
+ self.assertDeviceChecks(dc, op, [X, Y], [0])
+ self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+
+ # broadcasting with single elem dimensions at both ends
+ X = np.random.rand(2, 3, 4, 5).astype(np.float32)
+ Y = np.random.rand(1, 4, 1).astype(np.float32) + 2.0
+ op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1, axis=1)
+ workspace.FeedBlob("X", X)
+ workspace.FeedBlob("Y", Y)
+ workspace.RunOperatorOnce(op)
+ out = workspace.FetchBlob("out")
+ np.testing.assert_array_almost_equal(out, np.power(X, Y))
+ self.assertDeviceChecks(dc, op, [X, Y], [0])
+ self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+
+ @given(**hu.gcs)
def test_broadcast_scalar(self, gc, dc):
# broadcasting constant
X = np.random.rand(2, 3, 4, 5).astype(np.float32)
diff --git a/caffe2/python/operator_test/elementwise_ops_test.py b/caffe2/python/operator_test/elementwise_ops_test.py
index 3e02594..10fb7cc 100644
--- a/caffe2/python/operator_test/elementwise_ops_test.py
+++ b/caffe2/python/operator_test/elementwise_ops_test.py
@@ -75,6 +75,31 @@
self.assertGradientChecks(
gc, op, [X], 0, [0], stepsize=1e-4, threshold=1e-2)
+ @given(n=st.integers(2, 10), m=st.integers(4, 6),
+ d=st.integers(2, 3), **hu.gcs)
+ def test_powt(self, n, m, d, gc, dc):
+ X = np.random.rand(n, m, d).astype(np.float32)
+ Y = np.random.rand(n, m, d).astype(np.float32) + 2.0
+
+ def powt_op(X, Y):
+ return [np.power(X, Y)]
+
+ op = core.CreateOperator(
+ "Pow",
+ ["X", "Y"],
+ ["Z"]
+ )
+
+ self.assertReferenceChecks(
+ device_option=gc,
+ op=op,
+ inputs=[X, Y],
+ reference=powt_op,
+ )
+
+ self.assertGradientChecks(
+ gc, op, [X, Y], 0, [0], stepsize=1e-4, threshold=1e-2)
+
@given(n=st.integers(5, 6), m=st.integers(4, 6), **hu.gcs)
def test_sqr(self, n, m, gc, dc):
X = np.random.rand(n, m).astype(np.float32)