add typecast and assertion for histogram computing
as title
diff --git a/caffe2/python/modeling/compute_histogram_for_blobs.py b/caffe2/python/modeling/compute_histogram_for_blobs.py
index 32c1e78..72d6933 100644
--- a/caffe2/python/modeling/compute_histogram_for_blobs.py
+++ b/caffe2/python/modeling/compute_histogram_for_blobs.py
@@ -33,9 +33,11 @@
else:
self._field_name_suffix = '_curr_normalized_hist'
- self._num_buckets = num_buckets
- self._lower_bound = lower_bound
- self._upper_bound = upper_bound
+ self._num_buckets = int(num_buckets)
+ assert self._num_buckets > 0, (
+ "num_buckets need to be greater than 0, got {}".format(num_buckets))
+ self._lower_bound = float(lower_bound)
+ self._upper_bound = float(upper_bound)
def modify_net(self, net, init_net=None, grad_map=None, blob_to_device=None):
for blob_name in self._blobs:
@@ -44,8 +46,10 @@
raise Exception('blob {0} is not defined in net {1}'.format(
blob, net.Name()))
+ blob_float = net.Cast(blob, net.NextScopedBlob(prefix=blob +
+ '_float'), to=core.DataType.FLOAT)
curr_hist, acc_hist = net.AccumulateHistogram(
- [blob],
+ [blob_float],
[net.NextScopedBlob(prefix=blob + '_curr_hist'),
net.NextScopedBlob(prefix=blob + '_acc_hist')],
num_buckets=self._num_buckets,
diff --git a/caffe2/python/modeling/compute_histogram_for_blobs_test.py b/caffe2/python/modeling/compute_histogram_for_blobs_test.py
index 6720908..aec100d 100644
--- a/caffe2/python/modeling/compute_histogram_for_blobs_test.py
+++ b/caffe2/python/modeling/compute_histogram_for_blobs_test.py
@@ -70,7 +70,7 @@
self.assertEqual(fc1_w_curr_normalized_hist.size, num_buckets + 2)
self.assertAlmostEqual(np.linalg.norm(
fc1_w_curr_normalized_hist - cur_hist), 0.0, delta=1e-5)
- self.assertEqual(len(model.net.Proto().op), 10)
+ self.assertEqual(len(model.net.Proto().op), 12)
assert 'fc1_w' + net_modifier.field_name_suffix() in\
model.net.output_record().field_blobs(),\