blob: 53509bf85027c8d682f539b1f002057e33ab4fb6 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import workspace, brew
from caffe2.python.model_helper import ModelHelper
import unittest
import numpy as np
class BrewTest(unittest.TestCase):
def setUp(self):
def myhelper(model, val=-1):
return val
if not brew.has_helper(myhelper):
brew.Register(myhelper)
self.myhelper = myhelper
def myhelper2(model, val=-1):
return val
if not brew.has_helper(myhelper2):
brew.Register(myhelper2)
self.myhelper2 = myhelper2
def test_dropout(self):
p = 0.2
X = np.ones((100, 100)).astype(np.float32) - p
workspace.FeedBlob("x", X)
model = ModelHelper(name="test_model")
brew.dropout(model, "x", "out")
workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)
out = workspace.FetchBlob("out")
self.assertLess(abs(out.mean() - (1 - p)), 0.05)
def test_fc(self):
m, n, k = (15, 15, 15)
X = np.random.rand(m, k).astype(np.float32) - 0.5
workspace.FeedBlob("x", X)
model = ModelHelper(name="test_model")
brew.fc(model, "x", "out_1", k, n)
workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)
def test_arg_scope(self):
myhelper = self.myhelper
myhelper2 = self.myhelper2
n = 15
with brew.arg_scope([myhelper], val=n):
res = brew.myhelper(None)
self.assertEqual(n, res)
with brew.arg_scope([myhelper, myhelper2], val=n):
res1 = brew.myhelper(None)
res2 = brew.myhelper2(None)
self.assertEqual([n, n], [res1, res2])
def test_arg_scope_single(self):
X = np.random.rand(64, 3, 32, 32).astype(np.float32) - 0.5
workspace.FeedBlob("x", X)
model = ModelHelper(name="test_model")
with brew.arg_scope(
brew.conv,
stride=2,
pad=2,
weight_init=('XavierFill', {}),
bias_init=('ConstantFill', {})
):
brew.conv(
model=model,
blob_in="x",
blob_out="out",
dim_in=3,
dim_out=64,
kernel=3,
)
workspace.RunNetOnce(model.param_init_net)
workspace.RunNetOnce(model.net)
out = workspace.FetchBlob("out")
self.assertEqual(out.shape, (64, 64, 17, 17))
def test_arg_scope_nested(self):
myhelper = self.myhelper
n = 16
with brew.arg_scope([myhelper], val=-3), \
brew.arg_scope([myhelper], val=-2):
with brew.arg_scope([myhelper], val=n):
res = brew.myhelper(None)
self.assertEqual(n, res)
res = brew.myhelper(None)
self.assertEqual(res, -2)
res = brew.myhelper(None, val=15)
self.assertEqual(res, 15)
def test_double_register(self):
myhelper = self.myhelper
with self.assertRaises(AttributeError):
brew.Register(myhelper)
def test_has_helper(self):
self.assertTrue(brew.has_helper(brew.conv))
self.assertTrue(brew.has_helper("conv"))
def myhelper3():
pass
self.assertFalse(brew.has_helper(myhelper3))