RowWiseSparseAdam operator

Summary: Added the RowWise functionality for SparseAdam, which saves roughly 2/3 memory usage by only keeping one first and second moment term for each row of the parameter tensor, rather than one for each individual parameter.

Differential Revision: D6679342

fbshipit-source-id: ce6fb27e35ce41a890c66f6089cd2748d10e7a44
diff --git a/caffe2/python/operator_test/adam_test.py b/caffe2/python/operator_test/adam_test.py
index 0ac0a81..7c73784 100755
--- a/caffe2/python/operator_test/adam_test.py
+++ b/caffe2/python/operator_test/adam_test.py
@@ -43,6 +43,19 @@
             (np.sqrt(mom2_out) + epsilon)
         return param_out, mom1_out, mom2_out
 
+    @staticmethod
+    def ref_row_wise_adam(param, mom1, mom2, grad, LR, ITER,
+                          beta1, beta2, epsilon):
+        t = ITER + 1
+        corrected_local_rate = LR * np.sqrt(1 - np.power(beta2, t)) / \
+            (1 - np.power(beta1, t))
+        mom1_out = (beta1 * mom1) + (1 - beta1) * np.mean(grad)
+        mom2_out = (beta2 * mom2) + (1 - beta2) * np.mean(np.square(grad))
+        param_out = param + corrected_local_rate * mom1_out / \
+            (np.sqrt(mom2_out) + epsilon)
+        return (param_out, mom1_out, mom2_out)
+
+
     @given(inputs=hu.tensors(n=4),
            ITER=st.integers(min_value=0, max_value=10000),
            LR=st.floats(min_value=0.01, max_value=0.99,
@@ -142,7 +155,87 @@
             ref_sparse,
             input_device_options=input_device_options)
 
+    @given(inputs=hu.tensors(n=2),
+           ITER=st.integers(min_value=0, max_value=10000),
+           LR=st.floats(min_value=0.01, max_value=0.99,
+                        allow_nan=False, allow_infinity=False),
+           beta1=st.floats(min_value=0.01, max_value=0.99,
+                           allow_nan=False, allow_infinity=False),
+           beta2=st.floats(min_value=0.01, max_value=0.99,
+                           allow_nan=False, allow_infinity=False),
+           epsilon=st.floats(min_value=0.01, max_value=0.99,
+                             allow_nan=False, allow_infinity=False),
+           data_strategy=st.data(),
+               **hu.gcs)
+    def test_row_wise_sparse_adam(self, inputs, ITER, LR, beta1, beta2, epsilon,
+                                  data_strategy, gc, dc):
+        param, grad = inputs
+        ITER = np.array([ITER], dtype=np.int64)
+        LR = np.array([LR], dtype=np.float32)
 
-if __name__ == "__main__":
-    import unittest
-    unittest.main()
+        # Create a 1D row-wise average sum of squared gradients tensor.
+        mom1 = data_strategy.draw(
+            hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
+                        elements=hu.elements_of_type(dtype=np.float32))
+        )
+        mom2 = data_strategy.draw(
+            hu.tensor1d(min_len=param.shape[0], max_len=param.shape[0],
+                        elements=hu.elements_of_type(dtype=np.float32))
+        )
+        mom1 = np.absolute(mom1)
+        mom2 = np.absolute(mom2)
+
+        # Create an indexing array containing values which index into grad
+        indices = data_strategy.draw(
+            hu.tensor(
+                max_dim=1,
+                min_value=1,
+                max_value=grad.shape[0],
+                dtype=np.int64,
+                elements=st.sampled_from(np.arange(grad.shape[0])),
+            ),
+        )
+
+        # Note that unlike SparseAdam, RowWiseSparseAdam uses a moment
+        # tensor that is strictly 1-dimensional and equal in length to the
+        # first dimension of the parameters, so indices must also be
+        # 1-dimensional.
+        indices = indices.flatten()
+
+        hypothesis.note('indices.shape: %s' % str(indices.shape))
+
+        # Verify that the generated indices are unique
+        hypothesis.assume(np.array_equal(np.unique(indices), np.sort(indices)))
+
+        # Sparsify grad
+        grad = grad[indices]
+
+        op = core.CreateOperator(
+            "RowWiseSparseAdam",
+            ["param", "mom1", "mom2", "indices", "grad", "lr", "iter"],
+            ["param", "mom1", "mom2"],
+            beta1=beta1, beta2=beta2, epsilon=epsilon)
+
+        def ref_row_wise_sparse(param, mom1, mom2, indices, grad, LR, ITER):
+            param_out = np.copy(param)
+            mom1_out = np.copy(mom1)
+            mom2_out = np.copy(mom2)
+            for i, index in enumerate(indices):
+                param_out[index], mom1_out[index], mom2_out[index] = \
+                    self.ref_row_wise_adam(param[index], mom1[index], mom2[index],
+                                           grad[i], LR, ITER,
+                                           beta1, beta2, epsilon)
+            return (param_out, mom1_out, mom2_out)
+
+        # Iter lives on the CPU
+        input_device_options = {'iter': hu.cpu_do}
+
+        self.assertReferenceChecks(
+            gc, op,
+            [param, mom1, mom2, indices, grad, LR, ITER],
+            ref_row_wise_sparse,
+            input_device_options=input_device_options)
+
+    if __name__ == "__main__":
+        import unittest
+        unittest.main()
diff --git a/caffe2/python/optimizer.py b/caffe2/python/optimizer.py
index 4d4629b..e692df3 100644
--- a/caffe2/python/optimizer.py
+++ b/caffe2/python/optimizer.py
@@ -592,7 +592,7 @@
 
 class AdamOptimizer(Optimizer):
     def __init__(self, alpha=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
-                 policy='fixed', sparse_dedup_aggregator=None,
+                 policy='fixed', sparse_dedup_aggregator=None, rowWise=False,
                  engine='', **kwargs):
         super(AdamOptimizer, self).__init__()
         self.alpha = alpha
@@ -601,6 +601,7 @@
         self.epsilon = epsilon
         self.policy = policy
         self.sparse_dedup_aggregator = sparse_dedup_aggregator
+        self.rowWise = rowWise
         self.engine = engine
         self.init_kwargs = kwargs
 
@@ -618,22 +619,49 @@
             **(self.init_kwargs)
         )
 
-        m1 = param_init_net.ConstantFill(
-            [param],
-            param + "_first_moment",
-            value=0.0
-        )
-        m2 = param_init_net.ConstantFill(
-            [param],
-            param + "_second_moment",
-            value=0.0
-        )
+        if self.rowWise:
+            shapes, types = workspace.InferShapesAndTypes([param_init_net])
+            m1 = param_init_net.ConstantFill(
+                [],
+                param + "_avg_first_moment",
+                shape=[shapes[param][0]],
+                value=0.0
+            )
+            m2 = param_init_net.ConstantFill(
+                [],
+                param + "_avg_second_moment",
+                shape=[shapes[param][0]],
+                value=0.0
+            )
+
+        else:
+            m1 = param_init_net.ConstantFill(
+                [param],
+                param + "_first_moment",
+                value=0.0
+            )
+            m2 = param_init_net.ConstantFill(
+                [param],
+                param + "_second_moment",
+                value=0.0
+            )
+
         self._aux_params.shared.append(iteration)
         self._aux_params.local.append(m1)
         self._aux_params.local.append(m2)
+
+        if self.rowWise:
+            assert isinstance(grad, core.GradientSlice),\
+                'If SparseAdam with rowWise=True, gradient must be '\
+                'a gradientslice. PLease ensure that rowWise is not enabled '\
+                'for the dense Adam optimizer, as it is not supported.'
         if isinstance(grad, core.GradientSlice):
             grad = self.dedup(net, self.sparse_dedup_aggregator, grad)
-            net.SparseAdam(
+            if self.rowWise:
+                op = 'RowWiseSparseAdam'
+            else:
+                op = 'SparseAdam'
+            net.__getattr__(op)(
                 [param, m1, m2, grad.indices, grad.values, lr, iteration],
                 [param, m1, m2],
                 beta1=self.beta1,
diff --git a/caffe2/sgd/adam_op.cc b/caffe2/sgd/adam_op.cc
index a4debfd..426a276 100644
--- a/caffe2/sgd/adam_op.cc
+++ b/caffe2/sgd/adam_op.cc
@@ -81,7 +81,40 @@
     .Arg("beta2", "Default 0.999")
     .Arg("epsilon", "Default 1e-5");
 
+REGISTER_CPU_OPERATOR(
+    RowWiseSparseAdam,
+    RowWiseSparseAdamOp<float, CPUContext>);
+OPERATOR_SCHEMA(RowWiseSparseAdam)
+    .NumInputs(7)
+    .NumOutputs(3)
+    .EnforceInplace({{0, 0}, {1, 1}, {2, 2}})
+    .SetDoc(R"DOC(
+
+Computes a modified Adam Update for the sparse case.
+Given inputs (param, moment1, moment2, indices, grad, lr, iter), runs the
+Adam update on (param, moment1[indices], moment2[indices], lr, iter) and returns
+(new_param, new_moment1, new_moment2), where moment1 and moment2 are 1D tensors
+with length equal to the number of rows in param: shape(moment1) ==
+shape(moment2) == shape(param)[0]. Each element of moment1 and moment2 is
+applied to an entire row of param, and the new moment1 and moment2 values are
+calculated by averaging across the row.
+
+)DOC")
+    .Input(0, "param", "Parameters to be updated")
+    .Input(1, "moment_1", "First moment history")
+    .Input(2, "moment_2", "Second moment history")
+    .Input(3, "indices", "Sparse indices")
+    .Input(4, "grad", "Gradient computed")
+    .Input(5, "lr", "learning rate")
+    .Input(6, "iter", "iteration number")
+    .Output(0, "output_param", "Updated parameters")
+    .Output(1, "output_moment_1", "Updated first moment")
+    .Output(2, "output_moment_2", "Updated second moment")
+    .Arg("beta1", "Default 0.9")
+    .Arg("beta2", "Default 0.999")
+    .Arg("epsilon", "Default 1e-5");
+
 SHOULD_NOT_DO_GRADIENT(Adam);
 SHOULD_NOT_DO_GRADIENT(SparseAdam);
-
+SHOULD_NOT_DO_GRADIENT(RowWiseSparseAdam);
 }
diff --git a/caffe2/sgd/adam_op.h b/caffe2/sgd/adam_op.h
index 75459eb..d7a05c3 100644
--- a/caffe2/sgd/adam_op.h
+++ b/caffe2/sgd/adam_op.h
@@ -134,7 +134,8 @@
     // Enforce shapes
     CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_1).size());
     CAFFE_ENFORCE_EQ(Input(PARAM).size(), Input(MOMENT_2).size());
-    CAFFE_ENFORCE_EQ(Input(PARAM).size_from_dim(1),
+    CAFFE_ENFORCE_EQ(
+        Input(PARAM).size_from_dim(1),
         Input(GRAD).size_from_dim(Input(INDICES).ndim()));
     CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
 
@@ -228,4 +229,121 @@
   INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
   OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
 };
-}
+
+template <typename T, class Context>
+class RowWiseSparseAdamOp final : public Operator<Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+  RowWiseSparseAdamOp(const OperatorDef& operator_def, Workspace* ws)
+      : Operator<Context>(operator_def, ws),
+        beta1_(OperatorBase::GetSingleArgument<float>("beta1", 0.9f)),
+        beta2_(OperatorBase::GetSingleArgument<float>("beta2", 0.999f)),
+        epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)) {}
+
+  bool RunOnDevice() override {
+    // Enforce shapes
+    CAFFE_ENFORCE_EQ(Input(PARAM).dims()[0], Input(MOMENT_1).size());
+    CAFFE_ENFORCE_EQ(Input(PARAM).dims()[0], Input(MOMENT_2).size());
+    CAFFE_ENFORCE_EQ(
+        Input(PARAM).size_from_dim(1),
+        Input(GRAD).size_from_dim(Input(INDICES).ndim()));
+    CAFFE_ENFORCE_EQ(Input(LR).size(), 1);
+
+    return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
+        this, Input(INDICES));
+  }
+
+  template <typename SIndex>
+  bool DoRunWithType() {
+    const auto* lr = Input(LR).template data<T>();
+    const auto iter =
+        OperatorBase::Input<TensorCPU>(ITER).template data<int64_t>()[0];
+
+    const auto t = iter + 1;
+    const auto correction =
+        std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
+
+    auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
+    auto n = Input(GRAD).size() / block_size;
+
+    const auto* paramIn = Input(PARAM).template data<T>();
+    const auto* indices = Input(INDICES).template data<SIndex>();
+    const auto* gradIn = Input(GRAD).template data<T>();
+    const auto* moment1In = Input(MOMENT_1).template data<T>();
+    const auto* moment2In = Input(MOMENT_2).template data<T>();
+    auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<T>();
+    auto* moment1Out = Output(OUTPUT_MOMENT_1)->template mutable_data<T>();
+    auto* moment2Out = Output(OUTPUT_MOMENT_2)->template mutable_data<T>();
+
+    for (auto i = 0; i < n; ++i) {
+      auto idx = indices[i];
+
+      if (block_size == 1) {
+        float gi = gradIn[i];
+        float mi = moment1Out[idx] =
+            moment1In[idx] * beta1_ + gi * (1 - beta1_);
+        float vi = moment2Out[idx] =
+            moment2In[idx] * beta2_ + gi * gi * (1 - beta2_);
+        paramOut[idx] =
+            paramIn[idx] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
+
+      } else {
+        auto offsetI = i * block_size;
+        auto offsetIdx = idx * block_size;
+
+#ifndef NDEBUG
+        CAFFE_ENFORCE_GE(
+            Input(PARAM).size(),
+            block_size + offsetIdx,
+            this->debug_def().input(PARAM),
+            ", out of bound,  idx:",
+            idx,
+            " for input i:",
+            i,
+            " and block size:",
+            block_size);
+        CAFFE_ENFORCE_GE(
+            Input(GRAD).size(),
+            block_size + offsetI,
+            this->debug_def().input(GRAD),
+            ", out of bound idx, idx:",
+            idx,
+            " for input i:",
+            i);
+#endif
+
+        const float* w = paramIn + offsetIdx;
+        const float* g = gradIn + offsetI;
+        const float* m1 = moment1In + idx;
+        const float* m2 = moment2In + idx;
+        float* nw = paramOut + offsetIdx;
+        float* nm1 = moment1Out + idx;
+        float* nm2 = moment2Out + idx;
+
+        float m1_sum = 0.;
+        float m2_sum = 0.;
+        for (auto j = 0; j < block_size; ++j) {
+          float gj = g[j];
+          m1_sum += gj;
+          m2_sum += gj * gj;
+        }
+        float mi = nm1[0] =
+            m1[0] * beta1_ + (m1_sum / block_size) * (1 - beta1_);
+        float vi = nm2[0] =
+            m2[0] * beta2_ + (m2_sum / block_size) * (1 - beta2_);
+        for (auto j = 0; j < block_size; ++j) {
+          nw[j] = w[j] + lr[0] * correction * mi / (std::sqrt(vi) + epsilon_);
+        }
+      }
+    }
+    return true;
+  }
+
+ protected:
+  T beta1_;
+  T beta2_;
+  T epsilon_;
+  INPUT_TAGS(PARAM, MOMENT_1, MOMENT_2, INDICES, GRAD, LR, ITER);
+  OUTPUT_TAGS(OUTPUT_PARAM, OUTPUT_MOMENT_1, OUTPUT_MOMENT_2);
+};
+} // namespace caffe2