make sure the output of sparse lookup layer is float
Summary: currently, if reduer=Nonoe, the output if fp16
Differential Revision: D5773560
fbshipit-source-id: 24d7e5fae366d70352582e9a1ee14c7613753b7a
diff --git a/caffe2/python/layers/sparse_lookup.py b/caffe2/python/layers/sparse_lookup.py
index c9eba7d..bf262fc 100644
--- a/caffe2/python/layers/sparse_lookup.py
+++ b/caffe2/python/layers/sparse_lookup.py
@@ -22,7 +22,7 @@
def get_sparse_lookup_predictor_version(version):
- assert version in ('fp16', 'uint8rowwise'),\
+ assert version in {'fp32', 'fp16', 'uint8rowwise'},\
"Unexpected version of sparse_lookup layer {0}".format(version)
return version
@@ -91,8 +91,6 @@
initializer=self.scale_bias_init,
optimizer=model.NoOptim)
-
-
self.output_schema = schema.Scalar(
(np.float32, inner_shape),
self.get_next_blob_reference('output'),
@@ -115,15 +113,21 @@
return [weight]
def _gather_wrapper(self, net, version, in_indices, out):
- if version == 'fp16':
- return net.Gather([self.w, in_indices], out, engine='fp16')
+ # Gather can work on all kinds of input data types, and output
+ # data with the same type. Convert the output of Gather to float,
+ # because the follow-up Ops expect fp32.
+ if version == 'fp32':
+ return net.Gather([self.w, in_indices], out)
+ elif version == 'fp16':
+ gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
+ return net.HalfToFloat(gathered_w, out)
elif version == 'uint8rowwise':
- gathered_w = net.Gather([self.w, in_indices],
- engine='fp16')
+ gathered_w = net.Gather([self.w, in_indices], 'gathered_w')
gathered_scale_bias = net.Gather(
[self.scale_bias, in_indices],
- engine='fp16')
+ 'gathered_scale_bias'
+ )
return net.Rowwise8BitQuantizedToFloat(
[gathered_w, gathered_scale_bias], out)
@@ -134,12 +138,17 @@
def _sparse_lengths_weighted_reducer(
self, in_indices, weights, reducer,
net, version, grad_on_weights=0):
- op_input = [self.w,
- weights,
- in_indices,
- self.input_record.lengths()]
+ op_input = [
+ self.w,
+ weights,
+ in_indices,
+ self.input_record.lengths()
+ ]
layer_name = 'SparseLengths' + reducer
- if version == 'fp16':
+
+ if version in ['fp32', 'fp16']:
+ # SparseLengths* Ops with engine='fp16' will accept either
+ # fp16 or fp32 embedding matrix and output fp32 pooled embedding
net.__getattr__(layer_name)(
op_input,
self.output_schema.field_blobs(),
@@ -154,10 +163,8 @@
raise "Unsupported version of operator in SparseLookUp " +\
"layer: {0}".format(version)
-
-
# deal with sparse features of id_list type
- def _add_ops_id_list(self, net, version='fp16'):
+ def _add_ops_id_list(self, net, version):
assert self.reducer in self._id_list_supported_reducers, (
"Unsupported reducer: {} for ID_LIST".format(self.reducer)
)
@@ -167,7 +174,9 @@
self.input_record.lengths()]
layer_name = 'SparseLengths' + self.reducer
- if version == 'fp16':
+ if version in ['fp32', 'fp16']:
+ # SparseLengths* Ops with engine='fp16' will accept either
+ # fp16 or fp32 embedding matrix and output fp32 pooled embedding
net.__getattr__(layer_name)(
op_input,
self.output_schema.field_blobs(),
@@ -200,7 +209,7 @@
else:
table_rows = self._gather_wrapper(
- net, version, self.input_record.items(), 1)
+ net, version, self.input_record.items(), 'table_rows')
segment_ids = net.LengthsToSegmentIds(
self.input_record.lengths(),
@@ -211,9 +220,8 @@
engine='fp16',
)
-
# deal with sparse features of id_score_list type
- def _add_ops_id_score_list(self, net, version='fp16'):
+ def _add_ops_id_score_list(self, net, version):
assert self.reducer in self._id_score_list_supported_reducers, (
"Unsupported reducer: {} for ID_SCORE_LIST".format(self.reducer)
)
@@ -229,7 +237,8 @@
self.input_record.lengths()]
layer_name = 'SparseLengths' + self.reducer
- if version == 'fp16':
+
+ if version in ['fp32', 'fp16']:
net.__getattr__(layer_name)(
op_input,
self.output_schema.field_blobs(),
@@ -242,7 +251,6 @@
raise "Unsupported version of operator in SparseLookUp " +\
"layer: {0}".format(version)
-
elif self.reducer == 'PositionWeighted':
self._sparse_lengths_weighted_reducer(
self.input_record.keys(),
@@ -262,7 +270,7 @@
cur_scope = get_current_scope()
version = get_sparse_lookup_predictor_version(
**cur_scope.get(get_sparse_lookup_predictor_version.__name__,
- {'version': 'fp16'}))
+ {'version': 'fp32'}))
if schema.equal_schemas(self.input_record, IdList):
self._add_ops_id_list(net, version=version)