| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| import numpy as np |
| import unittest |
| |
| from caffe2.python import workspace, cnn |
| from caffe2.python import timeout_guard |
| import caffe2.python.data_workers as data_workers |
| |
| |
| def dummy_fetcher(fetcher_id, batch_size): |
| # Create random amount of values |
| n = np.random.randint(64) + 1 |
| data = np.zeros((n, 3)) |
| labels = [] |
| for j in range(n): |
| data[j, :] *= (j + fetcher_id) |
| labels.append(data[j, 0]) |
| |
| return [np.array(data), np.array(labels)] |
| |
| |
| class DataWorkersTest(unittest.TestCase): |
| |
| def testNonParallelModel(self): |
| model = cnn.CNNModelHelper(name="test") |
| coordinator = data_workers.init_data_input_workers( |
| model, |
| ["data", "label"], |
| dummy_fetcher, |
| 32, |
| 2, |
| ) |
| self.assertEqual(coordinator._fetcher_id_seq, 2) |
| coordinator.start() |
| |
| workspace.RunNetOnce(model.param_init_net) |
| workspace.CreateNet(model.net) |
| |
| for i in range(500): |
| with timeout_guard.CompleteInTimeOrDie(5): |
| workspace.RunNet(model.net.Proto().name) |
| |
| data = workspace.FetchBlob("data") |
| labels = workspace.FetchBlob("label") |
| |
| self.assertEqual(data.shape[0], labels.shape[0]) |
| self.assertEqual(data.shape[0], 32) |
| |
| for j in range(32): |
| self.assertEqual(labels[j], data[j, 0]) |
| self.assertEqual(labels[j], data[j, 1]) |
| self.assertEqual(labels[j], data[j, 2]) |
| |
| coordinator.stop() |