| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| from __future__ import unicode_literals |
| from caffe2.python import workspace, brew |
| from caffe2.python.model_helper import ModelHelper |
| |
| import unittest |
| import numpy as np |
| |
| |
| class BrewTest(unittest.TestCase): |
| def setUp(self): |
| |
| def myhelper(model, val=-1): |
| return val |
| |
| if not brew.has_helper(myhelper): |
| brew.Register(myhelper) |
| self.myhelper = myhelper |
| |
| def myhelper2(model, val=-1): |
| return val |
| |
| if not brew.has_helper(myhelper2): |
| brew.Register(myhelper2) |
| self.myhelper2 = myhelper2 |
| |
| def test_dropout(self): |
| p = 0.2 |
| X = np.ones((100, 100)).astype(np.float32) - p |
| workspace.FeedBlob("x", X) |
| model = ModelHelper(name="test_model") |
| brew.dropout(model, "x", "out") |
| workspace.RunNetOnce(model.param_init_net) |
| workspace.RunNetOnce(model.net) |
| out = workspace.FetchBlob("out") |
| self.assertLess(abs(out.mean() - (1 - p)), 0.05) |
| |
| def test_fc(self): |
| m, n, k = (15, 15, 15) |
| X = np.random.rand(m, k).astype(np.float32) - 0.5 |
| |
| workspace.FeedBlob("x", X) |
| model = ModelHelper(name="test_model") |
| brew.fc(model, "x", "out_1", k, n) |
| |
| workspace.RunNetOnce(model.param_init_net) |
| workspace.RunNetOnce(model.net) |
| |
| def test_arg_scope(self): |
| myhelper = self.myhelper |
| myhelper2 = self.myhelper2 |
| n = 15 |
| with brew.arg_scope([myhelper], val=n): |
| res = brew.myhelper(None) |
| self.assertEqual(n, res) |
| |
| with brew.arg_scope([myhelper, myhelper2], val=n): |
| res1 = brew.myhelper(None) |
| res2 = brew.myhelper2(None) |
| self.assertEqual([n, n], [res1, res2]) |
| |
| def test_arg_scope_single(self): |
| X = np.random.rand(64, 3, 32, 32).astype(np.float32) - 0.5 |
| |
| workspace.FeedBlob("x", X) |
| model = ModelHelper(name="test_model") |
| with brew.arg_scope( |
| brew.conv, |
| stride=2, |
| pad=2, |
| weight_init=('XavierFill', {}), |
| bias_init=('ConstantFill', {}) |
| ): |
| brew.conv( |
| model=model, |
| blob_in="x", |
| blob_out="out", |
| dim_in=3, |
| dim_out=64, |
| kernel=3, |
| ) |
| |
| workspace.RunNetOnce(model.param_init_net) |
| workspace.RunNetOnce(model.net) |
| out = workspace.FetchBlob("out") |
| self.assertEqual(out.shape, (64, 64, 17, 17)) |
| |
| def test_arg_scope_nested(self): |
| myhelper = self.myhelper |
| n = 16 |
| with brew.arg_scope([myhelper], val=-3), \ |
| brew.arg_scope([myhelper], val=-2): |
| with brew.arg_scope([myhelper], val=n): |
| res = brew.myhelper(None) |
| self.assertEqual(n, res) |
| res = brew.myhelper(None) |
| self.assertEqual(res, -2) |
| |
| res = brew.myhelper(None, val=15) |
| self.assertEqual(res, 15) |
| |
| def test_double_register(self): |
| myhelper = self.myhelper |
| with self.assertRaises(AttributeError): |
| brew.Register(myhelper) |
| |
| def test_has_helper(self): |
| self.assertTrue(brew.has_helper(brew.conv)) |
| self.assertTrue(brew.has_helper("conv")) |
| |
| def myhelper3(): |
| pass |
| |
| self.assertFalse(brew.has_helper(myhelper3)) |