blob: e69d930213f3790bdb9201c265d3a38495e7c3f1 [file] [log] [blame]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import schema
import numpy as np
import unittest
import pickle
class TestDB(unittest.TestCase):
def testPicklable(self):
s = schema.Struct(
('field1', schema.Scalar(dtype=np.int32)),
('field2', schema.List(schema.Scalar(dtype=str)))
)
s2 = pickle.loads(pickle.dumps(s))
for r in (s, s2):
self.assertTrue(isinstance(r.field1, schema.Scalar))
self.assertTrue(isinstance(r.field2, schema.List))
self.assertTrue(getattr(r, 'non_existent', None) is None)
def testNormalizeField(self):
s = schema.Struct(('field1', np.int32), ('field2', str))
self.assertEquals(
s,
schema.Struct(
('field1', schema.Scalar(dtype=np.int32)),
('field2', schema.Scalar(dtype=str))
)
)
def testTuple(self):
s = schema.Tuple(np.int32, str, np.float32)
s2 = schema.Struct(
('field_0', schema.Scalar(dtype=np.int32)),
('field_1', schema.Scalar(dtype=np.str)),
('field_2', schema.Scalar(dtype=np.float32))
)
self.assertEquals(s, s2)
self.assertEquals(s[0], schema.Scalar(dtype=np.int32))
self.assertEquals(s[1], schema.Scalar(dtype=np.str))
self.assertEquals(s[2], schema.Scalar(dtype=np.float32))
self.assertEquals(
s[2, 0],
schema.Struct(
('field_2', schema.Scalar(dtype=np.float32)),
('field_0', schema.Scalar(dtype=np.int32)),
)
)
# test iterator behavior
for i, (v1, v2) in enumerate(zip(s, s2)):
self.assertEquals(v1, v2)
self.assertEquals(s[i], v1)
self.assertEquals(s2[i], v1)
def testRawTuple(self):
s = schema.RawTuple(2)
self.assertEquals(
s, schema.Struct(
('field_0', schema.Scalar()), ('field_1', schema.Scalar())
)
)
self.assertEquals(s[0], schema.Scalar())
self.assertEquals(s[1], schema.Scalar())
def testStructIndexing(self):
s = schema.Struct(
('field1', schema.Scalar(dtype=np.int32)),
('field2', schema.List(schema.Scalar(dtype=str)))
)
self.assertEquals(s['field2'], s.field2)
self.assertEquals(s['field2'], schema.List(schema.Scalar(dtype=str)))
self.assertEquals(
s['field2', 'field1'],
schema.Struct(
('field2', schema.List(schema.Scalar(dtype=str))),
('field1', schema.Scalar(dtype=np.int32)),
)
)
def testPreservesMetadata(self):
s = schema.Struct(
('a', schema.Scalar(np.float32)), (
'b', schema.Scalar(
np.int32,
metadata=schema.Metadata(categorical_limit=5)
)
), (
'c', schema.List(
schema.Scalar(
np.int32,
metadata=schema.Metadata(categorical_limit=6)
)
)
)
)
# attach metadata to lengths field
s.c.lengths.set_metadata(schema.Metadata(categorical_limit=7))
self.assertEqual(None, s.a.metadata)
self.assertEqual(5, s.b.metadata.categorical_limit)
self.assertEqual(6, s.c.value.metadata.categorical_limit)
self.assertEqual(7, s.c.lengths.metadata.categorical_limit)
sc = s.clone()
self.assertEqual(None, sc.a.metadata)
self.assertEqual(5, sc.b.metadata.categorical_limit)
self.assertEqual(6, sc.c.value.metadata.categorical_limit)
self.assertEqual(7, sc.c.lengths.metadata.categorical_limit)
sv = schema.from_blob_list(
s, [
np.array([3.4]), np.array([2]), np.array([3]),
np.array([1, 2, 3])
]
)
self.assertEqual(None, sv.a.metadata)
self.assertEqual(5, sv.b.metadata.categorical_limit)
self.assertEqual(6, sv.c.value.metadata.categorical_limit)
self.assertEqual(7, sv.c.lengths.metadata.categorical_limit)
def testDupField(self):
with self.assertRaises(ValueError):
schema.Struct(
('a', schema.Scalar()),
('a', schema.Scalar()))
def testPreservesEmptyFields(self):
s = schema.Struct(
('a', schema.Scalar(np.float32)),
('b', schema.Struct()),
)
sc = s.clone()
self.assertIn("a", sc.fields)
self.assertIn("b", sc.fields)
sv = schema.from_blob_list(s, [np.array([3.4])])
self.assertIn("a", sv.fields)
self.assertIn("b", sv.fields)
self.assertEqual(0, len(sv.b.fields))
def testStructAddition(self):
s1 = schema.Struct(
('a', schema.Scalar())
)
s2 = schema.Struct(
('b', schema.Scalar())
)
s = s1 + s2
self.assertIn("a", s.fields)
self.assertIn("b", s.fields)
with self.assertRaises(ValueError):
s1 + s1
with self.assertRaises(TypeError):
s1 + schema.Scalar()