[mlf][efficiency] add tensor inference function to last-n collector op (#46693)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46693

title

Test Plan: unit tests

Reviewed By: hx89

Differential Revision: D23946770

fbshipit-source-id: f7c3d4a1b4ef3b0e5f56e5a9a30f5003ce9f40b0
diff --git a/caffe2/operators/last_n_window_collector.cc b/caffe2/operators/last_n_window_collector.cc
index 1b141b6..8b14c83 100644
--- a/caffe2/operators/last_n_window_collector.cc
+++ b/caffe2/operators/last_n_window_collector.cc
@@ -142,6 +142,30 @@
     .NumInputs({3, 4, 5})
     .NumOutputs(2, 3)
     .EnforceInplace({{0, 0}, {1, 1}, {4, 2}})
+    .TensorInferenceFunction([](const OperatorDef& def,
+                                const vector<TensorShape>& in) {
+      auto output_size = def.output_size();
+      vector<TensorShape> out(output_size);
+      const ArgumentHelper helper(def);
+      const auto num_to_collect =
+          helper.GetSingleArgument<int>("num_to_collect", -1);
+
+      const auto data_dims = GetDimsVector(in[2]);
+      vector<int64_t> last_n_shape(data_dims.size());
+      last_n_shape[0] = num_to_collect;
+      std::copy(data_dims.begin() + 1, data_dims.end(), last_n_shape.begin() + 1);
+      out[0] = CreateTensorShape(last_n_shape, in[2].data_type());
+
+      out[1] = in[1];
+
+      if (output_size > 2) {
+        vector<int64_t> num_visited_shape(1);
+        num_visited_shape[0] = 1;
+        out[2] = CreateTensorShape(num_visited_shape, TensorProto::INT64);
+      }
+
+      return out;
+    })
     .SetDoc(R"DOC(
 Collect the last N rows from input data. The purpose is to keep track of data
 accross batches, so for example suppose the LastNWindowCollector is called
diff --git a/caffe2/python/operator_test/dataset_ops_test.py b/caffe2/python/operator_test/dataset_ops_test.py
index 96d93dc..a7e0157 100644
--- a/caffe2/python/operator_test/dataset_ops_test.py
+++ b/caffe2/python/operator_test/dataset_ops_test.py
@@ -1,32 +1,32 @@
-
-
-
-
-import numpy as np
-from caffe2.python import core, workspace, dataset
-from caffe2.python.dataset import Const
-from caffe2.python.schema import (
-    List, Field, Struct, Scalar, Map, from_blob_list, FetchRecord, NewRecord,
-    FeedRecord
-)
-from caffe2.python.test_util import TestCase
-
-import numpy.testing as npt
-
+import functools
+import operator
 import string
 
-from hypothesis import given
 import hypothesis.strategies as st
+import numpy as np
+import numpy.testing as npt
+from caffe2.python import core, dataset, workspace
+from caffe2.python.dataset import Const
+from caffe2.python.schema import (
+    FeedRecord,
+    FetchRecord,
+    Field,
+    List,
+    Map,
+    NewRecord,
+    Scalar,
+    Struct,
+    from_blob_list,
+)
+from caffe2.python.test_util import TestCase
+from hypothesis import given
 
 
 def _assert_arrays_equal(actual, ref, err_msg):
-    if ref.dtype.kind in ('S', 'O', 'U'):
+    if ref.dtype.kind in ("S", "O", "U"):
         np.testing.assert_array_equal(actual, ref, err_msg=err_msg)
     else:
-        np.testing.assert_allclose(
-            actual, ref, atol=1e-4,
-            rtol=1e-4, err_msg=err_msg
-        )
+        np.testing.assert_allclose(actual, ref, atol=1e-4, rtol=1e-4, err_msg=err_msg)
 
 
 def _assert_records_equal(actual, ref):
@@ -34,11 +34,12 @@
     assert isinstance(ref, Field)
     b1 = actual.field_blobs()
     b2 = ref.field_blobs()
-    assert (len(b1) == len(b2)), 'Records have different lengths: %d vs. %d' % (
-        len(b1), len(b2)
+    assert len(b1) == len(b2), "Records have different lengths: %d vs. %d" % (
+        len(b1),
+        len(b2),
     )
     for name, d1, d2 in zip(ref.field_names(), b1, b2):
-        _assert_arrays_equal(d1, d2, err_msg='Mismatch in field %s.' % name)
+        _assert_arrays_equal(d1, d2, err_msg="Mismatch in field %s." % name)
 
 
 @st.composite
@@ -47,7 +48,7 @@
         st.lists(
             st.integers(min_value=1, max_value=10),
             min_size=num_records,
-            max_size=num_records
+            max_size=num_records,
         )
     )
 
@@ -58,7 +59,7 @@
             st.integers(min_value=1, max_value=100),
             min_size=sparse_maps_total_length,
             max_size=sparse_maps_total_length,
-            unique=True
+            unique=True,
         )
     )
 
@@ -66,7 +67,7 @@
         st.lists(
             st.integers(min_value=1, max_value=10),
             min_size=sparse_maps_total_length,
-            max_size=sparse_maps_total_length
+            max_size=sparse_maps_total_length,
         )
     )
 
@@ -77,7 +78,7 @@
         st.lists(
             st.integers(min_value=1, max_value=9223372036854775807),
             min_size=total_sparse_values_lengths,
-            max_size=total_sparse_values_lengths
+            max_size=total_sparse_values_lengths,
         )
     )
 
@@ -95,7 +96,7 @@
         st.lists(
             st.integers(min_value=1, max_value=10),
             min_size=num_records,
-            max_size=num_records
+            max_size=num_records,
         )
     )
 
@@ -106,14 +107,12 @@
             st.integers(min_value=1, max_value=100),
             min_size=total_length,
             max_size=total_length,
-            unique=True
+            unique=True,
         )
     )
 
     float_values = draw(
-        st.lists(st.floats(),
-                 min_size=total_length,
-                 max_size=total_length)
+        st.lists(st.floats(), min_size=total_length, max_size=total_length)
     )
 
     return [float_lengths, float_keys, float_values]
@@ -123,22 +122,20 @@
 def _dataset(draw, min_elements=3, max_elements=10, **kwargs):
     schema = Struct(
         # Dense Features Map
-        ('floats', Map(
-            Scalar(np.int32), Scalar(np.float32)
-        )),
+        ("floats", Map(Scalar(np.int32), Scalar(np.float32))),
         # Sparse Features Map
-        ('int_lists', Map(
-            Scalar(np.int32),
-            List(Scalar(np.int64)),
-        )),
+        (
+            "int_lists",
+            Map(
+                Scalar(np.int32),
+                List(Scalar(np.int64)),
+            ),
+        ),
         # Complex Type
-        ('text', Scalar(str)),
+        ("text", Scalar(str)),
     )
 
-    num_records = draw(
-        st.integers(min_value=min_elements,
-                    max_value=max_elements)
-    )
+    num_records = draw(st.integers(min_value=min_elements, max_value=max_elements))
 
     raw_dense_features_map_contents = draw(_dense_features_map(num_records))
 
@@ -149,13 +146,17 @@
             st.lists(
                 st.text(alphabet=string.ascii_lowercase),
                 min_size=num_records,
-                max_size=num_records
+                max_size=num_records,
             )
         )
     ]
 
     # Concatenate all raw contents to a single one
-    contents_raw = raw_dense_features_map_contents + raw_sparse_features_map_contents + raw_text_contents
+    contents_raw = (
+        raw_dense_features_map_contents
+        + raw_sparse_features_map_contents
+        + raw_text_contents
+    )
 
     contents = from_blob_list(schema, contents_raw)
 
@@ -172,31 +173,28 @@
 
         dataset_fields = schema.field_names()
 
-
         for pack_to_single_shared_ptr in (True, False):
-            net = core.Net('pack_unpack_net')
+            net = core.Net("pack_unpack_net")
             batch = NewRecord(net, contents)
             FeedRecord(batch, contents)
 
             packed = net.PackRecords(
-                batch.field_blobs(), 1,
+                batch.field_blobs(),
+                1,
                 fields=dataset_fields,
-                pack_to_single_shared_ptr=pack_to_single_shared_ptr
+                pack_to_single_shared_ptr=pack_to_single_shared_ptr,
             )
 
             unpacked = packed.UnPackRecords(
-                [], len(dataset_fields),
-                fields=dataset_fields
+                [], len(dataset_fields), fields=dataset_fields
             )
 
             workspace.RunNetOnce(net)
 
-            for initial_tensor, unpacked_tensor in zip(
-                batch.field_blobs(), unpacked
-            ):
+            for initial_tensor, unpacked_tensor in zip(batch.field_blobs(), unpacked):
                 npt.assert_array_equal(
                     workspace.FetchBlob(initial_tensor),
-                    workspace.FetchBlob(unpacked_tensor)
+                    workspace.FetchBlob(unpacked_tensor),
                 )
 
     def test_dataset_ops(self):
@@ -207,35 +205,38 @@
         """
         schema = Struct(
             # fixed size vector, which will be stored as a matrix when batched
-            ('dense', Scalar((np.float32, 3))),
+            ("dense", Scalar((np.float32, 3))),
             # could represent a feature map from feature ID to float value
-            ('floats', Map(
-                Scalar(np.int32), Scalar(np.float32)
-            )),
+            ("floats", Map(Scalar(np.int32), Scalar(np.float32))),
             # could represent a multi-valued categorical feature map
-            ('int_lists', Map(
-                Scalar(np.int32),
-                List(Scalar(np.int64)),
-            )),
+            (
+                "int_lists",
+                Map(
+                    Scalar(np.int32),
+                    List(Scalar(np.int64)),
+                ),
+            ),
             # could represent a multi-valued, weighted categorical feature map
             (
-                'id_score_pairs', Map(
+                "id_score_pairs",
+                Map(
                     Scalar(np.int32),
                     Map(
                         Scalar(np.int64),
                         Scalar(np.float32),
-                        keys_name='ids',
-                        values_name='scores'
+                        keys_name="ids",
+                        values_name="scores",
                     ),
-                )
+                ),
             ),
             # additional scalar information
             (
-                'metadata', Struct(
-                    ('user_id', Scalar(np.int64)),
-                    ('user_embed', Scalar((np.float32, 2))),
-                    ('query', Scalar(str)),
-                )
+                "metadata",
+                Struct(
+                    ("user_id", Scalar(np.int64)),
+                    ("user_embed", Scalar((np.float32, 2))),
+                    ("query", Scalar(str)),
+                ),
             ),
         )
         """
@@ -244,26 +245,24 @@
         written as a tensor.
         """
         expected_fields = [
-            ('dense', (np.float32, 3)),
-            ('floats:lengths', np.int32),
-            ('floats:values:keys', np.int32),
-            ('floats:values:values', np.float32),
-            ('int_lists:lengths', np.int32),
-            ('int_lists:values:keys', np.int32),
-            ('int_lists:values:values:lengths', np.int32),
-            ('int_lists:values:values:values', np.int64),
-            ('id_score_pairs:lengths', np.int32),
-            ('id_score_pairs:values:keys', np.int32),
-            ('id_score_pairs:values:values:lengths', np.int32),
-            ('id_score_pairs:values:values:values:ids', np.int64),
-            ('id_score_pairs:values:values:values:scores', np.float32),
-            ('metadata:user_id', np.int64),
-            ('metadata:user_embed', (np.float32, 2)),
-            ('metadata:query', str),
+            ("dense", (np.float32, 3)),
+            ("floats:lengths", np.int32),
+            ("floats:values:keys", np.int32),
+            ("floats:values:values", np.float32),
+            ("int_lists:lengths", np.int32),
+            ("int_lists:values:keys", np.int32),
+            ("int_lists:values:values:lengths", np.int32),
+            ("int_lists:values:values:values", np.int64),
+            ("id_score_pairs:lengths", np.int32),
+            ("id_score_pairs:values:keys", np.int32),
+            ("id_score_pairs:values:values:lengths", np.int32),
+            ("id_score_pairs:values:values:values:ids", np.int64),
+            ("id_score_pairs:values:values:values:scores", np.float32),
+            ("metadata:user_id", np.int64),
+            ("metadata:user_embed", (np.float32, 2)),
+            ("metadata:query", str),
         ]
-        zipped = zip(
-            expected_fields, schema.field_names(), schema.field_types()
-        )
+        zipped = zip(expected_fields, schema.field_names(), schema.field_types())
         for (ref_name, ref_type), name, dtype in zipped:
             self.assertEquals(ref_name, name)
             self.assertEquals(np.dtype(ref_type), dtype)
@@ -295,7 +294,7 @@
             # metadata
             [123, 234, 456],  # user_id
             [[0.2, 0.8], [0.5, 0.5], [0.7, 0.3]],  # user_embed
-            ['dog posts', 'friends who like to', 'posts about ca'],  # query
+            ["dog posts", "friends who like to", "posts about ca"],  # query
         ]
         # convert the above content to ndarrays, checking against the schema
         contents = from_blob_list(schema, contents_raw)
@@ -305,8 +304,8 @@
         Then, a Writer is used to append these entries to the dataset.
         """
         ds = dataset.Dataset(schema)
-        net = core.Net('init')
-        with core.NameScope('init'):
+        net = core.Net("init")
+        with core.NameScope("init"):
             ds.init_empty(net)
 
             content_blobs = NewRecord(net, contents)
@@ -337,7 +336,7 @@
                 [11.1],  # id score pairs
                 [123],
                 [[0.2, 0.8]],
-                ['dog posts'],  # metadata
+                ["dog posts"],  # metadata
             ),
             (
                 [[2.1, 2.2, 2.3]],  # dense
@@ -355,7 +354,7 @@
                 [21.1, 22.1, 22.2],
                 [234],
                 [[0.5, 0.5]],
-                ['friends who like to'],  # metadata
+                ["friends who like to"],  # metadata
             ),
             (
                 [[3.1, 3.2, 3.3]],  # dense
@@ -373,11 +372,11 @@
                 [31.1, 31.2, 32.1, 32.2, 32.3],  # id score list
                 [456],
                 [[0.7, 0.3]],
-                ['posts about ca'],  # metadata
+                ["posts about ca"],  # metadata
             ),
             # after the end of the dataset, we will keep getting empty vectors
-            ([], ) * 16,
-            ([], ) * 16,
+            ([],) * 16,
+            ([],) * 16,
         ]
         entries = [from_blob_list(schema, e) for e in entries_raw]
         """
@@ -385,8 +384,8 @@
         We will run `read` net multiple times and assert that we are reading the
         entries the way we stated above.
         """
-        read_init_net = core.Net('read_init')
-        read_next_net = core.Net('read_next')
+        read_init_net = core.Net("read_init")
+        read_next_net = core.Net("read_next")
         reader = ds.reader(read_init_net)
         should_continue, batch = reader.read_record(read_next_net)
 
@@ -407,11 +406,11 @@
         Where we will process the dataset a little and store it in a second
         dataset. We can reuse the same Reader since it supports reset.
         """
-        reset_net = core.Net('reset_net')
+        reset_net = core.Net("reset_net")
         reader.reset(reset_net)
         read_step, batch = reader.execution_step()
         """ We will add the line number * 1000 to the feature ids. """
-        process_net = core.Net('process')
+        process_net = core.Net("process")
         line_no = Const(process_net, 0, dtype=np.int32)
         const_one = Const(process_net, 1000, dtype=np.int32)
         process_net.Add([line_no, const_one], [line_no])
@@ -419,19 +418,19 @@
         process_net.Print(field, [])
         process_net.Add([field, line_no], field, broadcast=1, axis=0)
         """ Lets create a second dataset and append to it. """
-        ds2 = dataset.Dataset(schema, name='dataset2')
+        ds2 = dataset.Dataset(schema, name="dataset2")
         ds2.init_empty(reset_net)
         writer = ds2.writer(reset_net)
         writer.write_record(process_net, batch)
         # commit is not necessary for DatasetWriter but will add it for
         # generality of the example
-        commit_net = core.Net('commit')
+        commit_net = core.Net("commit")
         writer.commit(commit_net)
         """ Time to create and run a plan which will do the processing """
-        plan = core.Plan('process')
-        plan.AddStep(core.execution_step('reset', reset_net))
+        plan = core.Plan("process")
+        plan.AddStep(core.execution_step("reset", reset_net))
         plan.AddStep(read_step.AddNet(process_net))
-        plan.AddStep(core.execution_step('commit', commit_net))
+        plan.AddStep(core.execution_step("commit", commit_net))
         workspace.RunPlan(plan)
         """
         Now we should have dataset2 populated.
@@ -446,18 +445,18 @@
         You can create a new schema from pieces of another schema and reuse
         the same data.
         """
-        subschema = Struct(('top_level', schema.int_lists.values))
+        subschema = Struct(("top_level", schema.int_lists.values))
         int_list_contents = contents.int_lists.values.field_names()
         self.assertEquals(len(subschema.field_names()), len(int_list_contents))
         """
         7. Random Access a dataset
 
         """
-        read_init_net = core.Net('read_init')
-        read_next_net = core.Net('read_next')
+        read_init_net = core.Net("read_init")
+        read_next_net = core.Net("read_next")
 
         idx = np.array([2, 1, 0])
-        indices_blob = Const(read_init_net, idx, name='indices')
+        indices_blob = Const(read_init_net, idx, name="indices")
         reader = ds.random_reader(read_init_net, indices_blob)
         reader.computeoffset(read_init_net)
 
@@ -480,11 +479,11 @@
         8. Random Access a dataset with loop_over = true
 
         """
-        read_init_net = core.Net('read_init')
-        read_next_net = core.Net('read_next')
+        read_init_net = core.Net("read_init")
+        read_next_net = core.Net("read_next")
 
         idx = np.array([2, 1, 0])
-        indices_blob = Const(read_init_net, idx, name='indices')
+        indices_blob = Const(read_init_net, idx, name="indices")
         reader = ds.random_reader(read_init_net, indices_blob, loop_over=True)
         reader.computeoffset(read_init_net)
 
@@ -506,11 +505,11 @@
         before shuffling the chunks.
 
         """
-        read_init_net = core.Net('read_init')
-        read_next_net = core.Net('read_next')
+        read_init_net = core.Net("read_init")
+        read_next_net = core.Net("read_next")
 
         reader = ds.random_reader(read_init_net)
-        reader.sort_and_shuffle(read_init_net, 'int_lists:lengths', 1, 2)
+        reader.sort_and_shuffle(read_init_net, "int_lists:lengths", 1, 2)
         reader.computeoffset(read_init_net)
 
         should_continue, batch = reader.read_record(read_next_net)
@@ -531,7 +530,7 @@
         """
         Trim a dataset
         """
-        trim_net = core.Net('trim_ds')
+        trim_net = core.Net("trim_ds")
         ds.trim(trim_net, multiple_of=2)
         workspace.RunNetOnce(trim_net)
         trimmed = FetchRecord(ds.content())
@@ -540,67 +539,108 @@
         self.assertEquals(EXPECTED_SIZES, actual_sizes)
 
     def test_last_n_window_ops(self):
-        collect_net = core.Net('collect_net')
+        collect_net = core.Net("collect_net")
         collect_net.GivenTensorFill(
             [],
-            'input',
+            "input",
             shape=[3, 2],
             values=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
         )
-        input_array =\
-            np.array(list(range(1, 7)), dtype=np.float32).reshape(3, 2)
+        input_array = np.array(list(range(1, 7)), dtype=np.float32).reshape(3, 2)
 
-        workspace.CreateBlob('output')
-        workspace.FeedBlob('next', np.array(0, dtype=np.int32))
+        workspace.CreateBlob("output")
+        workspace.FeedBlob("next", np.array(0, dtype=np.int32))
         collect_net.LastNWindowCollector(
-            ['output', 'next', 'input'],
-            ['output', 'next'],
+            ["output", "next", "input"],
+            ["output", "next"],
             num_to_collect=7,
         )
-        plan = core.Plan('collect_data')
-        plan.AddStep(
-            core.execution_step('collect_data', [collect_net],
-                                num_iter=1)
-        )
+        plan = core.Plan("collect_data")
+        plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=1))
         workspace.RunPlan(plan)
-        reference_result = workspace.FetchBlob('output')
+        reference_result = workspace.FetchBlob("output")
         npt.assert_array_equal(input_array, reference_result)
 
-        plan = core.Plan('collect_data')
-        plan.AddStep(
-            core.execution_step('collect_data', [collect_net],
-                                num_iter=2)
-        )
+        plan = core.Plan("collect_data")
+        plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=2))
         workspace.RunPlan(plan)
-        reference_result = workspace.FetchBlob('output')
-        npt.assert_array_equal(input_array[[1, 2, 2, 0, 1, 2, 0]],
-                               reference_result)
+        reference_result = workspace.FetchBlob("output")
+        npt.assert_array_equal(input_array[[1, 2, 2, 0, 1, 2, 0]], reference_result)
 
-        plan = core.Plan('collect_data')
-        plan.AddStep(
-            core.execution_step('collect_data', [collect_net],
-                                num_iter=3)
-        )
+        plan = core.Plan("collect_data")
+        plan.AddStep(core.execution_step("collect_data", [collect_net], num_iter=3))
         workspace.RunPlan(plan)
-        reference_result = workspace.FetchBlob('output')
-        npt.assert_array_equal(input_array[[2, 0, 1, 2, 2, 0, 1]],
-                               reference_result)
+        reference_result = workspace.FetchBlob("output")
+        npt.assert_array_equal(input_array[[2, 0, 1, 2, 2, 0, 1]], reference_result)
+
+    def test_last_n_window_ops_shape_inference(self):
+        collect_net = core.Net("collect_net")
+        collect_net.GivenTensorFill(
+            [],
+            "input",
+            shape=[3, 2],
+            values=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
+        )
+
+        workspace.CreateBlob("output")
+        workspace.FeedBlob("next", np.array(0, dtype=np.int32))
+        collect_net.LastNWindowCollector(
+            ["output", "next", "input"],
+            ["output", "next"],
+            num_to_collect=7,
+        )
+        (shapes, types) = workspace.InferShapesAndTypes([collect_net])
+        workspace.RunNetOnce(collect_net)
+
+        self.assertTrue(
+            np.array_equal(
+                shapes["output"], np.array([7, workspace.blobs["output"].shape[1]])
+            )
+        )
+
+    def test_last_n_window_ops_shape_inference_4d_input(self):
+        input_shape = [3, 2, 4, 5]
+        collect_net = core.Net("collect_net")
+        collect_net.GivenTensorFill(
+            [],
+            "input",
+            shape=input_shape,
+            values=[
+                float(val) for val in range(functools.reduce(operator.mul, input_shape))
+            ],
+        )
+
+        workspace.CreateBlob("output")
+        workspace.FeedBlob("next", np.array(0, dtype=np.int32))
+        collect_net.LastNWindowCollector(
+            ["output", "next", "input"],
+            ["output", "next"],
+            num_to_collect=7,
+        )
+        (shapes, types) = workspace.InferShapesAndTypes([collect_net])
+        workspace.RunNetOnce(collect_net)
+
+        self.assertTrue(
+            np.array_equal(
+                shapes["output"], np.array([7, *list(workspace.blobs["output"].shape[1:])])
+            )
+        )
 
     def test_collect_tensor_ops(self):
-        init_net = core.Net('init_net')
-        blobs = ['blob_1', 'blob_2', 'blob_3']
+        init_net = core.Net("init_net")
+        blobs = ["blob_1", "blob_2", "blob_3"]
         bvec_map = {}
-        ONE = init_net.ConstantFill([], 'ONE', shape=[1, 2], value=1)
+        ONE = init_net.ConstantFill([], "ONE", shape=[1, 2], value=1)
         for b in blobs:
             init_net.ConstantFill([], [b], shape=[1, 2], value=0)
-            bvec_map[b] = b + '_vec'
+            bvec_map[b] = b + "_vec"
             init_net.CreateTensorVector([], [bvec_map[b]])
 
-        reader_net = core.Net('reader_net')
+        reader_net = core.Net("reader_net")
         for b in blobs:
             reader_net.Add([b, ONE], [b])
 
-        collect_net = core.Net('collect_net')
+        collect_net = core.Net("collect_net")
         num_to_collect = 1000
         max_example_to_cover = 100000
         bvec = [bvec_map[b] for b in blobs]
@@ -610,25 +650,24 @@
             num_to_collect=num_to_collect,
         )
 
-        print('Collect Net Proto: {}'.format(collect_net.Proto()))
+        print("Collect Net Proto: {}".format(collect_net.Proto()))
 
-        plan = core.Plan('collect_data')
-        plan.AddStep(core.execution_step('collect_init', init_net))
+        plan = core.Plan("collect_data")
+        plan.AddStep(core.execution_step("collect_init", init_net))
         plan.AddStep(
             core.execution_step(
-                'collect_data', [reader_net, collect_net],
-                num_iter=max_example_to_cover
+                "collect_data", [reader_net, collect_net], num_iter=max_example_to_cover
             )
         )
         workspace.RunPlan(plan)
 
         # concat the collected tensors
-        concat_net = core.Net('concat_net')
+        concat_net = core.Net("concat_net")
         bconcated_map = {}
         bsize_map = {}
         for b in blobs:
-            bconcated_map[b] = b + '_concated'
-            bsize_map[b] = b + '_size'
+            bconcated_map[b] = b + "_concated"
+            bsize_map[b] = b + "_size"
             concat_net.ConcatTensorVector([bvec_map[b]], [bconcated_map[b]])
             concat_net.TensorVectorSize([bvec_map[b]], [bsize_map[b]])
 
@@ -637,19 +676,16 @@
         # check data
         reference_result = workspace.FetchBlob(bconcated_map[blobs[0]])
         self.assertEqual(
-            reference_result.shape,
-            (min(num_to_collect, max_example_to_cover), 2)
+            reference_result.shape, (min(num_to_collect, max_example_to_cover), 2)
         )
         size = workspace.FetchBlob(bsize_map[blobs[0]])
         self.assertEqual(tuple(), size.shape)
         self.assertEqual(min(num_to_collect, max_example_to_cover), size.item())
 
         hist, _ = np.histogram(
-            reference_result[:, 0],
-            bins=10,
-            range=(1, max_example_to_cover)
+            reference_result[:, 0], bins=10, range=(1, max_example_to_cover)
         )
-        print('Sample histogram: {}'.format(hist))
+        print("Sample histogram: {}".format(hist))
 
         self.assertTrue(all(hist > 0.6 * (num_to_collect / 10)))
         for i in range(1, len(blobs)):
@@ -659,4 +695,5 @@
 
 if __name__ == "__main__":
     import unittest
+
     unittest.main()