float support for square root divide
Summary: to support an operation needed by D5507205
Reviewed By: xianjiec
Differential Revision: D5512522
fbshipit-source-id: a9b3a668c28eff71d1e106dbbb572184df4a7638
diff --git a/caffe2/operators/square_root_divide_op.cc b/caffe2/operators/square_root_divide_op.cc
index 375937b..de993a10 100644
--- a/caffe2/operators/square_root_divide_op.cc
+++ b/caffe2/operators/square_root_divide_op.cc
@@ -2,9 +2,7 @@
namespace caffe2 {
-REGISTER_CPU_OPERATOR(
- SquareRootDivide,
- SquareRootDivideOp<int32_t, CPUContext>);
+REGISTER_CPU_OPERATOR(SquareRootDivide, SquareRootDivideOp<CPUContext>);
OPERATOR_SCHEMA(SquareRootDivide)
.NumInputs(2)
.NumOutputs(1)
@@ -18,15 +16,15 @@
Example:
Data = [
- [1.0, 2.0],
- [3.0, 4.0]
+ [2.0, 4.0],
+ [9.0, 12.0]
]
SCALE = [4, 9]
OUTPUT = [
- [2.0, 4.0],
- [9.0, 12.0]
+ [1.0, 2.0],
+ [3.0, 4.0]
]
)DOC");
diff --git a/caffe2/operators/square_root_divide_op.h b/caffe2/operators/square_root_divide_op.h
index df018bf..dce63a8 100644
--- a/caffe2/operators/square_root_divide_op.h
+++ b/caffe2/operators/square_root_divide_op.h
@@ -7,7 +7,7 @@
namespace caffe2 {
-template <typename TScale, class Context>
+template <class Context>
class SquareRootDivideOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -23,6 +23,12 @@
private:
template <typename TData>
bool DoRunWithType() {
+ return DispatchHelper<TensorTypes2<float, int32_t, int64_t>, TData>::call(
+ this, Input(SCALE));
+ }
+
+ template <typename TData, typename TScale>
+ bool DoRunWithType2() {
auto& data = Input(DATA);
auto& scale = Input(SCALE);
auto* Y = Output(0);
@@ -33,7 +39,7 @@
auto* scalePtr = scale.template data<TScale>();
auto* dataPtr = data.template data<TData>();
auto* yPtr = Y->template mutable_data<TData>();
- for (int i = 0; i < batchSize; ++i) {
+ for (auto i = 0; i < batchSize; ++i) {
auto scale = scalePtr[i];
CAFFE_ENFORCE(scale >= 0, scale, " < 0");
auto multiplier = scale == 0 ? 1.0 : 1 / std::sqrt(scale);
diff --git a/caffe2/python/operator_test/square_root_divide_op_test.py b/caffe2/python/operator_test/square_root_divide_op_test.py
index 74bde30..253184a 100644
--- a/caffe2/python/operator_test/square_root_divide_op_test.py
+++ b/caffe2/python/operator_test/square_root_divide_op_test.py
@@ -17,19 +17,21 @@
data_min_size=4, data_max_size=10,
examples_min_number=1, examples_max_number=4,
dtype=np.float32, elements=None):
- dims_ = st.tuples(
+ params_ = st.tuples(
st.integers(min_value=examples_min_number,
max_value=examples_max_number),
st.integers(min_value=data_min_size,
max_value=data_max_size),
+ st.sampled_from([np.float32, np.int32, np.int64])
)
- return dims_.flatmap(
- lambda dims: st.tuples(
- hu.arrays([dims[0], dims[1]], dtype=dtype),
+ return params_.flatmap(
+ lambda param_: st.tuples(
+ hu.arrays([param_[0], param_[1]], dtype=dtype),
hu.arrays(
- [dims[0]], np.int32,
- st.integers(min_value=5, max_value=10),
- )
+ [param_[0]], dtype=param_[2],
+ elements=(st.floats(0.0, 10000.0) if param_[2] in [np.float32]
+ else st.integers(0, 10000)),
+ ),
)
)
@@ -68,6 +70,7 @@
grad_reference=grad,
)
+
if __name__ == "__main__":
import unittest
unittest.main()