[mlf][efficiency] add tensor inference function to last-n collector op (#46693)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46693
title
Test Plan: unit tests
Reviewed By: hx89
Differential Revision: D23946770
fbshipit-source-id: f7c3d4a1b4ef3b0e5f56e5a9a30f5003ce9f40b0
diff --git a/caffe2/operators/last_n_window_collector.cc b/caffe2/operators/last_n_window_collector.cc
index 1b141b6..8b14c83 100644
--- a/caffe2/operators/last_n_window_collector.cc
+++ b/caffe2/operators/last_n_window_collector.cc
@@ -142,6 +142,30 @@
.NumInputs({3, 4, 5})
.NumOutputs(2, 3)
.EnforceInplace({{0, 0}, {1, 1}, {4, 2}})
+ .TensorInferenceFunction([](const OperatorDef& def,
+ const vector<TensorShape>& in) {
+ auto output_size = def.output_size();
+ vector<TensorShape> out(output_size);
+ const ArgumentHelper helper(def);
+ const auto num_to_collect =
+ helper.GetSingleArgument<int>("num_to_collect", -1);
+
+ const auto data_dims = GetDimsVector(in[2]);
+ vector<int64_t> last_n_shape(data_dims.size());
+ last_n_shape[0] = num_to_collect;
+ std::copy(data_dims.begin() + 1, data_dims.end(), last_n_shape.begin() + 1);
+ out[0] = CreateTensorShape(last_n_shape, in[2].data_type());
+
+ out[1] = in[1];
+
+ if (output_size > 2) {
+ vector<int64_t> num_visited_shape(1);
+ num_visited_shape[0] = 1;
+ out[2] = CreateTensorShape(num_visited_shape, TensorProto::INT64);
+ }
+
+ return out;
+ })
.SetDoc(R"DOC(
Collect the last N rows from input data. The purpose is to keep track of data
accross batches, so for example suppose the LastNWindowCollector is called
diff --git a/caffe2/python/operator_test/dataset_ops_test.py b/caffe2/python/operator_test/dataset_ops_test.py
index 96d93dc..a7e0157 100644
--- a/caffe2/python/operator_test/dataset_ops_test.py
+++ b/caffe2/python/operator_test/dataset_ops_test.py
@@ -1,32 +1,32 @@
-
-
-
-
-import numpy as np
-from caffe2.python import core, workspace, dataset
-from caffe2.python.dataset import Const
-from caffe2.python.schema import (
- List, Field, Struct, Scalar, Map, from_blob_list, FetchRecord, NewRecord,
- FeedRecord
-)
-from caffe2.python.test_util import TestCase
-
-import numpy.testing as npt
-
+import functools
+import operator
import string
-from hypothesis import given
import hypothesis.strategies as st
+import numpy as np
+import numpy.testing as npt
+from caffe2.python import core, dataset, workspace
+from caffe2.python.dataset import Const
+from caffe2.python.schema import (
+ FeedRecord,
+ FetchRecord,
+ Field,
+ List,
+ Map,
+ NewRecord,
+ Scalar,
+ Struct,
+ from_blob_list,
+)
+from caffe2.python.test_util import TestCase
+from hypothesis import given
def _assert_arrays_equal(actual, ref, err_msg):
- if ref.dtype.kind in ('S', 'O', 'U'):
+ if ref.dtype.kind in ("S", "O", "U"):
np.testing.assert_array_equal(actual, ref, err_msg=err_msg)
else:
- np.testing.assert_allclose(
- actual, ref, atol=1e-4,
- rtol=1e-4, err_msg=err_msg
- )
+ np.testing.assert_allclose(actual, ref, atol=1e-4, rtol=1e-4, err_msg=err_msg)
def _assert_records_equal(actual, ref):
@@ -34,11 +34,12 @@
assert isinstance(ref, Field)
b1 = actual.field_blobs()
b2 = ref.field_blobs()
- assert (len(b1) == len(b2)), 'Records have different lengths: %d vs. %d' % (
- len(b1), len(b2)
+ assert len(b1) == len(b2), "Records have different lengths: %d vs. %d" % (
+ len(b1),
+ len(b2),
)
for name, d1, d2 in zip(ref.field_names(), b1, b2):
- _assert_arrays_equal(d1, d2, err_msg='Mismatch in field %s.' % name)
+ _assert_arrays_equal(d1, d2, err_msg="Mismatch in field %s." % name)
@st.composite
@@ -47,7 +48,7 @@
st.lists(
st.integers(min_value=1, max_value=10),
min_size=num_records,
- max_size=num_records
+ max_size=num_records,
)
)
@@ -58,7 +59,7 @@
st.integers(min_value=1, max_value=100),
min_size=sparse_maps_total_length,
max_size=sparse_maps_total_length,
- unique=True
+ unique=True,
)
)
@@ -66,7 +67,7 @@
st.lists(
st.integers(min_value=1, max_value=10),
min_size=sparse_maps_total_length,
- max_size=sparse_maps_total_length
+ max_size=sparse_maps_total_length,
)
)
@@ -77,7 +78,7 @@
st.lists(
st.integers(min_value=1, max_value=9223372036854775807),
min_size=total_sparse_values_lengths,
- max_size=total_sparse_values_lengths
+ max_size=total_sparse_values_lengths,
)
)
@@ -95,7 +96,7 @@
st.lists(
st.integers(min_value=1, max_value=10),
min_size=num_records,
- max_size=num_records
+ max_size=num_records,
)
)
@@ -106,14 +107,12 @@
st.integers(min_value=1, max_value=100),
min_size=total_length,
max_size=total_length,
- unique=True
+ unique=True,
)
)
float_values = draw(
- st.lists(st.floats(),
- min_size=total_length,
- max_size=total_length)
+ st.lists(st.floats(), min_size=total_length, max_size=total_length)
)
return [float_lengths, float_keys, float_values]
@@ -123,22 +122,20 @@
def _dataset(draw, min_elements=3, max_elements=10, **kwargs):
schema = Struct(
# Dense Features Map
- ('floats', Map(
- Scalar(np.int32), Scalar(np.float32)
- )),
+ ("floats", Map(Scalar(np.int32), Scalar(np.float32))),
# Sparse Features Map
- ('int_lists', Map(
- Scalar(np.int32),
- List(Scalar(np.int64)),
- )),
+ (
+ "int_lists",
+ Map(
+ Scalar(np.int32),
+ List(Scalar(np.int64)),
+ ),
+ ),
# Complex Type
- ('text', Scalar(str)),
+ ("text", Scalar(str)),
)
- num_records = draw(
- st.integers(min_value=min_elements,
- max_value=max_elements)
- )
+ num_records = draw(st.integers(min_value=min_elements, max_value=max_elements))
raw_dense_features_map_contents = draw(_dense_features_map(num_records))
@@ -149,13 +146,17 @@
st.lists(
st.text(alphabet=string.ascii_lowercase),
min_size=num_records,
- max_size=num_records
+ max_size=num_records,
)
)
]
# Concatenate all raw contents to a single one
- contents_raw = raw_dense_features_map_contents + raw_sparse_features_map_contents + raw_text_contents
+ contents_raw = (
+ raw_dense_features_map_contents
+ + raw_sparse_features_map_contents
+ + raw_text_contents
+ )
contents = from_blob_list(schema, contents_raw)
@@ -172,31 +173,28 @@
dataset_fields = schema.field_names()
-
for pack_to_single_shared_ptr in (True, False):
- net = core.Net('pack_unpack_net')
+ net = core.Net("pack_unpack_net")
batch = NewRecord(net, contents)
FeedRecord(batch, contents)
packed = net.PackRecords(
- batch.field_blobs(), 1,
+ batch.field_blobs(),
+ 1,
fields=dataset_fields,
- pack_to_single_shared_ptr=pack_to_single_shared_ptr
+ pack_to_single_shared_ptr=pack_to_single_shared_ptr,
)
unpacked = packed.UnPackRecords(
- [], len(dataset_fields),
- fields=dataset_fields
+ [], len(dataset_fields), fields=dataset_fields
)
workspace.RunNetOnce(net)
- for initial_tensor, unpacked_tensor in zip(
- batch.field_blobs(), unpacked
- ):
+ for initial_tensor, unpacked_tensor in zip(batch.field_blobs(), unpacked):
npt.assert_array_equal(
workspace.FetchBlob(initial_tensor),
- workspace.FetchBlob(unpacked_tensor)
+ workspace.FetchBlob(unpacked_tensor),
)
def test_dataset_ops(self):
@@ -207,35 +205,38 @@
"""
schema = Struct(
# fixed size vector, which will be stored as a matrix when batched
- ('dense', Scalar((np.float32, 3))),
+ ("dense", Scalar((np.float32, 3))),
# could represent a feature map from feature ID to float value
- ('floats', Map(
- Scalar(np.int32), Scalar(np.float32)
- )),
+ ("floats", Map(Scalar(np.int32), Scalar(np.float32))),
# could represent a multi-valued categorical feature map
- ('int_lists', Map(
- Scalar(np.int32),
- List(Scalar(np.int64)),
- )),
+ (
+ "int_lists",
+ Map(
+ Scalar(np.int32),
+ List(Scalar(np.int64)),
+ ),
+ ),
# could represent a multi-valued, weighted categorical feature map
(
- 'id_score_pairs', Map(
+ "id_score_pairs",
+ Map(
Scalar(np.int32),
Map(
Scalar(np.int64),
Scalar(np.float32),
- keys_name='ids',
- values_name='scores'
+ keys_name="ids",
+ values_name="scores",
),
- )
+ ),
),
# additional scalar information
(
- 'metadata', Struct(
- ('user_id', Scalar(np.int64)),
- ('user_embed', Scalar((np.float32, 2))),
- ('query', Scalar(str)),
- )
+ "metadata",
+ Struct(
+ ("user_id", Scalar(np.int64)),
+ ("user_embed", Scalar((np.float32, 2))),
+ ("query", Scalar(str)),
+ ),
),
)
"""
@@ -244,26 +245,24 @@
written as a tensor.
"""
expected_fields = [
- ('dense', (np.float32, 3)),
- ('floats:lengths', np.int32),
- ('floats:values:keys', np.int32),
- ('floats:values:values', np.float32),
- ('int_lists:lengths', np.int32),
- ('int_lists:values:keys', np.int32),
- ('int_lists:values:values:lengths', np.int32),
- ('int_lists:values:values:values', np.int64),
- ('id_score_pairs:lengths', np.int32),
- ('id_score_pairs:values:keys', np.int32),
- ('id_score_pairs:values:values:lengths', np.int32),
- ('id_score_pairs:values:values:values:ids', np.int64),
- ('id_score_pairs:values:values:values:scores', np.float32),
- ('metadata:user_id', np.int64),
- ('metadata:user_embed', (np.float32, 2)),
- ('metadata:query', str),
+ ("dense", (np.float32, 3)),
+ ("floats:lengths", np.int32),
+ ("floats:values:keys", np.int32),
+ ("floats:values:values", np.float32),
+ ("int_lists:lengths", np.int32),
+ ("int_lists:values:keys", np.int32),
+ ("int_lists:values:values:lengths", np.int32),
+ ("int_lists:values:values:values", np.int64),
+ ("id_score_pairs:lengths", np.int32),
+ ("id_score_pairs:values:keys", np.int32),
+ ("id_score_pairs:values:values:lengths", np.int32),
+ ("id_score_pairs:values:values:values:ids", np.int64),
+ ("id_score_pairs:values:values:values:scores", np.float32),
+ ("metadata:user_id", np.int64),
+ ("metadata:user_embed", (np.float32, 2)),
+ ("metadata:query", str),
]
- zipped = zip(
- expected_fields, schema.field_names(), schema.field_types()
- )
+ zipped = zip(expected_fields, schema.field_names(), schema.field_types())
for (ref_name, ref_type), name, dtype in zipped:
self.assertEquals(ref_name, name)
self.assertEquals(np.dtype(ref_type), dtype)
@@ -295,7 +294,7 @@
# metadata
[123, 234, 456], # user_id
[[0.2, 0.8], [0.5, 0.5], [0.7, 0.3]], # user_embed
- ['dog posts', 'friends who like to', 'posts about ca'], # query
+ ["dog posts", "friends who like to", "posts about ca"], # query
]
# convert the above content to ndarrays, checking against the schema
contents = from_blob_list(schema, contents_raw)
@@ -305,8 +304,8 @@
Then, a Writer is used to append these entries to the dataset.
"""
ds = dataset.Dataset(schema)
- net = core.Net('init')
- with core.NameScope('init'):
+ net = core.Net("init")
+ with core.NameScope("init"):
ds.init_empty(net)
content_blobs = NewRecord(net, contents)
@@ -337,7 +336,7 @@
[11.1], # id score pairs
[123],
[[0.2, 0.8]],
- ['dog posts'], # metadata
+ ["dog posts"], # metadata
),
(
[[2.1, 2.2, 2.3]], # dense
@@ -355,7 +354,7 @@
[21.1, 22.1, 22.2],
[234],
[[0.5, 0.5]],
- ['friends who like to'], # metadata
+ ["friends who like to"], # metadata
),
(
[[3.1, 3.2, 3.3]], # dense
@@ -373,11 +372,11 @@
[31.1, 31.2, 32.1, 32.2, 32.3], # id score list
[456],
[[0.7, 0.3]],
- ['posts about ca'], # metadata
+ ["posts about ca"], # metadata
),
# after the end of the dataset, we will keep getting empty vectors
- ([], ) * 16,
- ([], ) * 16,
+ ([],) * 16,
+ ([],) * 16,
]
entries = [from_blob_list(schema, e) for e in entries_raw]
"""
@@ -385,8 +384,8 @@
We will run `read` net multiple times and assert that we are reading the
entries the way we stated above.
"""
- read_init_net = core.Net('read_init')
- read_next_net = core.Net('read_next')
+ read_init_net = core.Net("read_init")
+ read_next_net = core.Net("read_next")
reader = ds.reader(read_init_net)
should_continue, batch = reader.read_record(read_next_net)
@@ -407,11 +406,11 @@
Where we will process the dataset a little and store it in a second
dataset. We can reuse the same Reader since it supports reset.
"""
- reset_net = core.Net('reset_net')
+ reset_net = core.Net("reset_net")
reader.reset(reset_net)
read_step, batch = reader.execution_step()
""" We will add the line number * 1000 to the feature ids. """
- process_net = core.Net('process')
+ process_net = core.Net("process")
line_no = Const(process_net, 0, dtype=np.int32)
const_one = Const(process_net, 1000, dtype=np.int32)
process_net.Add([line_no, const_one], [line_no])
@@ -419,19 +418,19 @@
process_net.Print(field, [])
process_net.Add([field, line_no], field, broadcast=1, axis=0)
""" Lets create a second dataset and append to it. """
- ds2 = dataset.Dataset(schema, name='dataset2')
+ ds2 = dataset.Dataset(schema, name="dataset2")
ds2.init_empty(reset_net)
writer = ds2.writer(reset_net)
writer.write_record(process_net, batch)
# commit is not necessary for DatasetWriter but will add it for
# generality of the example
- commit_net = core.Net('commit')
+ commit_net = core.Net("commit")
writer.commit(commit_net)
""" Time to create and run a plan which will do the processing """
- plan = core.Plan('process')
- plan.AddStep(core.execution_step('reset', reset_net))
+ plan = core.Plan("process")
+ plan.AddStep(core.execution_step("reset", reset_net))
plan.AddStep(read_step.AddNet(process_net))
- plan.AddStep(core.execution_step('commit', commit_net))
+ plan.AddStep(core.execution_step("commit", commit_net))
workspace.RunPlan(plan)
"""
Now we should have dataset2 populated.
@@ -446,18 +445,18 @@
You can create a new schema from pieces of another schema and reuse
the same data.
"""
- subschema = Struct(('top_level', schema.int_lists.values))
+ subschema = Struct(("top_level", schema.int_lists.values))
int_list_contents = contents.int_lists.values.field_names()
self.assertEquals(len(subschema.field_names()), len(int_list_contents))
"""
7. Random Access a dataset
"""
- read_init_net = core.Net('read_init')
- read_next_net = core.Net('read_next')
+ read_init_net = core.Net("read_init")
+ read_next_net = core.Net("read_next")
idx = np.array([2, 1, 0])
- indices_blob = Const(read_init_net, idx, name='indices')
+ indices_blob = Const(read_init_net, idx, name="indices")
reader = ds.random_reader(read_init_net, indices_blob)
reader.computeoffset(read_init_net)
@@ -480,11 +479,11 @@
8. Random Access a dataset with loop_over = true
"""
- read_init_net = core.Net('read_init')
- read_next_net = core.Net('read_next')
+ read_init_net = core.Net("read_init")
+ read_next_net = core.Net("read_next")
idx = np.array([2, 1, 0])
- indices_blob = Const(read_init_net, idx, name='indices')
+ indices_blob = Const(read_init_net, idx, name="indices")
reader = ds.random_reader(read_init_net, indices_blob, loop_over=True)
reader.computeoffset(read_init_net)
@@ -506,11 +505,11 @@
before shuffling the chunks.
"""
- read_init_net = core.Net('read_init')
- read_next_net = core.Net('read_next')
+ read_init_net = core.Net("read_init")
+ read_next_net = core.Net("read_next")
reader = ds.random_reader(read_init_net)
- reader.sort_and_shuffle(read_init_net, 'int_lists:lengths', 1, 2)
+ reader.sort_and_shuffle(read_init_net, "int_lists:lengths", 1, 2)
reader.computeoffset(read_init_net)
should_continue, batch = reader.read_record(read_next_net)
@@ -531,7 +530,7 @@
"""
Trim a dataset
"""
- trim_net = core.Net('trim_ds')
+ trim_net = core.Net("trim_ds")
ds.trim(trim_net, multiple_of=2)
workspace.RunNetOnce(trim_net)
trimmed = FetchRecord(ds.content())
@@ -540,67 +539,108 @@
self.assertEquals(EXPECTED_SIZES, actual_sizes)
def test_last_n_window_ops(self):
- collect_net = core.Net('collect_net')
+ collect_net = core.Net("collect_net")
collect_net.GivenTensorFill(
[],
- 'input',
+ "input",
shape=[3, 2],
values=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
)
- input_array =\
- np.array(list(range(1, 7)), dtype=np.float32).reshape(3, 2)
+ input_array = np.array(list(range(1, 7)), dtype=np.float32).reshape(3, 2)
- workspace.CreateBlob('output')
- workspace.FeedBlob('next', np.array(0, dtype=np.int32))
+ workspace.CreateBlob("output")
+ workspace.FeedBlob("next", np.array(0, dtype=np.int32))
collect_net.LastNWindowCollector(
- ['output', 'next', 'input'],
- ['output', 'next'],
+ ["output", "next", "input"],
+ ["output", "next"],
num_to_collect=7,
)
- plan = core.Plan('collect_data')
- plan.AddStep(
- core.execution_step('collect_data', [collect_net],
- num_iter=1)
- )
+ plan = core.Plan("collect_data")
+ plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=1))
workspace.RunPlan(plan)
- reference_result = workspace.FetchBlob('output')
+ reference_result = workspace.FetchBlob("output")
npt.assert_array_equal(input_array, reference_result)
- plan = core.Plan('collect_data')
- plan.AddStep(
- core.execution_step('collect_data', [collect_net],
- num_iter=2)
- )
+ plan = core.Plan("collect_data")
+ plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=2))
workspace.RunPlan(plan)
- reference_result = workspace.FetchBlob('output')
- npt.assert_array_equal(input_array[[1, 2, 2, 0, 1, 2, 0]],
- reference_result)
+ reference_result = workspace.FetchBlob("output")
+ npt.assert_array_equal(input_array[[1, 2, 2, 0, 1, 2, 0]], reference_result)
- plan = core.Plan('collect_data')
- plan.AddStep(
- core.execution_step('collect_data', [collect_net],
- num_iter=3)
- )
+ plan = core.Plan("collect_data")
+ plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=3))
workspace.RunPlan(plan)
- reference_result = workspace.FetchBlob('output')
- npt.assert_array_equal(input_array[[2, 0, 1, 2, 2, 0, 1]],
- reference_result)
+ reference_result = workspace.FetchBlob("output")
+ npt.assert_array_equal(input_array[[2, 0, 1, 2, 2, 0, 1]], reference_result)
+
+ def test_last_n_window_ops_shape_inference(self):
+ collect_net = core.Net("collect_net")
+ collect_net.GivenTensorFill(
+ [],
+ "input",
+ shape=[3, 2],
+ values=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
+ )
+
+ workspace.CreateBlob("output")
+ workspace.FeedBlob("next", np.array(0, dtype=np.int32))
+ collect_net.LastNWindowCollector(
+ ["output", "next", "input"],
+ ["output", "next"],
+ num_to_collect=7,
+ )
+ (shapes, types) = workspace.InferShapesAndTypes([collect_net])
+ workspace.RunNetOnce(collect_net)
+
+ self.assertTrue(
+ np.array_equal(
+ shapes["output"], np.array([7, workspace.blobs["output"].shape[1]])
+ )
+ )
+
+ def test_last_n_window_ops_shape_inference_4d_input(self):
+ input_shape = [3, 2, 4, 5]
+ collect_net = core.Net("collect_net")
+ collect_net.GivenTensorFill(
+ [],
+ "input",
+ shape=input_shape,
+ values=[
+ float(val) for val in range(functools.reduce(operator.mul, input_shape))
+ ],
+ )
+
+ workspace.CreateBlob("output")
+ workspace.FeedBlob("next", np.array(0, dtype=np.int32))
+ collect_net.LastNWindowCollector(
+ ["output", "next", "input"],
+ ["output", "next"],
+ num_to_collect=7,
+ )
+ (shapes, types) = workspace.InferShapesAndTypes([collect_net])
+ workspace.RunNetOnce(collect_net)
+
+ self.assertTrue(
+ np.array_equal(
+ shapes["output"], np.array([7, *list(workspace.blobs["output"].shape[1:])])
+ )
+ )
def test_collect_tensor_ops(self):
- init_net = core.Net('init_net')
- blobs = ['blob_1', 'blob_2', 'blob_3']
+ init_net = core.Net("init_net")
+ blobs = ["blob_1", "blob_2", "blob_3"]
bvec_map = {}
- ONE = init_net.ConstantFill([], 'ONE', shape=[1, 2], value=1)
+ ONE = init_net.ConstantFill([], "ONE", shape=[1, 2], value=1)
for b in blobs:
init_net.ConstantFill([], [b], shape=[1, 2], value=0)
- bvec_map[b] = b + '_vec'
+ bvec_map[b] = b + "_vec"
init_net.CreateTensorVector([], [bvec_map[b]])
- reader_net = core.Net('reader_net')
+ reader_net = core.Net("reader_net")
for b in blobs:
reader_net.Add([b, ONE], [b])
- collect_net = core.Net('collect_net')
+ collect_net = core.Net("collect_net")
num_to_collect = 1000
max_example_to_cover = 100000
bvec = [bvec_map[b] for b in blobs]
@@ -610,25 +650,24 @@
num_to_collect=num_to_collect,
)
- print('Collect Net Proto: {}'.format(collect_net.Proto()))
+ print("Collect Net Proto: {}".format(collect_net.Proto()))
- plan = core.Plan('collect_data')
- plan.AddStep(core.execution_step('collect_init', init_net))
+ plan = core.Plan("collect_data")
+ plan.AddStep(core.execution_step("collect_init", init_net))
plan.AddStep(
core.execution_step(
- 'collect_data', [reader_net, collect_net],
- num_iter=max_example_to_cover
+ "collect_data", [reader_net, collect_net], num_iter=max_example_to_cover
)
)
workspace.RunPlan(plan)
# concat the collected tensors
- concat_net = core.Net('concat_net')
+ concat_net = core.Net("concat_net")
bconcated_map = {}
bsize_map = {}
for b in blobs:
- bconcated_map[b] = b + '_concated'
- bsize_map[b] = b + '_size'
+ bconcated_map[b] = b + "_concated"
+ bsize_map[b] = b + "_size"
concat_net.ConcatTensorVector([bvec_map[b]], [bconcated_map[b]])
concat_net.TensorVectorSize([bvec_map[b]], [bsize_map[b]])
@@ -637,19 +676,16 @@
# check data
reference_result = workspace.FetchBlob(bconcated_map[blobs[0]])
self.assertEqual(
- reference_result.shape,
- (min(num_to_collect, max_example_to_cover), 2)
+ reference_result.shape, (min(num_to_collect, max_example_to_cover), 2)
)
size = workspace.FetchBlob(bsize_map[blobs[0]])
self.assertEqual(tuple(), size.shape)
self.assertEqual(min(num_to_collect, max_example_to_cover), size.item())
hist, _ = np.histogram(
- reference_result[:, 0],
- bins=10,
- range=(1, max_example_to_cover)
+ reference_result[:, 0], bins=10, range=(1, max_example_to_cover)
)
- print('Sample histogram: {}'.format(hist))
+ print("Sample histogram: {}".format(hist))
self.assertTrue(all(hist > 0.6 * (num_to_collect / 10)))
for i in range(1, len(blobs)):
@@ -659,4 +695,5 @@
if __name__ == "__main__":
import unittest
+
unittest.main()