[c2] cuda impl for WeightScale op (#38712)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38712

as title

Test Plan: buck test;

Reviewed By: ustctf

Differential Revision: D21586705

fbshipit-source-id: 12cd34f04f074ee12b77304055f3ba6068cf38fb
diff --git a/caffe2/python/operator_test/weight_scale_test.py b/caffe2/python/operator_test/weight_scale_test.py
index 61d5171..9988ebc 100644
--- a/caffe2/python/operator_test/weight_scale_test.py
+++ b/caffe2/python/operator_test/weight_scale_test.py
@@ -31,7 +31,7 @@
            stepsize=st.integers(min_value=20, max_value=50),
            upper_bound_iter=st.integers(min_value=5, max_value=100),
            scale=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False),
-           **hu.gcs_cpu_only)
+           **hu.gcs)
     def test_weight_scale(self, inputs, ITER, stepsize, upper_bound_iter, scale, gc, dc):
         ITER = np.array([ITER], dtype=np.int64)
         op = core.CreateOperator(
@@ -41,8 +41,11 @@
             iter = iter + 1
             return [w * scale if iter % stepsize == 0 and iter < upper_bound_iter else w]
 
+        input_device_options = {'iter': hu.cpu_do}
         self.assertReferenceChecks(
             gc,
             op,
             [inputs[0], ITER],
-            functools.partial(ref_weight_scale, stepsize=stepsize, upper_bound_iter=upper_bound_iter, scale=scale))
+            functools.partial(ref_weight_scale, stepsize=stepsize, upper_bound_iter=upper_bound_iter, scale=scale),
+            input_device_options=input_device_options
+        )
diff --git a/caffe2/sgd/weight_scale_op.cc b/caffe2/sgd/weight_scale_op.cc
index f9448f4..63aaf06 100644
--- a/caffe2/sgd/weight_scale_op.cc
+++ b/caffe2/sgd/weight_scale_op.cc
@@ -14,15 +14,24 @@
  * limitations under the License.
  */
 
-#include "weight_scale_op.h"
+#include "caffe2/sgd/weight_scale_op.h"
 
 namespace caffe2 {
 
-REGISTER_CPU_OPERATOR(WeightScale, WeightScaleOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(WeightScale, WeightScaleOp<CPUContext>);
 OPERATOR_SCHEMA(WeightScale)
     .NumInputs(2)
     .NumOutputs(1)
     .AllowInplace({{0, 0}, {1, 1}})
+    .DeviceInferenceFunction([](const OperatorDef& def) {
+      auto op_device =
+          def.has_device_option() ? def.device_option() : DeviceOption();
+      vector<DeviceOption> in_dev(def.input_size(), op_device);
+      vector<DeviceOption> out_dev(def.output_size(), op_device);
+      // ITER input lives on CPU
+      in_dev[1] = DeviceOption();
+      return std::make_pair(in_dev, out_dev);
+    })
     .SetDoc(R"DOC(
 Every `stepsize` iterations, multiply the weights by a constant `scale`:
     nw = w * scale
diff --git a/caffe2/sgd/weight_scale_op.h b/caffe2/sgd/weight_scale_op.h
index f43a1b3..868c0ff 100644
--- a/caffe2/sgd/weight_scale_op.h
+++ b/caffe2/sgd/weight_scale_op.h
@@ -42,7 +42,7 @@
   caffe2::math::Scale<T, T, Context>(N, scale, w, nw, context);
 }
 
-template <typename T, class Context>
+template <class Context>
 class WeightScaleOp final : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -54,11 +54,15 @@
         update_upper_bound_(OperatorBase::GetSingleArgument<int64_t>(
             "upper_bound_iter",
             std::numeric_limits<int64_t>::max())),
-        scale_(this->template GetSingleArgument<T>("scale", 1.0f)) {}
+        scale_(this->template GetSingleArgument<float>("scale", 1.0f)) {}
 
   bool RunOnDevice() override {
     Output(OUTPUT_WEIGHTS)->ResizeLike(Input(WEIGHTS));
+    return DispatchHelper<TensorTypes<float>>::call(this, Input(WEIGHTS));
+  }
 
+  template <typename T>
+  bool DoRunWithType() {
     const auto iter =
         OperatorBase::Input<Tensor>(ITER, CPU).template data<int64_t>()[0] + 1;
 
@@ -77,8 +81,9 @@
  protected:
   int64_t stepsize_;
   int64_t update_upper_bound_;
-  T scale_;
+  float scale_;
   INPUT_TAGS(WEIGHTS, ITER);
   OUTPUT_TAGS(OUTPUT_WEIGHTS);
 };
+
 } // namespace caffe2
diff --git a/caffe2/sgd/weight_scale_op_gpu.cc b/caffe2/sgd/weight_scale_op_gpu.cc
new file mode 100644
index 0000000..51efb6f
--- /dev/null
+++ b/caffe2/sgd/weight_scale_op_gpu.cc
@@ -0,0 +1,44 @@
+#include "caffe2/core/common_gpu.h"
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/sgd/weight_scale_op.h"
+
+namespace caffe2 {
+REGISTER_CUDA_OPERATOR(WeightScale, WeightScaleOp<CUDAContext>);
+
+template <typename T>
+void weight_scale_update_kernel(
+    int N,
+    const T* w,
+    const T& scale,
+    int64_t iter,
+    int64_t stepsize,
+    int64_t update_upper_bound,
+    T* nw,
+    CUDAContext* context) {
+  const auto w_size = N * sizeof(float);
+  if (iter % stepsize != 0 || iter >= update_upper_bound) {
+    (void)cudaMemcpy(nw, w, w_size, cudaMemcpyDefault);
+  } else {
+    // perform the weight scaling
+    caffe2::math::Scale<T, T, CUDAContext>(N, scale, w, nw, context);
+  }
+}
+
+template <>
+template <typename T>
+bool WeightScaleOp<CUDAContext>::DoRunWithType() {
+  const auto iter =
+      OperatorBase::Input<Tensor>(ITER, CPU).template data<int64_t>()[0] + 1;
+  weight_scale_update_kernel<T>(
+      Input(WEIGHTS).size(),
+      Input(WEIGHTS).template data<T>(),
+      scale_,
+      iter,
+      stepsize_,
+      update_upper_bound_,
+      Output(OUTPUT_WEIGHTS)->template mutable_data<T>(),
+      &context_);
+  return true;
+}
+
+} // namespace caffe2