|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  |  | 
|  | from caffe2.python import control, core, test_util, workspace | 
|  |  | 
|  | import logging | 
|  | logger = logging.getLogger(__name__) | 
|  |  | 
|  |  | 
|  | class TestControl(test_util.TestCase): | 
|  | def setUp(self): | 
|  | super(TestControl, self).setUp() | 
|  | self.N_ = 10 | 
|  |  | 
|  | self.init_net_ = core.Net("init-net") | 
|  | cnt = self.init_net_.CreateCounter([], init_count=0) | 
|  | const_n = self.init_net_.ConstantFill( | 
|  | [], shape=[], value=self.N_, dtype=core.DataType.INT64) | 
|  | const_0 = self.init_net_.ConstantFill( | 
|  | [], shape=[], value=0, dtype=core.DataType.INT64) | 
|  |  | 
|  | self.cnt_net_ = core.Net("cnt-net") | 
|  | self.cnt_net_.CountUp([cnt]) | 
|  | curr_cnt = self.cnt_net_.RetrieveCount([cnt]) | 
|  | self.init_net_.ConstantFill( | 
|  | [], [curr_cnt], shape=[], value=0, dtype=core.DataType.INT64) | 
|  | self.cnt_net_.AddExternalOutput(curr_cnt) | 
|  |  | 
|  | self.cnt_2_net_ = core.Net("cnt-2-net") | 
|  | self.cnt_2_net_.CountUp([cnt]) | 
|  | self.cnt_2_net_.CountUp([cnt]) | 
|  | curr_cnt_2 = self.cnt_2_net_.RetrieveCount([cnt]) | 
|  | self.init_net_.ConstantFill( | 
|  | [], [curr_cnt_2], shape=[], value=0, dtype=core.DataType.INT64) | 
|  | self.cnt_2_net_.AddExternalOutput(curr_cnt_2) | 
|  |  | 
|  | self.cond_net_ = core.Net("cond-net") | 
|  | cond_blob = self.cond_net_.LT([curr_cnt, const_n]) | 
|  | self.cond_net_.AddExternalOutput(cond_blob) | 
|  |  | 
|  | self.not_cond_net_ = core.Net("not-cond-net") | 
|  | cond_blob = self.not_cond_net_.GE([curr_cnt, const_n]) | 
|  | self.not_cond_net_.AddExternalOutput(cond_blob) | 
|  |  | 
|  | self.true_cond_net_ = core.Net("true-cond-net") | 
|  | true_blob = self.true_cond_net_.LT([const_0, const_n]) | 
|  | self.true_cond_net_.AddExternalOutput(true_blob) | 
|  |  | 
|  | self.false_cond_net_ = core.Net("false-cond-net") | 
|  | false_blob = self.false_cond_net_.GT([const_0, const_n]) | 
|  | self.false_cond_net_.AddExternalOutput(false_blob) | 
|  |  | 
|  | self.idle_net_ = core.Net("idle-net") | 
|  | self.idle_net_.ConstantFill( | 
|  | [], shape=[], value=0, dtype=core.DataType.INT64) | 
|  |  | 
|  | def CheckNetOutput(self, nets_and_expects): | 
|  | """ | 
|  | Check the net output is expected | 
|  | nets_and_expects is a list of tuples (net, expect) | 
|  | """ | 
|  | for net, expect in nets_and_expects: | 
|  | output = workspace.FetchBlob( | 
|  | net.Proto().external_output[-1]) | 
|  | self.assertEqual(output, expect) | 
|  |  | 
|  | def CheckNetAllOutput(self, net, expects): | 
|  | """ | 
|  | Check the net output is expected | 
|  | expects is a list of bools. | 
|  | """ | 
|  | self.assertEqual(len(net.Proto().external_output), len(expects)) | 
|  | for i in range(len(expects)): | 
|  | output = workspace.FetchBlob( | 
|  | net.Proto().external_output[i]) | 
|  | self.assertEqual(output, expects[i]) | 
|  |  | 
|  | def BuildAndRunPlan(self, step): | 
|  | plan = core.Plan("test") | 
|  | plan.AddStep(control.Do('init', self.init_net_)) | 
|  | plan.AddStep(step) | 
|  | self.assertEqual(workspace.RunPlan(plan), True) | 
|  |  | 
|  | def ForLoopTest(self, nets_or_steps): | 
|  | step = control.For('myFor', nets_or_steps, self.N_) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(self.cnt_net_, self.N_)]) | 
|  |  | 
|  | def testForLoopWithNets(self): | 
|  | self.ForLoopTest(self.cnt_net_) | 
|  | self.ForLoopTest([self.cnt_net_, self.idle_net_]) | 
|  |  | 
|  | def testForLoopWithStep(self): | 
|  | step = control.Do('count', self.cnt_net_) | 
|  | self.ForLoopTest(step) | 
|  | self.ForLoopTest([step, self.idle_net_]) | 
|  |  | 
|  | def WhileLoopTest(self, nets_or_steps): | 
|  | step = control.While('myWhile', self.cond_net_, nets_or_steps) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(self.cnt_net_, self.N_)]) | 
|  |  | 
|  | def testWhileLoopWithNet(self): | 
|  | self.WhileLoopTest(self.cnt_net_) | 
|  | self.WhileLoopTest([self.cnt_net_, self.idle_net_]) | 
|  |  | 
|  | def testWhileLoopWithStep(self): | 
|  | step = control.Do('count', self.cnt_net_) | 
|  | self.WhileLoopTest(step) | 
|  | self.WhileLoopTest([step, self.idle_net_]) | 
|  |  | 
|  | def UntilLoopTest(self, nets_or_steps): | 
|  | step = control.Until('myUntil', self.not_cond_net_, nets_or_steps) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(self.cnt_net_, self.N_)]) | 
|  |  | 
|  | def testUntilLoopWithNet(self): | 
|  | self.UntilLoopTest(self.cnt_net_) | 
|  | self.UntilLoopTest([self.cnt_net_, self.idle_net_]) | 
|  |  | 
|  | def testUntilLoopWithStep(self): | 
|  | step = control.Do('count', self.cnt_net_) | 
|  | self.UntilLoopTest(step) | 
|  | self.UntilLoopTest([step, self.idle_net_]) | 
|  |  | 
|  | def DoWhileLoopTest(self, nets_or_steps): | 
|  | step = control.DoWhile('myDoWhile', self.cond_net_, nets_or_steps) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(self.cnt_net_, self.N_)]) | 
|  |  | 
|  | def testDoWhileLoopWithNet(self): | 
|  | self.DoWhileLoopTest(self.cnt_net_) | 
|  | self.DoWhileLoopTest([self.idle_net_, self.cnt_net_]) | 
|  |  | 
|  | def testDoWhileLoopWithStep(self): | 
|  | step = control.Do('count', self.cnt_net_) | 
|  | self.DoWhileLoopTest(step) | 
|  | self.DoWhileLoopTest([self.idle_net_, step]) | 
|  |  | 
|  | def DoUntilLoopTest(self, nets_or_steps): | 
|  | step = control.DoUntil('myDoUntil', self.not_cond_net_, nets_or_steps) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(self.cnt_net_, self.N_)]) | 
|  |  | 
|  | def testDoUntilLoopWithNet(self): | 
|  | self.DoUntilLoopTest(self.cnt_net_) | 
|  | self.DoUntilLoopTest([self.cnt_net_, self.idle_net_]) | 
|  |  | 
|  | def testDoUntilLoopWithStep(self): | 
|  | step = control.Do('count', self.cnt_net_) | 
|  | self.DoUntilLoopTest(step) | 
|  | self.DoUntilLoopTest([self.idle_net_, step]) | 
|  |  | 
|  | def IfCondTest(self, cond_net, expect, cond_on_blob): | 
|  | if cond_on_blob: | 
|  | step = control.Do( | 
|  | 'if-all', | 
|  | control.Do('count', cond_net), | 
|  | control.If('myIf', cond_net.Proto().external_output[-1], | 
|  | self.cnt_net_)) | 
|  | else: | 
|  | step = control.If('myIf', cond_net, self.cnt_net_) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(self.cnt_net_, expect)]) | 
|  |  | 
|  | def testIfCondTrueOnNet(self): | 
|  | self.IfCondTest(self.true_cond_net_, 1, False) | 
|  |  | 
|  | def testIfCondTrueOnBlob(self): | 
|  | self.IfCondTest(self.true_cond_net_, 1, True) | 
|  |  | 
|  | def testIfCondFalseOnNet(self): | 
|  | self.IfCondTest(self.false_cond_net_, 0, False) | 
|  |  | 
|  | def testIfCondFalseOnBlob(self): | 
|  | self.IfCondTest(self.false_cond_net_, 0, True) | 
|  |  | 
|  | def IfElseCondTest(self, cond_net, cond_value, expect, cond_on_blob): | 
|  | if cond_value: | 
|  | run_net = self.cnt_net_ | 
|  | else: | 
|  | run_net = self.cnt_2_net_ | 
|  | if cond_on_blob: | 
|  | step = control.Do( | 
|  | 'if-else-all', | 
|  | control.Do('count', cond_net), | 
|  | control.If('myIfElse', cond_net.Proto().external_output[-1], | 
|  | self.cnt_net_, self.cnt_2_net_)) | 
|  | else: | 
|  | step = control.If('myIfElse', cond_net, | 
|  | self.cnt_net_, self.cnt_2_net_) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(run_net, expect)]) | 
|  |  | 
|  | def testIfElseCondTrueOnNet(self): | 
|  | self.IfElseCondTest(self.true_cond_net_, True, 1, False) | 
|  |  | 
|  | def testIfElseCondTrueOnBlob(self): | 
|  | self.IfElseCondTest(self.true_cond_net_, True, 1, True) | 
|  |  | 
|  | def testIfElseCondFalseOnNet(self): | 
|  | self.IfElseCondTest(self.false_cond_net_, False, 2, False) | 
|  |  | 
|  | def testIfElseCondFalseOnBlob(self): | 
|  | self.IfElseCondTest(self.false_cond_net_, False, 2, True) | 
|  |  | 
|  | def IfNotCondTest(self, cond_net, expect, cond_on_blob): | 
|  | if cond_on_blob: | 
|  | step = control.Do( | 
|  | 'if-not', | 
|  | control.Do('count', cond_net), | 
|  | control.IfNot('myIfNot', cond_net.Proto().external_output[-1], | 
|  | self.cnt_net_)) | 
|  | else: | 
|  | step = control.IfNot('myIfNot', cond_net, self.cnt_net_) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(self.cnt_net_, expect)]) | 
|  |  | 
|  | def testIfNotCondTrueOnNet(self): | 
|  | self.IfNotCondTest(self.true_cond_net_, 0, False) | 
|  |  | 
|  | def testIfNotCondTrueOnBlob(self): | 
|  | self.IfNotCondTest(self.true_cond_net_, 0, True) | 
|  |  | 
|  | def testIfNotCondFalseOnNet(self): | 
|  | self.IfNotCondTest(self.false_cond_net_, 1, False) | 
|  |  | 
|  | def testIfNotCondFalseOnBlob(self): | 
|  | self.IfNotCondTest(self.false_cond_net_, 1, True) | 
|  |  | 
|  | def IfNotElseCondTest(self, cond_net, cond_value, expect, cond_on_blob): | 
|  | if cond_value: | 
|  | run_net = self.cnt_2_net_ | 
|  | else: | 
|  | run_net = self.cnt_net_ | 
|  | if cond_on_blob: | 
|  | step = control.Do( | 
|  | 'if-not-else', | 
|  | control.Do('count', cond_net), | 
|  | control.IfNot('myIfNotElse', | 
|  | cond_net.Proto().external_output[-1], | 
|  | self.cnt_net_, self.cnt_2_net_)) | 
|  | else: | 
|  | step = control.IfNot('myIfNotElse', cond_net, | 
|  | self.cnt_net_, self.cnt_2_net_) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(run_net, expect)]) | 
|  |  | 
|  | def testIfNotElseCondTrueOnNet(self): | 
|  | self.IfNotElseCondTest(self.true_cond_net_, True, 2, False) | 
|  |  | 
|  | def testIfNotElseCondTrueOnBlob(self): | 
|  | self.IfNotElseCondTest(self.true_cond_net_, True, 2, True) | 
|  |  | 
|  | def testIfNotElseCondFalseOnNet(self): | 
|  | self.IfNotElseCondTest(self.false_cond_net_, False, 1, False) | 
|  |  | 
|  | def testIfNotElseCondFalseOnBlob(self): | 
|  | self.IfNotElseCondTest(self.false_cond_net_, False, 1, True) | 
|  |  | 
|  | def testSwitch(self): | 
|  | step = control.Switch( | 
|  | 'mySwitch', | 
|  | (self.false_cond_net_, self.cnt_net_), | 
|  | (self.true_cond_net_, self.cnt_2_net_) | 
|  | ) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(self.cnt_net_, 0), (self.cnt_2_net_, 2)]) | 
|  |  | 
|  | def testSwitchNot(self): | 
|  | step = control.SwitchNot( | 
|  | 'mySwitchNot', | 
|  | (self.false_cond_net_, self.cnt_net_), | 
|  | (self.true_cond_net_, self.cnt_2_net_) | 
|  | ) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(self.cnt_net_, 1), (self.cnt_2_net_, 0)]) | 
|  |  | 
|  | def testBoolNet(self): | 
|  | bool_net = control.BoolNet(('a', True)) | 
|  | step = control.Do('bool', bool_net) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetAllOutput(bool_net, [True]) | 
|  |  | 
|  | bool_net = control.BoolNet(('a', True), ('b', False)) | 
|  | step = control.Do('bool', bool_net) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetAllOutput(bool_net, [True, False]) | 
|  |  | 
|  | bool_net = control.BoolNet([('a', True), ('b', False)]) | 
|  | step = control.Do('bool', bool_net) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetAllOutput(bool_net, [True, False]) | 
|  |  | 
|  | def testCombineConditions(self): | 
|  | # combined by 'Or' | 
|  | combine_net = control.CombineConditions( | 
|  | 'test', [self.true_cond_net_, self.false_cond_net_], 'Or') | 
|  | step = control.Do('combine', | 
|  | self.true_cond_net_, | 
|  | self.false_cond_net_, | 
|  | combine_net) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(combine_net, True)]) | 
|  |  | 
|  | # combined by 'And' | 
|  | combine_net = control.CombineConditions( | 
|  | 'test', [self.true_cond_net_, self.false_cond_net_], 'And') | 
|  | step = control.Do('combine', | 
|  | self.true_cond_net_, | 
|  | self.false_cond_net_, | 
|  | combine_net) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(combine_net, False)]) | 
|  |  | 
|  | def testMergeConditionNets(self): | 
|  | # merged by 'Or' | 
|  | merge_net = control.MergeConditionNets( | 
|  | 'test', [self.true_cond_net_, self.false_cond_net_], 'Or') | 
|  | step = control.Do('merge', merge_net) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(merge_net, True)]) | 
|  |  | 
|  | # merged by 'And' | 
|  | merge_net = control.MergeConditionNets( | 
|  | 'test', [self.true_cond_net_, self.false_cond_net_], 'And') | 
|  | step = control.Do('merge', merge_net) | 
|  | self.BuildAndRunPlan(step) | 
|  | self.CheckNetOutput([(merge_net, False)]) |