Implementing Pow operator (this merges existing pow with a scalar and new pow with a tensor exponent).

Summary: The old pow operator has been deleted in math_ops.cc, math_ops.cu and math_ops.h, while the new operator supporting scalar and tensor exponent has been added in pow_op.cc, pow_op.h an elementwise_op.cu.

Reviewed By: houseroad

Differential Revision: D6893040

fbshipit-source-id: 30f614beea6f859fee25ce4f85573142885dde45
diff --git a/caffe2/operators/elementwise_op.cu b/caffe2/operators/elementwise_op.cu
index adbc0e4..f8fb042 100644
--- a/caffe2/operators/elementwise_op.cu
+++ b/caffe2/operators/elementwise_op.cu
@@ -113,6 +113,16 @@
 CUDA_FUNCTOR(Xor, CUDA_XOR, BoolTypes, FixedType<bool>);
 #undef CUDA_XOR
 
+// pow, log and other math functions are defined in CUDA math library
+// in header file math.h
+#define CUDA_POW(x, y) (pow(x, y))
+CUDA_FUNCTOR(
+    Pow,
+    CUDA_POW,
+    TensorTypes<float> /*NumericTypes*/,
+    SameTypeAsInput);
+#undef CUDA_POW
+
 __global__ void NotKernel(const int n, const bool* x, bool* y) {
   CUDA_1D_KERNEL_LOOP(i, n) {
     y[i] = !x[i];
diff --git a/caffe2/operators/math_ops.cc b/caffe2/operators/math_ops.cc
index f84852c..36b2f86 100644
--- a/caffe2/operators/math_ops.cc
+++ b/caffe2/operators/math_ops.cc
@@ -82,65 +82,4 @@
     .IdenticalTypeAndShape();
 SHOULD_NOT_DO_GRADIENT(Sign);
 
-REGISTER_CPU_OPERATOR(
-    Pow,
-    UnaryElementwiseWithArgsOp<TensorTypes<float>, CPUContext, PowFunctor>);
-
-OPERATOR_SCHEMA(Pow)
-    .NumInputs(1)
-    .NumOutputs(1)
-    .Arg("exponent", "The exponent of the power function.")
-    .AllowInplace({{0, 0}})
-    .IdenticalTypeAndShape()
-    .SetDoc(R"DOC(
-Pow takes input data (Tensor<T>) and an argument exponent, and
-produces one output data (Tensor<T>) where the function `f(x) = x^exponent`,
-is applied to the data tensor elementwise.
-)DOC")
-    .Input(0, "X", "Input tensor of any shape")
-    .Output(0, "Y", "Output tensor (same size as X)");
-
-class GetPowGradient : public GradientMakerBase {
-  using GradientMakerBase::GradientMakerBase;
-  vector<OperatorDef> GetGradientDefs() override {
-    ArgumentHelper arg_helper(def_);
-    float exponent = arg_helper.GetSingleArgument<float>("exponent", 0.0);
-    Argument scale_arg;
-    scale_arg.set_name("scale");
-    scale_arg.set_f(exponent);
-    Argument pow_arg;
-    pow_arg.set_name("exponent");
-    if (I(0) != O(0)) {
-      pow_arg.set_f(exponent - 1);
-    } else {
-      LOG(WARNING) << "In-place Pow gradient, possible loss of precision";
-      constexpr float kEps = 1e-12f;
-      CAFFE_ENFORCE(std::fabs(exponent) > kEps);
-      pow_arg.set_f((exponent - 1) / exponent);
-    }
-    return vector<OperatorDef>{CreateOperatorDef(
-                                   "Pow",
-                                   "",
-                                   std::vector<string>{I(0)},
-                                   std::vector<string>{GI(0)},
-                                   std::vector<Argument>{pow_arg}),
-                               CreateOperatorDef(
-                                   "Mul",
-                                   "",
-                                   std::vector<string>{GI(0), GO(0)},
-                                   std::vector<string>{GI(0)}),
-                               CreateOperatorDef(
-                                   "Scale",
-                                   "",
-                                   std::vector<string>{GI(0)},
-                                   std::vector<string>{GI(0)},
-                                   std::vector<Argument>{scale_arg})};
-  }
-  virtual bool CopyArguments() const override {
-    return false;
-  }
-};
-
-REGISTER_GRADIENT(Pow, GetPowGradient);
-
 } // namespace caffe2
diff --git a/caffe2/operators/math_ops.cu b/caffe2/operators/math_ops.cu
index 377e941..f98a3bf 100644
--- a/caffe2/operators/math_ops.cu
+++ b/caffe2/operators/math_ops.cu
@@ -52,7 +52,4 @@
 REGISTER_CUDA_OPERATOR(
     Sign,
     UnaryElementwiseOp<TensorTypes<float>, CUDAContext, SignCUDAFunctor>);
-REGISTER_CUDA_OPERATOR(
-    Pow,
-    UnaryElementwiseWithArgsOp<TensorTypes<float>, CUDAContext, PowFunctor>);
 }
diff --git a/caffe2/operators/math_ops.h b/caffe2/operators/math_ops.h
index e06af24..75a3d7b 100644
--- a/caffe2/operators/math_ops.h
+++ b/caffe2/operators/math_ops.h
@@ -25,21 +25,4 @@
 #include "caffe2/operators/elementwise_op.h"
 #include "caffe2/utils/math.h"
 
-namespace caffe2 {
-
-struct PowFunctor {
-  explicit PowFunctor(OperatorBase& op) {
-    exponent_ = op.GetSingleArgument<float>("exponent", 0);
-  }
-
-  template <typename T, class Context>
-  inline void
-  operator()(const int n, const T* x, T* y, Context* device_context) {
-    math::Powx<float, Context>(n, x, exponent_, y, device_context);
-  }
-
-  float exponent_;
-};
-}
-
 #endif
diff --git a/caffe2/operators/pow_op.cc b/caffe2/operators/pow_op.cc
new file mode 100644
index 0000000..e7a3dd9
--- /dev/null
+++ b/caffe2/operators/pow_op.cc
@@ -0,0 +1,323 @@
+/**
+ * Copyright (c) 2018-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/pow_op.h"
+#include "caffe2/utils/math.h"
+// definition of NumericTypes and SameTypeAsInput is in below header file
+//#include "caffe2/operators/elementwise_op.h"
+#include <Eigen/Core>
+
+namespace caffe2 {
+
+#define EIGEN_POW(x, y) (x.pow(y))
+
+struct EigenPowFunctor {
+  template <int b_is_scalar, typename T1, typename T2, typename R>
+  inline void Run(size_t n, const T1* a, const T2* b, R* out, CPUContext*) {
+    if (b_is_scalar) {
+      EigenVectorArrayMap<R>(out, n) =
+          EIGEN_POW((ConstEigenVectorArrayMap<T1>(a, n)), (b[0]));
+    } else {
+      EigenVectorArrayMap<R>(out, n) = EIGEN_POW(
+          (ConstEigenVectorArrayMap<T1>(a, n)),
+          (ConstEigenVectorArrayMap<T2>(b, n)));
+    }
+  }
+  template <typename T1, typename T2, typename R>
+  void RunWithBroadcast(
+      const T1* a,
+      const T2* b,
+      R* out,
+      size_t pre,
+      size_t n,
+      CPUContext*) {
+    EigenArrayMap<R>(out, n, pre) = EIGEN_POW(
+        (ConstEigenArrayMap<T1>(a, n, pre)),
+        (ConstEigenVectorArrayMap<T2>(b, n)).rowwise().replicate(pre));
+    /*
+    //below code only allows elementary ops, such as +, -, * and /,
+    //and does not allow operations, such as pow, exp and log
+    EIGEN_POW(
+       (ConstEigenArrayMap<T>(a, n, pre).colwise()),
+       (ConstEigenVectorArrayMap<T>(b, n)));
+     */
+  }
+  template <typename T1, typename T2, typename R>
+  void RunWithBroadcast2(
+      const T1* a,
+      const T2* b,
+      R* out,
+      size_t pre,
+      size_t n,
+      size_t post,
+      CPUContext*) {
+    for (int i = 0; i < pre; ++i) {
+      EigenArrayMap<R>(out + i * n * post, post, n) = EIGEN_POW(
+          (ConstEigenArrayMap<T1>(a + i * n * post, post, n)),
+          (Eigen::Map<const Eigen::Array<T2, 1, Eigen::Dynamic>>(b, n))
+              .colwise()
+              .replicate(post));
+      /*
+      //below code only allows elementary ops, such as +, -, * and /,
+      //and does not allow for operations, such as pow, exp and log
+      EIEGN_POW(
+        (ConstEigenArrayMap<T>(a + i * n * post, post, n).rowwise()),
+        (Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>>(b, n)));
+      */
+    }
+  }
+};
+
+REGISTER_CPU_OPERATOR(
+    Pow,
+    PowOp<
+        TensorTypes<float>, /*NumericTypes,*/
+        CPUContext,
+        EigenPowFunctor,
+        SameTypeAsInput>)
+
+OPERATOR_SCHEMA(Pow)
+    .NumInputs(1, 2)
+    .NumOutputs(1)
+    .Arg("exponent", "The exponent of the power function.")
+    .AllowInplace({{0, 0}, {1, 0}})
+    .SetDoc(R"DOC(
+Pow takes input data (Tensor<T>) and an argument exponent, which can be a
+scalar or another tensor. It produces one output data (Tensor<T>), where
+the function `f(x) = x^exponent` is applied to the data tensor elementwise.
+)DOC")
+    .Input(0, "X", "Input tensor of any shape")
+    .Input(1, "exponent", "The exponent of the power function.")
+    .Output(0, "Y", "Output tensor (same size as X)");
+
+class GetPowGradient : public GradientMakerBase {
+  using GradientMakerBase::GradientMakerBase;
+  vector<OperatorDef> GetGradientDefs() override {
+    ArgumentHelper arg_helper(def_);
+    if (arg_helper.HasArgument("exponent")) { // second input is a scalar
+      // function f(w,a) = w^a
+      // gradient operator with respect to first input tensor
+      // df/dw = a * w^(a-1) (all operations are component-wise)
+      float exponent = arg_helper.GetSingleArgument<float>("exponent", 0.0);
+      Argument scale_arg;
+      scale_arg.set_name("scale");
+      scale_arg.set_f(exponent);
+      Argument pow_arg;
+      pow_arg.set_name("exponent");
+      if (I(0) != O(0)) {
+        pow_arg.set_f(exponent - 1);
+      } else {
+        LOG(WARNING) << "In-place Pow gradient, possible loss of precision";
+        constexpr float kEps = 1e-12f;
+        CAFFE_ENFORCE(std::fabs(exponent) > kEps);
+        pow_arg.set_f((exponent - 1) / exponent);
+      }
+      return vector<OperatorDef>{CreateOperatorDef(
+                                     "Pow",
+                                     "",
+                                     std::vector<string>{I(0)},
+                                     std::vector<string>{GI(0)},
+                                     std::vector<Argument>{pow_arg}),
+                                 CreateOperatorDef(
+                                     "Mul",
+                                     "",
+                                     std::vector<string>{GI(0), GO(0)},
+                                     std::vector<string>{GI(0)}),
+                                 CreateOperatorDef(
+                                     "Scale",
+                                     "",
+                                     std::vector<string>{GI(0)},
+                                     std::vector<string>{GI(0)},
+                                     std::vector<Argument>{scale_arg})};
+      /*
+      // Alternative gradient computation
+      return vector<OperatorDef>{CreateOperatorDef(
+                                     "Div",
+                                     "",
+                                     std::vector<string>{O(0), I(0)},
+                                     std::vector<string>{GI(0)}),
+                                 CreateOperatorDef(
+                                     "Mul",
+                                     "",
+                                     std::vector<string>{GI(0), GO(0)},
+                                     std::vector<string>{GI(0)}),
+                                 CreateOperatorDef(
+                                     "Scale",
+                                     "",
+                                     std::vector<string>{GI(0)},
+                                     std::vector<string>{GI(0)},
+                                     std::vector<Argument>{scale_arg})};
+      */
+    } else { // second input is a tensor
+      CAFFE_ENFORCE(
+          Def().input(0) != Def().output(0) &&
+              Def().input(1) != Def().output(0),
+          "Gradient computation cannot be carried out if Pow uses in-place "
+          "computation: ",
+          ProtoDebugString(Def()));
+      vector<OperatorDef> grad_ops;
+      Argument one_arg;
+      one_arg.set_name("value");
+      one_arg.set_f(1);
+      Argument broadcast, axis, axis_str, order;
+      bool bflag = ArgumentHelper::HasArgument(Def(), "broadcast");
+
+      if (bflag) {
+        if (ArgumentHelper::HasArgument(Def(), "broadcast")) {
+          broadcast = GetArgument(Def(), "broadcast");
+        } else {
+          broadcast = MakeArgument<int>("broadcast", 0);
+        }
+        if (ArgumentHelper::HasArgument(Def(), "axis")) {
+          axis = GetArgument(Def(), "axis");
+        } else {
+          axis = MakeArgument<int>("axis", -1);
+        }
+        if (ArgumentHelper::HasArgument(Def(), "axis_str")) {
+          axis_str = GetArgument(Def(), "axis_str");
+        } else {
+          axis_str = MakeArgument<string>("axis_str", "");
+        }
+        if (ArgumentHelper::HasArgument(Def(), "order")) {
+          order = GetArgument(Def(), "order");
+        } else {
+          order = MakeArgument<string>("order", "NCHW");
+        }
+      }
+
+      // function f(w,a) = w^a
+      // gradient operator with respect to first input tensor
+      // df/dw = a * w^(a-1) (all operations are component-wise)
+      grad_ops.push_back(CreateOperatorDef(
+          "ConstantFill",
+          "",
+          std::vector<string>{I(1)},
+          std::vector<string>{GI(1)},
+          std::vector<Argument>{one_arg}));
+      grad_ops.push_back(CreateOperatorDef(
+          "Sub",
+          "",
+          std::vector<string>{I(1), GI(1)},
+          std::vector<string>{GI(1)}));
+      if (bflag) {
+        grad_ops.push_back(CreateOperatorDef(
+            "Pow",
+            "",
+            std::vector<string>{I(0), GI(1)},
+            std::vector<string>{GI(0)},
+            vector<Argument>{broadcast, axis, axis_str, order}));
+      } else {
+        grad_ops.push_back(CreateOperatorDef(
+            "Pow",
+            "",
+            std::vector<string>{I(0), GI(1)},
+            std::vector<string>{GI(0)}));
+      }
+
+      grad_ops.push_back(CreateOperatorDef(
+          "Mul",
+          "",
+          std::vector<string>{GI(0), GO(0)},
+          std::vector<string>{GI(0)}));
+      if (bflag) {
+        grad_ops.push_back(CreateOperatorDef(
+            "Mul",
+            "",
+            std::vector<string>{GI(0), I(1)},
+            std::vector<string>{GI(0)},
+            vector<Argument>{broadcast, axis, axis_str, order}));
+      } else {
+        grad_ops.push_back(CreateOperatorDef(
+            "Mul",
+            "",
+            std::vector<string>{GI(0), I(1)},
+            std::vector<string>{GI(0)}));
+      }
+      /*
+      // Alternative gradient computation (no broadcast support)
+      grad_ops.push_back(CreateOperatorDef(
+                           "Div",
+                           "",
+                           std::vector<string>{O(0), I(0)},
+                           std::vector<string>{GI(0)}));
+      grad_ops.push_back(CreateOperatorDef(
+                           "Mul",
+                           "",
+                           std::vector<string>{GI(0), GO(0)},
+                           std::vector<string>{GI(0)}));
+      grad_ops.push_back(CreateOperatorDef(
+                           "Mul",
+                           "",
+                           std::vector<string>{GI(0), I(1)},
+                           std::vector<string>{GI(0)}));
+      */
+      // gradient operator for with respect to second input tensor
+      // df/da =  w^a * ln w (all operations are component-wise)
+      /*
+      // reset GI(1) to zero
+      Argument zero_arg;
+      zero_arg.set_name("value");
+      zero_arg.set_f(0);
+      grad_ops.push_back(CreateOperatorDef(
+          "ConstantFill",
+          "",
+          std::vector<string>{I(1)},
+          std::vector<string>{GI(1)},
+          std::vector<Argument>{zero_arg}));
+      */
+      grad_ops.push_back(CreateOperatorDef(
+          "Log",
+          "",
+          std::vector<string>{I(0)},
+          std::vector<string>{GI(1) + "_autogen_pre_red"}));
+      grad_ops.push_back(CreateOperatorDef(
+          "Mul",
+          "",
+          std::vector<string>{GI(1) + "_autogen_pre_red", O(0)},
+          std::vector<string>{GI(1) + "_autogen_pre_red"}));
+      if (bflag) {
+        grad_ops.push_back(CreateOperatorDef(
+            "Mul",
+            "",
+            std::vector<string>{GI(1) + "_autogen_pre_red", GO(0)},
+            std::vector<string>{GI(1) + "_autogen_pre_red"}));
+        grad_ops.push_back(CreateOperatorDef(
+            "SumReduceLike",
+            "",
+            vector<string>{GI(1) + "_autogen_pre_red", I(1)},
+            vector<string>{GI(1)},
+            vector<Argument>{axis, axis_str, order}));
+      } else {
+        grad_ops.push_back(CreateOperatorDef(
+            "Mul",
+            "",
+            std::vector<string>{GI(1) + "_autogen_pre_red", GO(0)},
+            std::vector<string>{GI(1)}));
+      }
+
+      return grad_ops;
+    }
+  }
+
+  // Argument `shape` is no longer needed in backprop.
+  bool CopyArguments() const override {
+    return false;
+  }
+};
+
+REGISTER_GRADIENT(Pow, GetPowGradient);
+
+} // namespace caffe2
diff --git a/caffe2/operators/pow_op.h b/caffe2/operators/pow_op.h
new file mode 100644
index 0000000..579210a
--- /dev/null
+++ b/caffe2/operators/pow_op.h
@@ -0,0 +1,149 @@
+/**
+ * Copyright (c) 2018-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_OPERATORS_POW_OP_H_
+#define CAFFE2_OPERATORS_POW_OP_H_
+
+#include "caffe2/core/common_omp.h"
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/utils/math.h"
+// definition of NumericTypes and SameTypeAsInput is in below header file
+#include "caffe2/operators/elementwise_op.h"
+
+namespace caffe2 {
+
+template <
+    typename InputTypes,
+    class Context,
+    class Functor,
+    class TypeMap = SameTypeAsInput>
+class PowOp : public Operator<Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+
+  PowOp(const OperatorDef& operator_def, Workspace* ws)
+      : Operator<Context>(operator_def, ws),
+        OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0),
+        OP_SINGLE_ARG(int, "axis", axis_, -1),
+        OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
+        OP_SINGLE_ARG(string, "order", order_, "NCHW"),
+        functor_() {
+    if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
+      exponent_ = this->template GetSingleArgument<float>(
+          "exponent", 0); // based on pow_ops.h
+    } else if (InputSize() == 2) { // BinaryElementwiseOp
+      // Figure out the correct axis to use.
+      if (enable_broadcast_) {
+        if (axis_ != -1) {
+          // Get axis from an explicit axis argument.
+          CAFFE_ENFORCE_EQ(
+              axis_str_.size(),
+              0,
+              "Args axis and axis_str cannot be used simultaneously.");
+        } else if (axis_str_.size()) {
+          // Get the axis index semantically.
+          CAFFE_ENFORCE_EQ(
+              axis_str_.size(), 1, "Unsupported axis string", axis_str_);
+          size_t semantic_axis_ = order_.find(axis_str_);
+          CAFFE_ENFORCE_NE(
+              semantic_axis_,
+              string::npos,
+              "Unrecognizable axis string ",
+              axis_str_,
+              " from order string ",
+              order_);
+          axis_ = semantic_axis_;
+        }
+      } else {
+        CAFFE_ENFORCE(
+            axis_ == -1 && axis_str_.size() == 0,
+            "Do not specify axis or axis_str if broadcast is not enabled.");
+      }
+    } else {
+      CAFFE_THROW(
+          "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
+    }
+  }
+
+  bool RunOnDevice() override {
+    return DispatchHelper<InputTypes>::call(this, Input(0));
+  }
+
+  template <typename T>
+  bool DoRunWithType() {
+    if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
+      const auto& A = Input(0);
+      auto* C = Output(0);
+      C->ResizeLike(A);
+      const T* Adata = A.template data<T>();
+      auto* Cdata =
+          C->template mutable_data<typename TypeMap::template type<T>>();
+      functor_.template Run<true, T, float, T>(
+          A.size(), Adata, &exponent_, Cdata, &context_);
+    } else if (InputSize() == 2) { // BinaryElementwiseOp
+      const auto& A = Input(0);
+      const auto& B = Input(1);
+      auto* C = Output(0);
+      CAFFE_ENFORCE(
+          &B != C || !enable_broadcast_,
+          "In-place is allowed only with the first tensor when broadcasting");
+      C->ResizeLike(A);
+      const T* Adata = A.template data<T>();
+      const T* Bdata = B.template data<T>();
+      auto* Cdata =
+          C->template mutable_data<typename TypeMap::template type<T>>();
+      if (!enable_broadcast_) {
+        CAFFE_ENFORCE_EQ(
+            A.dims(),
+            B.dims(),
+            "Dimension mismatch - did you forget to set broadcast=1?");
+        functor_.template Run<false, T, T, T>(
+            A.size(), Adata, Bdata, Cdata, &context_);
+      } else if (B.size() == 1) {
+        functor_.template Run<true, T, T, T>(
+            A.size(), Adata, Bdata, Cdata, &context_);
+      } else {
+        size_t pre, n, post;
+        std::tie(pre, n, post) = calculate_broadcast_sizes(A, B, axis_);
+        if (post == 1) {
+          functor_.template RunWithBroadcast<T, T, T>(
+              Adata, Bdata, Cdata, pre, n, &context_);
+        } else {
+          functor_.template RunWithBroadcast2<T, T, T>(
+              Adata, Bdata, Cdata, pre, n, post, &context_);
+        }
+      }
+    } else {
+      CAFFE_THROW(
+          "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
+    }
+    return true;
+  }
+
+ private:
+  bool enable_broadcast_;
+  int axis_;
+  string axis_str_;
+  string order_;
+  float exponent_;
+  Functor functor_;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_POW_OP_H_
diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py
index d1cfa80..7492455 100644
--- a/caffe2/python/hypothesis_test.py
+++ b/caffe2/python/hypothesis_test.py
@@ -1107,6 +1107,27 @@
             reference=log_ref)
         self.assertGradientChecks(gc, op, [input_tensor], 0, [0])
 
+    @given(input_tensors=hu.tensors(n=2, elements=st.floats(min_value=2.0, max_value=3.0, allow_nan=False, allow_infinity=False)),
+           **hu.gcs_cpu_only)
+    def test_powt(self, input_tensors, gc, dc):
+        X1, X2 = input_tensors
+
+        op = core.CreateOperator(
+            "Pow",
+            ["X1", "X2"],
+            ["output"]
+        )
+
+        def powt_ref(X1, X2):
+            return (np.power(X1,X2),)
+
+        self.assertReferenceChecks(
+            device_option=gc,
+            op=op,
+            inputs=[X1, X2],
+            reference=powt_ref)
+        self.assertGradientChecks(gc, op, [X1, X2], 0, [0])
+
     def test_blobs_dequeue_timeout(self):
         op = core.CreateOperator(
             "CreateBlobsQueue",
diff --git a/caffe2/python/operator_test/elementwise_op_broadcast_test.py b/caffe2/python/operator_test/elementwise_op_broadcast_test.py
index e4e4411..9265453 100644
--- a/caffe2/python/operator_test/elementwise_op_broadcast_test.py
+++ b/caffe2/python/operator_test/elementwise_op_broadcast_test.py
@@ -184,6 +184,58 @@
         self.assertGradientChecks(gc, op, [X, Y], 1, [0])
 
     @given(**hu.gcs)
+    def test_broadcast_powt(self, gc, dc):
+        # Set broadcast and no axis, i.e. broadcasting last dimensions.
+        X = np.random.rand(2, 3, 4, 5).astype(np.float32)
+        Y = np.random.rand(4, 5).astype(np.float32) + 2.0
+
+        op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1)
+        workspace.FeedBlob("X", X)
+        workspace.FeedBlob("Y", Y)
+        workspace.RunOperatorOnce(op)
+        out = workspace.FetchBlob("out")
+        np.testing.assert_array_almost_equal(out, np.power(X, Y))
+        self.assertDeviceChecks(dc, op, [X, Y], [0])
+        self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+
+        # broadcasting intermediate dimensions
+        X = np.random.rand(2, 3, 4, 5).astype(np.float32)
+        Y = np.random.rand(3, 4).astype(np.float32) + 2.0
+        op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1, axis=1)
+        workspace.FeedBlob("X", X)
+        workspace.FeedBlob("Y", Y)
+        workspace.RunOperatorOnce(op)
+        out = workspace.FetchBlob("out")
+        np.testing.assert_array_almost_equal(out, np.power(X, Y[:, :, np.newaxis]))
+        self.assertDeviceChecks(dc, op, [X, Y], [0])
+        self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+
+        # broadcasting the first dimension
+        X = np.random.rand(2, 3, 4, 5).astype(np.float32)
+        Y = np.random.rand(2).astype(np.float32) + 2.0
+        op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1, axis=0)
+        workspace.FeedBlob("X", X)
+        workspace.FeedBlob("Y", Y)
+        workspace.RunOperatorOnce(op)
+        out = workspace.FetchBlob("out")
+        np.testing.assert_array_almost_equal(
+            out, np.power(X, Y[:, np.newaxis, np.newaxis, np.newaxis]))
+        self.assertDeviceChecks(dc, op, [X, Y], [0])
+        self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+
+        # broadcasting with single elem dimensions at both ends
+        X = np.random.rand(2, 3, 4, 5).astype(np.float32)
+        Y = np.random.rand(1, 4, 1).astype(np.float32) + 2.0
+        op = core.CreateOperator("Pow", ["X", "Y"], "out", broadcast=1, axis=1)
+        workspace.FeedBlob("X", X)
+        workspace.FeedBlob("Y", Y)
+        workspace.RunOperatorOnce(op)
+        out = workspace.FetchBlob("out")
+        np.testing.assert_array_almost_equal(out, np.power(X, Y))
+        self.assertDeviceChecks(dc, op, [X, Y], [0])
+        self.assertGradientChecks(gc, op, [X, Y], 1, [0])
+
+    @given(**hu.gcs)
     def test_broadcast_scalar(self, gc, dc):
         # broadcasting constant
         X = np.random.rand(2, 3, 4, 5).astype(np.float32)
diff --git a/caffe2/python/operator_test/elementwise_ops_test.py b/caffe2/python/operator_test/elementwise_ops_test.py
index 3e02594..10fb7cc 100644
--- a/caffe2/python/operator_test/elementwise_ops_test.py
+++ b/caffe2/python/operator_test/elementwise_ops_test.py
@@ -75,6 +75,31 @@
         self.assertGradientChecks(
             gc, op, [X], 0, [0], stepsize=1e-4, threshold=1e-2)
 
+    @given(n=st.integers(2, 10), m=st.integers(4, 6),
+           d=st.integers(2, 3), **hu.gcs)
+    def test_powt(self, n, m, d, gc, dc):
+        X = np.random.rand(n, m, d).astype(np.float32)
+        Y = np.random.rand(n, m, d).astype(np.float32) + 2.0
+
+        def powt_op(X, Y):
+            return [np.power(X, Y)]
+
+        op = core.CreateOperator(
+            "Pow",
+            ["X", "Y"],
+            ["Z"]
+        )
+
+        self.assertReferenceChecks(
+            device_option=gc,
+            op=op,
+            inputs=[X, Y],
+            reference=powt_op,
+        )
+
+        self.assertGradientChecks(
+            gc, op, [X, Y], 0, [0], stepsize=1e-4, threshold=1e-2)
+
     @given(n=st.integers(5, 6), m=st.integers(4, 6), **hu.gcs)
     def test_sqr(self, n, m, gc, dc):
         X = np.random.rand(n, m).astype(np.float32)