float support for square root divide

Summary: to support an operation needed by D5507205

Reviewed By: xianjiec

Differential Revision: D5512522

fbshipit-source-id: a9b3a668c28eff71d1e106dbbb572184df4a7638
diff --git a/caffe2/operators/square_root_divide_op.cc b/caffe2/operators/square_root_divide_op.cc
index 375937b..de993a10 100644
--- a/caffe2/operators/square_root_divide_op.cc
+++ b/caffe2/operators/square_root_divide_op.cc
@@ -2,9 +2,7 @@
 
 namespace caffe2 {
 
-REGISTER_CPU_OPERATOR(
-    SquareRootDivide,
-    SquareRootDivideOp<int32_t, CPUContext>);
+REGISTER_CPU_OPERATOR(SquareRootDivide, SquareRootDivideOp<CPUContext>);
 OPERATOR_SCHEMA(SquareRootDivide)
     .NumInputs(2)
     .NumOutputs(1)
@@ -18,15 +16,15 @@
 Example:
 
   Data = [
-    [1.0, 2.0],
-    [3.0, 4.0]
+    [2.0, 4.0],
+    [9.0, 12.0]
   ]
 
   SCALE = [4, 9]
 
   OUTPUT = [
-    [2.0, 4.0],
-    [9.0, 12.0]
+    [1.0, 2.0],
+    [3.0, 4.0]
   ]
 
 )DOC");
diff --git a/caffe2/operators/square_root_divide_op.h b/caffe2/operators/square_root_divide_op.h
index df018bf..dce63a8 100644
--- a/caffe2/operators/square_root_divide_op.h
+++ b/caffe2/operators/square_root_divide_op.h
@@ -7,7 +7,7 @@
 
 namespace caffe2 {
 
-template <typename TScale, class Context>
+template <class Context>
 class SquareRootDivideOp final : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -23,6 +23,12 @@
  private:
   template <typename TData>
   bool DoRunWithType() {
+    return DispatchHelper<TensorTypes2<float, int32_t, int64_t>, TData>::call(
+        this, Input(SCALE));
+  }
+
+  template <typename TData, typename TScale>
+  bool DoRunWithType2() {
     auto& data = Input(DATA);
     auto& scale = Input(SCALE);
     auto* Y = Output(0);
@@ -33,7 +39,7 @@
     auto* scalePtr = scale.template data<TScale>();
     auto* dataPtr = data.template data<TData>();
     auto* yPtr = Y->template mutable_data<TData>();
-    for (int i = 0; i < batchSize; ++i) {
+    for (auto i = 0; i < batchSize; ++i) {
       auto scale = scalePtr[i];
       CAFFE_ENFORCE(scale >= 0, scale, " < 0");
       auto multiplier = scale == 0 ? 1.0 : 1 / std::sqrt(scale);
diff --git a/caffe2/python/operator_test/square_root_divide_op_test.py b/caffe2/python/operator_test/square_root_divide_op_test.py
index 74bde30..253184a 100644
--- a/caffe2/python/operator_test/square_root_divide_op_test.py
+++ b/caffe2/python/operator_test/square_root_divide_op_test.py
@@ -17,19 +17,21 @@
         data_min_size=4, data_max_size=10,
         examples_min_number=1, examples_max_number=4,
         dtype=np.float32, elements=None):
-    dims_ = st.tuples(
+    params_ = st.tuples(
         st.integers(min_value=examples_min_number,
                     max_value=examples_max_number),
         st.integers(min_value=data_min_size,
                     max_value=data_max_size),
+        st.sampled_from([np.float32, np.int32, np.int64])
     )
-    return dims_.flatmap(
-        lambda dims: st.tuples(
-            hu.arrays([dims[0], dims[1]], dtype=dtype),
+    return params_.flatmap(
+        lambda param_: st.tuples(
+            hu.arrays([param_[0], param_[1]], dtype=dtype),
             hu.arrays(
-                [dims[0]], np.int32,
-                st.integers(min_value=5, max_value=10),
-            )
+                [param_[0]], dtype=param_[2],
+                elements=(st.floats(0.0, 10000.0) if param_[2] in [np.float32]
+                          else st.integers(0, 10000)),
+            ),
         )
     )
 
@@ -68,6 +70,7 @@
             grad_reference=grad,
         )
 
+
 if __name__ == "__main__":
     import unittest
     unittest.main()