make sure the output of sparse lookup layer is float

Summary: currently, if reduer=Nonoe, the output if fp16

Differential Revision: D5773560

fbshipit-source-id: 24d7e5fae366d70352582e9a1ee14c7613753b7a
diff --git a/caffe2/python/layers/sparse_lookup.py b/caffe2/python/layers/sparse_lookup.py
index c9eba7d..bf262fc 100644
--- a/caffe2/python/layers/sparse_lookup.py
+++ b/caffe2/python/layers/sparse_lookup.py
@@ -22,7 +22,7 @@
 
 
 def get_sparse_lookup_predictor_version(version):
-    assert version in ('fp16', 'uint8rowwise'),\
+    assert version in {'fp32', 'fp16', 'uint8rowwise'},\
         "Unexpected version of sparse_lookup layer {0}".format(version)
     return version
 
@@ -91,8 +91,6 @@
             initializer=self.scale_bias_init,
             optimizer=model.NoOptim)
 
-
-
         self.output_schema = schema.Scalar(
             (np.float32, inner_shape),
             self.get_next_blob_reference('output'),
@@ -115,15 +113,21 @@
         return [weight]
 
     def _gather_wrapper(self, net, version, in_indices, out):
-        if version == 'fp16':
-            return net.Gather([self.w, in_indices], out, engine='fp16')
+        # Gather can work on all kinds of input data types, and output
+        # data with the same type. Convert the output of Gather to float,
+        # because the follow-up Ops expect fp32.
+        if version == 'fp32':
+            return net.Gather([self.w, in_indices], out)
+        elif version == 'fp16':
+            gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
 
+            return net.HalfToFloat(gathered_w, out)
         elif version == 'uint8rowwise':
-            gathered_w = net.Gather([self.w, in_indices],
-                                    engine='fp16')
+            gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
             gathered_scale_bias = net.Gather(
                 [self.scale_bias, in_indices],
-                engine='fp16')
+                'gathered_scale_bias'
+            )
 
             return net.Rowwise8BitQuantizedToFloat(
                 [gathered_w, gathered_scale_bias], out)
@@ -134,12 +138,17 @@
     def _sparse_lengths_weighted_reducer(
             self, in_indices, weights, reducer,
             net, version, grad_on_weights=0):
-        op_input = [self.w,
-                    weights,
-                    in_indices,
-                    self.input_record.lengths()]
+        op_input = [
+            self.w,
+            weights,
+            in_indices,
+            self.input_record.lengths()
+        ]
         layer_name = 'SparseLengths' + reducer
-        if version == 'fp16':
+
+        if version in ['fp32', 'fp16']:
+            # SparseLengths* Ops with engine='fp16' will accept either
+            # fp16 or fp32 embedding matrix and output fp32 pooled embedding
             net.__getattr__(layer_name)(
                 op_input,
                 self.output_schema.field_blobs(),
@@ -154,10 +163,8 @@
             raise "Unsupported version of operator in SparseLookUp " +\
                 "layer: {0}".format(version)
 
-
-
     # deal with sparse features of id_list type
-    def _add_ops_id_list(self, net, version='fp16'):
+    def _add_ops_id_list(self, net, version):
         assert self.reducer in self._id_list_supported_reducers, (
             "Unsupported reducer: {} for ID_LIST".format(self.reducer)
         )
@@ -167,7 +174,9 @@
                         self.input_record.lengths()]
 
             layer_name = 'SparseLengths' + self.reducer
-            if version == 'fp16':
+            if version in ['fp32', 'fp16']:
+                # SparseLengths* Ops with engine='fp16' will accept either
+                # fp16 or fp32 embedding matrix and output fp32 pooled embedding
                 net.__getattr__(layer_name)(
                     op_input,
                     self.output_schema.field_blobs(),
@@ -200,7 +209,7 @@
 
         else:
             table_rows = self._gather_wrapper(
-                net, version, self.input_record.items(), 1)
+                net, version, self.input_record.items(), 'table_rows')
 
             segment_ids = net.LengthsToSegmentIds(
                 self.input_record.lengths(),
@@ -211,9 +220,8 @@
                 engine='fp16',
             )
 
-
     # deal with sparse features of id_score_list type
-    def _add_ops_id_score_list(self, net, version='fp16'):
+    def _add_ops_id_score_list(self, net, version):
         assert self.reducer in self._id_score_list_supported_reducers, (
             "Unsupported reducer: {} for ID_SCORE_LIST".format(self.reducer)
         )
@@ -229,7 +237,8 @@
                         self.input_record.lengths()]
 
             layer_name = 'SparseLengths' + self.reducer
-            if version == 'fp16':
+
+            if version in ['fp32', 'fp16']:
                 net.__getattr__(layer_name)(
                     op_input,
                     self.output_schema.field_blobs(),
@@ -242,7 +251,6 @@
                 raise "Unsupported version of operator in SparseLookUp " +\
                     "layer: {0}".format(version)
 
-
         elif self.reducer == 'PositionWeighted':
             self._sparse_lengths_weighted_reducer(
                 self.input_record.keys(),
@@ -262,7 +270,7 @@
         cur_scope = get_current_scope()
         version = get_sparse_lookup_predictor_version(
             **cur_scope.get(get_sparse_lookup_predictor_version.__name__,
-                            {'version': 'fp16'}))
+                            {'version': 'fp32'}))
 
         if schema.equal_schemas(self.input_record, IdList):
             self._add_ops_id_list(net, version=version)