Fix number of indices and block_size in SparseAdam
Summary:
Fix number of indices and block_size in SparseAdam to support gradients of any dimension.
Closes https://github.com/caffe2/caffe2/pull/249
Reviewed By: asaadaldien
Differential Revision: D5125714
Pulled By: akyrola
fbshipit-source-id: 84134049cb9a77e58562272ea351222befe27fca
diff --git a/caffe2/sgd/adam_op.h b/caffe2/sgd/adam_op.h
index ccc75f4..12ae713 100644
--- a/caffe2/sgd/adam_op.h
+++ b/caffe2/sgd/adam_op.h
@@ -136,8 +136,8 @@
const auto correction =
std::sqrt(T(1.) - std::pow(beta2_, t)) / (T(1.) - std::pow(beta1_, t));
- auto n = Input(GRAD).dim(0);
- auto block_size = Input(GRAD).size() / n;
+ auto block_size = Input(PARAM).size() / Input(PARAM).dim(0);
+ auto n = Input(GRAD).size() / block_size;
const auto* paramIn = Input(PARAM).template data<T>();
const auto* indices = Input(INDICES).template data<SIndex>();