blob: 801e17ff9ba1e2c9711543e3a7aef8c6af940cd9 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import workspace
from caffe2.python.core import Plan, to_execution_step
from caffe2.python.net_builder import ops, NetBuilder
import unittest
def test_loop():
x = ops.Const(5)
y = ops.Const(0)
with ops.loop():
ops.stop_if(ops.EQ([x, ops.Const(0)]))
ops.Add([x, ops.Const(-1)], [x])
ops.Add([y, ops.Const(1)], [y])
return y
def test_inner_stop(x):
ops.stop_if(ops.LT([x, ops.Const(5)]))
def test_outer():
x = ops.Const(10)
# test stop_if(False)
with ops.stop_guard() as g1:
test_inner_stop(x)
# test stop_if(True)
y = ops.Const(3)
with ops.stop_guard() as g2:
test_inner_stop(y)
# test no stop
with ops.stop_guard() as g4:
ops.Const(0)
# test empty clause
with ops.stop_guard() as g3:
pass
return (
g1.has_stopped(), g2.has_stopped(), g3.has_stopped(), g4.has_stopped())
def test_if(x):
y = ops.Const(1)
with ops.If(ops.GT([x, ops.Const(50)])):
ops.Const(2, blob_out=y)
with ops.If(ops.LT([x, ops.Const(50)])):
ops.Const(3, blob_out=y)
ops.stop()
ops.Const(4, blob_out=y)
return y
class TestNetBuilder(unittest.TestCase):
def test_ops(self):
with NetBuilder() as nb:
y = test_loop()
z, w, a, b = test_outer()
p = test_if(ops.Const(75))
q = test_if(ops.Const(25))
plan = Plan('name')
plan.AddStep(to_execution_step(nb))
ws = workspace.C.Workspace()
ws.run(plan)
expected = [
(y, 5),
(z, False),
(w, True),
(a, False),
(b, False),
(p, 3),
(q, 2),
]
for b, expected in expected:
actual = ws.blobs[str(b)].fetch()
self.assertEquals(actual, expected)