| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| |
| from caffe2.python import workspace |
| from caffe2.python.core import Plan, to_execution_step |
| from caffe2.python.net_builder import ops, NetBuilder |
| import unittest |
| |
| |
| def test_loop(): |
| x = ops.Const(5) |
| y = ops.Const(0) |
| with ops.loop(): |
| ops.stop_if(ops.EQ([x, ops.Const(0)])) |
| ops.Add([x, ops.Const(-1)], [x]) |
| ops.Add([y, ops.Const(1)], [y]) |
| return y |
| |
| |
| def test_inner_stop(x): |
| ops.stop_if(ops.LT([x, ops.Const(5)])) |
| |
| |
| def test_outer(): |
| x = ops.Const(10) |
| # test stop_if(False) |
| with ops.stop_guard() as g1: |
| test_inner_stop(x) |
| |
| # test stop_if(True) |
| y = ops.Const(3) |
| with ops.stop_guard() as g2: |
| test_inner_stop(y) |
| |
| # test no stop |
| with ops.stop_guard() as g4: |
| ops.Const(0) |
| |
| # test empty clause |
| with ops.stop_guard() as g3: |
| pass |
| |
| return ( |
| g1.has_stopped(), g2.has_stopped(), g3.has_stopped(), g4.has_stopped()) |
| |
| |
| def test_if(x): |
| y = ops.Const(1) |
| with ops.If(ops.GT([x, ops.Const(50)])): |
| ops.Const(2, blob_out=y) |
| with ops.If(ops.LT([x, ops.Const(50)])): |
| ops.Const(3, blob_out=y) |
| ops.stop() |
| ops.Const(4, blob_out=y) |
| return y |
| |
| |
| class TestNetBuilder(unittest.TestCase): |
| def test_ops(self): |
| with NetBuilder() as nb: |
| y = test_loop() |
| z, w, a, b = test_outer() |
| p = test_if(ops.Const(75)) |
| q = test_if(ops.Const(25)) |
| plan = Plan('name') |
| plan.AddStep(to_execution_step(nb)) |
| ws = workspace.C.Workspace() |
| ws.run(plan) |
| expected = [ |
| (y, 5), |
| (z, False), |
| (w, True), |
| (a, False), |
| (b, False), |
| (p, 3), |
| (q, 2), |
| ] |
| for b, expected in expected: |
| actual = ws.blobs[str(b)].fetch() |
| self.assertEquals(actual, expected) |