|  | #!/usr/bin/env python | 
|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  | from hypothesis import given, settings | 
|  | import hypothesis.strategies as st | 
|  | from multiprocessing import Process | 
|  |  | 
|  | import numpy as np | 
|  | import tempfile | 
|  | import shutil | 
|  |  | 
|  | import caffe2.python.hypothesis_test_util as hu | 
|  | import unittest | 
|  |  | 
|  | op_engine = 'GLOO' | 
|  |  | 
|  | class TemporaryDirectory: | 
|  | def __enter__(self): | 
|  | self.tmpdir = tempfile.mkdtemp() | 
|  | return self.tmpdir | 
|  |  | 
|  | def __exit__(self, type, value, traceback): | 
|  | shutil.rmtree(self.tmpdir) | 
|  |  | 
|  |  | 
|  | def allcompare_process(filestore_dir, process_id, data, num_procs): | 
|  | from caffe2.python import core, data_parallel_model, workspace, lazy_dyndep | 
|  | from caffe2.python.model_helper import ModelHelper | 
|  | from caffe2.proto import caffe2_pb2 | 
|  | lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops") | 
|  |  | 
|  | workspace.RunOperatorOnce( | 
|  | core.CreateOperator( | 
|  | "FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir | 
|  | ) | 
|  | ) | 
|  | rendezvous = dict( | 
|  | kv_handler="store_handler", | 
|  | shard_id=process_id, | 
|  | num_shards=num_procs, | 
|  | engine=op_engine, | 
|  | exit_nets=None | 
|  | ) | 
|  |  | 
|  | model = ModelHelper() | 
|  | model._rendezvous = rendezvous | 
|  |  | 
|  | workspace.FeedBlob("test_data", data) | 
|  |  | 
|  | data_parallel_model._RunComparison( | 
|  | model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0) | 
|  | ) | 
|  |  | 
|  |  | 
|  | class TestLazyDynDepAllCompare(hu.HypothesisTestCase): | 
|  | @given( | 
|  | d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8) | 
|  | ) | 
|  | @settings(deadline=None) | 
|  | def test_allcompare(self, d, n, num_procs): | 
|  | dims = [] | 
|  | for _ in range(d): | 
|  | dims.append(np.random.randint(1, high=n)) | 
|  | test_data = np.random.ranf(size=tuple(dims)).astype(np.float32) | 
|  |  | 
|  | with TemporaryDirectory() as tempdir: | 
|  | processes = [] | 
|  | for idx in range(num_procs): | 
|  | process = Process( | 
|  | target=allcompare_process, | 
|  | args=(tempdir, idx, test_data, num_procs) | 
|  | ) | 
|  | processes.append(process) | 
|  | process.start() | 
|  |  | 
|  | while len(processes) > 0: | 
|  | process = processes.pop() | 
|  | process.join() | 
|  |  | 
|  | class TestLazyDynDepError(unittest.TestCase): | 
|  | def test_errorhandler(self): | 
|  | from caffe2.python import core, lazy_dyndep | 
|  | import tempfile | 
|  |  | 
|  | with tempfile.NamedTemporaryFile() as f: | 
|  | lazy_dyndep.RegisterOpsLibrary(f.name) | 
|  |  | 
|  | def handler(e): | 
|  | raise ValueError("test") | 
|  | lazy_dyndep.SetErrorHandler(handler) | 
|  | with self.assertRaises(ValueError, msg="test"): | 
|  | core.RefreshRegisteredOperators() | 
|  |  | 
|  | def test_importaftererror(self): | 
|  | from caffe2.python import core, lazy_dyndep | 
|  | import tempfile | 
|  |  | 
|  | with tempfile.NamedTemporaryFile() as f: | 
|  | lazy_dyndep.RegisterOpsLibrary(f.name) | 
|  |  | 
|  | def handler(e): | 
|  | raise ValueError("test") | 
|  | lazy_dyndep.SetErrorHandler(handler) | 
|  | with self.assertRaises(ValueError): | 
|  | core.RefreshRegisteredOperators() | 
|  |  | 
|  | def handlernoop(e): | 
|  | raise | 
|  | lazy_dyndep.SetErrorHandler(handlernoop) | 
|  | lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops") | 
|  | core.RefreshRegisteredOperators() | 
|  |  | 
|  | def test_workspacecreatenet(self): | 
|  | from caffe2.python import workspace, lazy_dyndep | 
|  | import tempfile | 
|  |  | 
|  | with tempfile.NamedTemporaryFile() as f: | 
|  | lazy_dyndep.RegisterOpsLibrary(f.name) | 
|  | called = False | 
|  |  | 
|  | def handler(e): | 
|  | raise ValueError("test") | 
|  | lazy_dyndep.SetErrorHandler(handler) | 
|  | with self.assertRaises(ValueError, msg="test"): | 
|  | workspace.CreateNet("fake") | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | unittest.main() |