blob: 2e234d46f3408898417e6e3dbab4cadb8ff0ae67 [file] [log] [blame]
from caffe2.proto import caffe2_pb2
from google.protobuf.message import DecodeError, Message
from google.protobuf import text_format
import collections
import numpy as np
import sys
if sys.version_info > (3,):
# This is python 3. We will define a few stuff that we used.
basestring = str
long = int
def CaffeBlobToNumpyArray(blob):
if (blob.num != 0):
# old style caffe blob.
return (np.asarray(blob.data, dtype=np.float32)
.reshape(blob.num, blob.channels, blob.height, blob.width))
else:
# new style caffe blob.
return (np.asarray(blob.data, dtype=np.float32)
.reshape(blob.shape.dim))
def Caffe2TensorToNumpyArray(tensor):
return np.asarray(tensor.float_data, dtype=np.float32).reshape(tensor.dims)
def NumpyArrayToCaffe2Tensor(arr, name):
tensor = caffe2_pb2.TensorProto()
tensor.data_type = caffe2_pb2.TensorProto.FLOAT
tensor.name = name
tensor.dims.extend(arr.shape)
tensor.float_data.extend(list(arr.flatten().astype(float)))
return tensor
def MakeArgument(key, value):
"""Makes an argument based on the value type."""
argument = caffe2_pb2.Argument()
argument.name = key
iterable = isinstance(value, collections.Iterable)
if isinstance(value, np.ndarray):
value = value.flatten().tolist()
elif isinstance(value, np.generic):
# convert numpy scalar to native python type
value = np.asscalar(value)
if type(value) is float:
argument.f = value
elif type(value) is int or type(value) is bool or type(value) is long:
# We make a relaxation that a boolean variable will also be stored as
# int.
argument.i = value
elif isinstance(value, basestring):
argument.s = (value if type(value) is bytes
else value.encode('utf-8'))
elif isinstance(value, Message):
argument.s = value.SerializeToString()
elif iterable and all(type(v) in [float, np.float_] for v in value):
argument.floats.extend(value)
elif iterable and all(type(v) in [int, bool, long, np.int_] for v in value):
argument.ints.extend(value)
elif iterable and all(isinstance(v, basestring) for v in value):
argument.strings.extend([
(v if type(v) is bytes else v.encode('utf-8')) for v in value])
elif iterable and all(isinstance(v, Message) for v in value):
argument.strings.extend([v.SerializeToString() for v in value])
else:
raise ValueError(
"Unknown argument type: key=%s value=%s, value type=%s" %
(key, str(value), str(type(value)))
)
return argument
def TryReadProtoWithClass(cls, s):
"""Reads a protobuffer with the given proto class.
Inputs:
cls: a protobuffer class.
s: a string of either binary or text protobuffer content.
Outputs:
proto: the protobuffer of cls
Throws:
google.protobuf.message.DecodeError: if we cannot decode the message.
"""
obj = cls()
try:
text_format.Parse(s, obj)
return obj
except text_format.ParseError:
obj.ParseFromString(s)
return obj
def GetContentFromProto(obj, function_map):
"""Gets a specific field from a protocol buffer that matches the given class
"""
for cls, func in function_map.items():
if type(obj) is cls:
return func(obj)
def GetContentFromProtoString(s, function_map):
for cls, func in function_map.items():
try:
obj = TryReadProtoWithClass(cls, s)
return func(obj)
except DecodeError:
continue
else:
raise DecodeError("Cannot find a fit protobuffer class.")
def ConvertProtoToBinary(proto_class, filename, out_filename):
"""Convert a text file of the given protobuf class to binary."""
proto = TryReadProtoWithClass(proto_class, open(filename).read())
with open(out_filename, 'w') as fid:
fid.write(proto.SerializeToString())