|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  | from caffe2.python import brew, model_helper, scope | 
|  | from caffe2.python.modeling.parameter_sharing import ( | 
|  | ParameterSharing, | 
|  | parameter_sharing_context, | 
|  | ) | 
|  | from caffe2.python.modeling.initializers import ( | 
|  | Initializer | 
|  | ) | 
|  | import unittest | 
|  |  | 
|  |  | 
|  | class ParameterSharingTest(unittest.TestCase): | 
|  |  | 
|  | def test_parameter_sharing_default_scopes(self): | 
|  | # Test no sharing default scopes | 
|  | param_1 = parameter_sharing_context.get_parameter_name('w') | 
|  | self.assertEquals(param_1, 'w') | 
|  | with scope.NameScope('scope'): | 
|  | param_2 = parameter_sharing_context.get_parameter_name('w') | 
|  | self.assertEquals(param_2, 'scope/w') | 
|  | with scope.NameScope('scope_2'): | 
|  | param_3 = parameter_sharing_context.get_parameter_name('w') | 
|  | self.assertEquals(param_3, 'scope/scope_2/w') | 
|  |  | 
|  | def test_parameter_sharing_nested_scopes(self): | 
|  | # Test parameter sharing | 
|  | with scope.NameScope('global_scope'): | 
|  | with ParameterSharing({'model_b': 'model_a'}): | 
|  | param_global = parameter_sharing_context.get_parameter_name('w') | 
|  | self.assertEquals(param_global, 'global_scope/w') | 
|  | # This scope is overridden to match 'model_a' | 
|  | with scope.NameScope('model_b'): | 
|  | with ParameterSharing({'shared_scope': ''}): | 
|  | param_4 = parameter_sharing_context.get_parameter_name( | 
|  | 'w') | 
|  | self.assertEquals(param_4, 'global_scope/model_a/w') | 
|  | with scope.NameScope('shared_scope'): | 
|  | param_5 = parameter_sharing_context.\ | 
|  | get_parameter_name('w') | 
|  | self.assertEquals(param_5, 'global_scope/model_a/w') | 
|  | # This scope is supposed to have not sharing | 
|  | with scope.NameScope('model_c'): | 
|  | with ParameterSharing({'shared_scope': ''}): | 
|  | param_4 = parameter_sharing_context.get_parameter_name( | 
|  | 'w') | 
|  | self.assertEquals(param_4, 'global_scope/model_c/w') | 
|  | with scope.NameScope('shared_scope'): | 
|  | param_5 = parameter_sharing_context.\ | 
|  | get_parameter_name('w') | 
|  | self.assertEquals(param_5, 'global_scope/model_c/w') | 
|  |  | 
|  | def test_parameter_sharing_subscopes(self): | 
|  | # Sharing only one of the subscopes | 
|  | with ParameterSharing({'global_scope/b': 'global_scope/a'}): | 
|  | with scope.NameScope('global_scope'): | 
|  | param_6 = parameter_sharing_context.get_parameter_name('w') | 
|  | self.assertEquals(param_6, 'global_scope/w') | 
|  | with scope.NameScope('a'): | 
|  | param_7 = parameter_sharing_context.get_parameter_name('w') | 
|  | self.assertEquals(param_7, 'global_scope/a/w') | 
|  | with scope.NameScope('b'): | 
|  | param_8 = parameter_sharing_context.get_parameter_name('w') | 
|  | self.assertEquals(param_8, 'global_scope/a/w') | 
|  | with scope.NameScope('c'): | 
|  | param_9 = parameter_sharing_context.get_parameter_name('w') | 
|  | self.assertEquals(param_9, 'global_scope/c/w') | 
|  |  | 
|  | def test_create_param(self): | 
|  | model = model_helper.ModelHelper(name="test") | 
|  | # Test no sharing default scopes | 
|  | p1 = model.create_param( | 
|  | 'w', | 
|  | shape=[2], | 
|  | initializer=Initializer("ConstantFill") | 
|  | ) | 
|  | with scope.NameScope('some_global_scope'): | 
|  | p2 = model.create_param( | 
|  | 'w', | 
|  | shape=[2], | 
|  | initializer=Initializer("ConstantFill") | 
|  | ) | 
|  | self.assertNotEqual(model.get_param_info(p1), None) | 
|  | self.assertNotEqual(model.get_param_info(p2), None) | 
|  | self.assertNotEqual(model.get_param_info(p1), model.get_param_info(p2)) | 
|  | model.Validate() | 
|  |  | 
|  | def test_deep_hierarchy(self): | 
|  | model = model_helper.ModelHelper(name="test") | 
|  | with ParameterSharing({'a': 'b'}): | 
|  | with scope.NameScope('a'): | 
|  | with ParameterSharing({'c': 'd'}): | 
|  | with scope.NameScope('c'): | 
|  | with ParameterSharing({'e': 'f'}): | 
|  | with scope.NameScope('e'): | 
|  | p = model.create_param( | 
|  | 'w', | 
|  | shape=[2], | 
|  | initializer=Initializer("ConstantFill") | 
|  | ) | 
|  | self.assertNotEqual(model.get_param_info(p), None) | 
|  |  | 
|  |  | 
|  | def test_parameter_sharing_brew(self): | 
|  | # Test no sharing default scopes | 
|  | model = model_helper.ModelHelper(name="test") | 
|  | data = model.net.AddExternalInput("data") | 
|  | fc1 = brew.fc(model, data, "fc1", dim_in=16, dim_out=16) | 
|  | # Shared params are expected to share the same shape and fail if it's | 
|  | # not true | 
|  | with self.assertRaises(AssertionError): | 
|  | _ = brew.fc(model, data, "fc1", dim_in=2, dim_out=2)  # noqa | 
|  |  | 
|  | output_blobs = set() | 
|  | with scope.NameScope('some_global_scope'): | 
|  | with scope.NameScope('model_a'): | 
|  | output_blobs.add(str(brew.fc(model, fc1, 'output', 16, 16))) | 
|  | with ParameterSharing({'model_b': 'model_a'}),\ | 
|  | scope.NameScope('model_b'): | 
|  | with ParameterSharing({'shared_1': '', 'shared_2': ''}): | 
|  | # All params in DenseLayers from shared_1, shared_2 and | 
|  | # model_a are shared and will be pointing to: | 
|  | # [some_global_scope/model_a/output_W, | 
|  | #  some_global_scope/model_a/output_b] | 
|  | with scope.NameScope('shared_1'): | 
|  | output_blobs.add( | 
|  | str(brew.fc(model, fc1, 'output', 16, 16))) | 
|  | with scope.NameScope('shared_2'): | 
|  | output_blobs.add( | 
|  | str(brew.fc(model, fc1, 'output', 16, 16))) | 
|  | # Params of this layer are not shared with anyone unless | 
|  | # there is some explicit sharing with model_a/unshared (not | 
|  | # in this example). | 
|  | # Names of the blobs are | 
|  | # [some_global_scope/model_a/unshared/output_W, | 
|  | #  some_global_scope/model_a/unshared/output_b] | 
|  | with scope.NameScope('unshared'): | 
|  | output_blobs.add( | 
|  | str(brew.fc(model, fc1, 'output', 16, 16))) | 
|  |  | 
|  | self.assertEqual(len(model._parameters_info), 6) | 
|  | self.assertEqual(len(output_blobs), 4) | 
|  | self.assertEqual(sorted(model._parameters_info.keys()), [ | 
|  | 'fc1_b', | 
|  | 'fc1_w', | 
|  | 'some_global_scope/model_a/output_b', | 
|  | 'some_global_scope/model_a/output_w', | 
|  | 'some_global_scope/model_a/unshared/output_b', | 
|  | 'some_global_scope/model_a/unshared/output_w', | 
|  | ]) | 
|  | model.Validate() |