Add hashing to bucket-weighted pooling (#20673)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20673
Add option to bucket-weighted pooling to hash the bucket so that any cardinality score can be used.
Reviewed By: huginhuangfb
Differential Revision: D15003509
fbshipit-source-id: 575a149de395f18fd7759f3edb485619f8aa5363
diff --git a/caffe2/python/layers/bucket_weighted.py b/caffe2/python/layers/bucket_weighted.py
index 99ec158..d234089 100644
--- a/caffe2/python/layers/bucket_weighted.py
+++ b/caffe2/python/layers/bucket_weighted.py
@@ -21,11 +21,12 @@
class BucketWeighted(ModelLayer):
def __init__(self, model, input_record, max_score=0, bucket_boundaries=None,
- weight_optim=None, name="bucket_weighted"):
+ hash_buckets=False, weight_optim=None, name="bucket_weighted"):
super(BucketWeighted, self).__init__(model, name, input_record)
assert isinstance(input_record, schema.List), "Incorrect input type"
self.bucket_boundaries = bucket_boundaries
+ self.hash_buckets = hash_buckets
if bucket_boundaries is not None:
self.shape = len(bucket_boundaries) + 1
elif max_score > 0:
@@ -63,6 +64,10 @@
"buckets_int",
to=core.DataType.INT32
)
+ if self.hash_buckets:
+ buckets_int = net.IndexHash(
+ buckets_int, "hashed_buckets_int", seed=0, modulo=self.shape
+ )
net.Gather(
[self.bucket_w, buckets_int],
self.output_schema.bucket_weights.field_blobs())