Set the correct engine name for position weighted pooling when fp16 is used for training

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12225

Reviewed By: hyuen, xianjiec

Differential Revision: D10123465

fbshipit-source-id: e8d929d4153d1ee987ae3d1c37892525d7574d16
diff --git a/caffe2/python/layers/sparse_lookup.py b/caffe2/python/layers/sparse_lookup.py
index 4c3661b..1dde928 100644
--- a/caffe2/python/layers/sparse_lookup.py
+++ b/caffe2/python/layers/sparse_lookup.py
@@ -166,6 +166,8 @@
                 "Train version {} is not currently supported".format(trainer_version)
             )
 
+        self.trainer_version = trainer_version
+
         return default_weight_init
 
     def _gather_wrapper(self, net, version, in_indices, out):
@@ -208,11 +210,22 @@
         if version in ['fp32', 'fp16']:
             # SparseLengths* Ops will accept either fp16 or fp32 embedding
             # matrix and output fp32 pooled embedding
-            net.__getattr__(layer_name)(
-                op_input,
-                self.output_schema.field_blobs(),
-                grad_on_weights=grad_on_weights,
-            )
+            # A special case here is that we need FP16 engine for
+            # SparseLengthsWeightedSum when FP16 embeedings are used for
+            # correct backward updates
+            if reducer == "WeightedSum" and version == "fp16":
+                net.SparseLengthsWeightedSum(
+                    op_input,
+                    self.output_schema.field_blobs(),
+                    grad_on_weights=grad_on_weights,
+                    engine='FP16',
+                )
+            else:
+                net.__getattr__(layer_name)(
+                    op_input,
+                    self.output_schema.field_blobs(),
+                    grad_on_weights=grad_on_weights,
+                )
         elif version == 'uint8rowwise':
             op_input.insert(len(op_input), self.scale_bias)
             net.__getattr__(layer_name + '8BitsRowwise')(
@@ -338,7 +351,18 @@
             raise "Only Sum, Mean, None are supported for IdScoreList input." +\
                 "Trying to create with {}".format(self.reducer)
 
-    def add_ops(self, net):
+    def add_ops(self, net, version='fp32'):
+        if _is_id_list(self.input_record):
+            self._add_ops_id_list(net, version=version)
+        elif _is_id_score_list(self.input_record):
+            self._add_ops_id_score_list(net, version=version)
+        else:
+            raise "Unsupported input type {0}".format(self.input_record)
+
+    def add_train_ops(self, net):
+        self.add_ops(net, self.trainer_version)
+
+    def add_predict_ops(self, net):
         cur_scope = get_current_scope()
         version = get_sparse_lookup_predictor_version(
             **cur_scope.get(get_sparse_lookup_predictor_version.__name__,
@@ -350,9 +374,7 @@
                                                    'fused_uint8rowwise'}:
             version = 'fp32'
 
-        if _is_id_list(self.input_record):
-            self._add_ops_id_list(net, version=version)
-        elif _is_id_score_list(self.input_record):
-            self._add_ops_id_score_list(net, version=version)
-        else:
-            raise "Unsupported input type {0}".format(self.input_record)
+        self.add_ops(net, version)
+
+    def add_eval_ops(self, net):
+        self.add_predict_ops(net)