Set the correct engine name for position weighted pooling when fp16 is used for training
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12225
Reviewed By: hyuen, xianjiec
Differential Revision: D10123465
fbshipit-source-id: e8d929d4153d1ee987ae3d1c37892525d7574d16
diff --git a/caffe2/python/layers/sparse_lookup.py b/caffe2/python/layers/sparse_lookup.py
index 4c3661b..1dde928 100644
--- a/caffe2/python/layers/sparse_lookup.py
+++ b/caffe2/python/layers/sparse_lookup.py
@@ -166,6 +166,8 @@
"Train version {} is not currently supported".format(trainer_version)
)
+ self.trainer_version = trainer_version
+
return default_weight_init
def _gather_wrapper(self, net, version, in_indices, out):
@@ -208,11 +210,22 @@
if version in ['fp32', 'fp16']:
# SparseLengths* Ops will accept either fp16 or fp32 embedding
# matrix and output fp32 pooled embedding
- net.__getattr__(layer_name)(
- op_input,
- self.output_schema.field_blobs(),
- grad_on_weights=grad_on_weights,
- )
+ # A special case here is that we need FP16 engine for
+ # SparseLengthsWeightedSum when FP16 embeedings are used for
+ # correct backward updates
+ if reducer == "WeightedSum" and version == "fp16":
+ net.SparseLengthsWeightedSum(
+ op_input,
+ self.output_schema.field_blobs(),
+ grad_on_weights=grad_on_weights,
+ engine='FP16',
+ )
+ else:
+ net.__getattr__(layer_name)(
+ op_input,
+ self.output_schema.field_blobs(),
+ grad_on_weights=grad_on_weights,
+ )
elif version == 'uint8rowwise':
op_input.insert(len(op_input), self.scale_bias)
net.__getattr__(layer_name + '8BitsRowwise')(
@@ -338,7 +351,18 @@
raise "Only Sum, Mean, None are supported for IdScoreList input." +\
"Trying to create with {}".format(self.reducer)
- def add_ops(self, net):
+ def add_ops(self, net, version='fp32'):
+ if _is_id_list(self.input_record):
+ self._add_ops_id_list(net, version=version)
+ elif _is_id_score_list(self.input_record):
+ self._add_ops_id_score_list(net, version=version)
+ else:
+ raise "Unsupported input type {0}".format(self.input_record)
+
+ def add_train_ops(self, net):
+ self.add_ops(net, self.trainer_version)
+
+ def add_predict_ops(self, net):
cur_scope = get_current_scope()
version = get_sparse_lookup_predictor_version(
**cur_scope.get(get_sparse_lookup_predictor_version.__name__,
@@ -350,9 +374,7 @@
'fused_uint8rowwise'}:
version = 'fp32'
- if _is_id_list(self.input_record):
- self._add_ops_id_list(net, version=version)
- elif _is_id_score_list(self.input_record):
- self._add_ops_id_score_list(net, version=version)
- else:
- raise "Unsupported input type {0}".format(self.input_record)
+ self.add_ops(net, version)
+
+ def add_eval_ops(self, net):
+ self.add_predict_ops(net)