[c2] cuda impl for WeightScale op (#38712)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38712
as title
Test Plan: buck test;
Reviewed By: ustctf
Differential Revision: D21586705
fbshipit-source-id: 12cd34f04f074ee12b77304055f3ba6068cf38fb
diff --git a/caffe2/python/operator_test/weight_scale_test.py b/caffe2/python/operator_test/weight_scale_test.py
index 61d5171..9988ebc 100644
--- a/caffe2/python/operator_test/weight_scale_test.py
+++ b/caffe2/python/operator_test/weight_scale_test.py
@@ -31,7 +31,7 @@
stepsize=st.integers(min_value=20, max_value=50),
upper_bound_iter=st.integers(min_value=5, max_value=100),
scale=st.floats(min_value=0.01, max_value=0.99, allow_nan=False, allow_infinity=False),
- **hu.gcs_cpu_only)
+ **hu.gcs)
def test_weight_scale(self, inputs, ITER, stepsize, upper_bound_iter, scale, gc, dc):
ITER = np.array([ITER], dtype=np.int64)
op = core.CreateOperator(
@@ -41,8 +41,11 @@
iter = iter + 1
return [w * scale if iter % stepsize == 0 and iter < upper_bound_iter else w]
+ input_device_options = {'iter': hu.cpu_do}
self.assertReferenceChecks(
gc,
op,
[inputs[0], ITER],
- functools.partial(ref_weight_scale, stepsize=stepsize, upper_bound_iter=upper_bound_iter, scale=scale))
+ functools.partial(ref_weight_scale, stepsize=stepsize, upper_bound_iter=upper_bound_iter, scale=scale),
+ input_device_options=input_device_options
+ )
diff --git a/caffe2/sgd/weight_scale_op.cc b/caffe2/sgd/weight_scale_op.cc
index f9448f4..63aaf06 100644
--- a/caffe2/sgd/weight_scale_op.cc
+++ b/caffe2/sgd/weight_scale_op.cc
@@ -14,15 +14,24 @@
* limitations under the License.
*/
-#include "weight_scale_op.h"
+#include "caffe2/sgd/weight_scale_op.h"
namespace caffe2 {
-REGISTER_CPU_OPERATOR(WeightScale, WeightScaleOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(WeightScale, WeightScaleOp<CPUContext>);
OPERATOR_SCHEMA(WeightScale)
.NumInputs(2)
.NumOutputs(1)
.AllowInplace({{0, 0}, {1, 1}})
+ .DeviceInferenceFunction([](const OperatorDef& def) {
+ auto op_device =
+ def.has_device_option() ? def.device_option() : DeviceOption();
+ vector<DeviceOption> in_dev(def.input_size(), op_device);
+ vector<DeviceOption> out_dev(def.output_size(), op_device);
+ // ITER input lives on CPU
+ in_dev[1] = DeviceOption();
+ return std::make_pair(in_dev, out_dev);
+ })
.SetDoc(R"DOC(
Every `stepsize` iterations, multiply the weights by a constant `scale`:
nw = w * scale
diff --git a/caffe2/sgd/weight_scale_op.h b/caffe2/sgd/weight_scale_op.h
index f43a1b3..868c0ff 100644
--- a/caffe2/sgd/weight_scale_op.h
+++ b/caffe2/sgd/weight_scale_op.h
@@ -42,7 +42,7 @@
caffe2::math::Scale<T, T, Context>(N, scale, w, nw, context);
}
-template <typename T, class Context>
+template <class Context>
class WeightScaleOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -54,11 +54,15 @@
update_upper_bound_(OperatorBase::GetSingleArgument<int64_t>(
"upper_bound_iter",
std::numeric_limits<int64_t>::max())),
- scale_(this->template GetSingleArgument<T>("scale", 1.0f)) {}
+ scale_(this->template GetSingleArgument<float>("scale", 1.0f)) {}
bool RunOnDevice() override {
Output(OUTPUT_WEIGHTS)->ResizeLike(Input(WEIGHTS));
+ return DispatchHelper<TensorTypes<float>>::call(this, Input(WEIGHTS));
+ }
+ template <typename T>
+ bool DoRunWithType() {
const auto iter =
OperatorBase::Input<Tensor>(ITER, CPU).template data<int64_t>()[0] + 1;
@@ -77,8 +81,9 @@
protected:
int64_t stepsize_;
int64_t update_upper_bound_;
- T scale_;
+ float scale_;
INPUT_TAGS(WEIGHTS, ITER);
OUTPUT_TAGS(OUTPUT_WEIGHTS);
};
+
} // namespace caffe2
diff --git a/caffe2/sgd/weight_scale_op_gpu.cc b/caffe2/sgd/weight_scale_op_gpu.cc
new file mode 100644
index 0000000..51efb6f
--- /dev/null
+++ b/caffe2/sgd/weight_scale_op_gpu.cc
@@ -0,0 +1,44 @@
+#include "caffe2/core/common_gpu.h"
+#include "caffe2/core/context_gpu.h"
+#include "caffe2/sgd/weight_scale_op.h"
+
+namespace caffe2 {
+REGISTER_CUDA_OPERATOR(WeightScale, WeightScaleOp<CUDAContext>);
+
+template <typename T>
+void weight_scale_update_kernel(
+ int N,
+ const T* w,
+ const T& scale,
+ int64_t iter,
+ int64_t stepsize,
+ int64_t update_upper_bound,
+ T* nw,
+ CUDAContext* context) {
+ const auto w_size = N * sizeof(float);
+ if (iter % stepsize != 0 || iter >= update_upper_bound) {
+ (void)cudaMemcpy(nw, w, w_size, cudaMemcpyDefault);
+ } else {
+ // perform the weight scaling
+ caffe2::math::Scale<T, T, CUDAContext>(N, scale, w, nw, context);
+ }
+}
+
+template <>
+template <typename T>
+bool WeightScaleOp<CUDAContext>::DoRunWithType() {
+ const auto iter =
+ OperatorBase::Input<Tensor>(ITER, CPU).template data<int64_t>()[0] + 1;
+ weight_scale_update_kernel<T>(
+ Input(WEIGHTS).size(),
+ Input(WEIGHTS).template data<T>(),
+ scale_,
+ iter,
+ stepsize_,
+ update_upper_bound_,
+ Output(OUTPUT_WEIGHTS)->template mutable_data<T>(),
+ &context_);
+ return true;
+}
+
+} // namespace caffe2