Fix a data_workers test
Summary:
This is a global variable which can be incremented by other tests.
Before:
```
$ pytest -v caffe2/python/data_workers_test.py
...
caffe2/python/data_workers_test.py::DataWorkersTest::testGracefulShutdown PASSED
caffe2/python/data_workers_test.py::DataWorkersTest::testNonParallelModel FAILED
============================================= FAILURES ==============================================
_______________________________ DataWorkersTest.testNonParallelModel ________________________________
self = <data_workers_test.DataWorkersTest testMethod=testNonParallelModel>
def testNonParallelModel(self):
model = cnn.CNNModelHelper(name="test")
coordinator = data_workers.init_data_input_workers(
model,
["data", "label"],
dummy_fetcher,
32,
2,
)
> self.assertEqual(coordinator._fetcher_id_seq, 2)
E AssertionError: 4 != 2
caffe2/python/data_workers_test.py:38: AssertionError
-----------------
Closes https://github.com/caffe2/caffe2/pull/211
Differential Revision: D4916591
Pulled By: Yangqing
fbshipit-source-id: 281f12d7f02dbd0ce0932024cf1f16cd12130112
diff --git a/caffe2/python/data_workers_test.py b/caffe2/python/data_workers_test.py
index 982a49a..25a552e 100644
--- a/caffe2/python/data_workers_test.py
+++ b/caffe2/python/data_workers_test.py
@@ -28,6 +28,7 @@
def testNonParallelModel(self):
model = cnn.CNNModelHelper(name="test")
+ old_seq_id = data_workers.global_coordinator._fetcher_id_seq
coordinator = data_workers.init_data_input_workers(
model,
["data", "label"],
@@ -35,7 +36,9 @@
32,
2,
)
- self.assertEqual(coordinator._fetcher_id_seq, 2)
+ new_seq_id = data_workers.global_coordinator._fetcher_id_seq
+ self.assertEqual(new_seq_id, old_seq_id + 2)
+
coordinator.start()
workspace.RunNetOnce(model.param_init_net)
@@ -60,6 +63,7 @@
def testGracefulShutdown(self):
model = cnn.CNNModelHelper(name="test")
+ old_seq_id = data_workers.global_coordinator._fetcher_id_seq
coordinator = data_workers.init_data_input_workers(
model,
["data", "label"],
@@ -67,7 +71,9 @@
32,
2,
)
- self.assertEqual(coordinator._fetcher_id_seq, 2)
+ new_seq_id = data_workers.global_coordinator._fetcher_id_seq
+ self.assertEqual(new_seq_id, old_seq_id + 2)
+
coordinator.start()
workspace.RunNetOnce(model.param_init_net)