| # Copyright 2014 The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """A simple module for declaring C-like structures. |
| |
| Example usage: |
| |
| >>> # Declare a struct type by specifying name, field formats and field names. |
| ... # Field formats are the same as those used in the struct module. |
| ... import cstruct |
| >>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid") |
| >>> |
| >>> |
| >>> # Create instances from tuples or raw bytes. Data past the end is ignored. |
| ... n1 = NLMsgHdr((44, 32, 0x2, 0, 491)) |
| >>> print n1 |
| NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491) |
| >>> |
| >>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00" |
| ... "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end") |
| >>> print n2 |
| NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510) |
| >>> |
| >>> # Serialize to raw bytes. |
| ... print n1.Pack().encode("hex") |
| 2c0000002000020000000000eb010000 |
| >>> |
| >>> # Parse the beginning of a byte stream as a struct, and return the struct |
| ... # and the remainder of the stream for further reading. |
| ... data = ("\x2c\x00\x00\x00\x21\x00\x02\x00" |
| ... "\x00\x00\x00\x00\xfe\x01\x00\x00" |
| ... "more data") |
| >>> cstruct.Read(data, NLMsgHdr) |
| (NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510), 'more data') |
| >>> |
| """ |
| |
| import ctypes |
| import string |
| import struct |
| |
| |
| def CalcNumElements(fmt): |
| size = struct.calcsize(fmt) |
| elements = struct.unpack(fmt, "\x00" * size) |
| return len(elements) |
| |
| |
| def Struct(name, fmt, fieldnames, substructs={}): |
| """Function that returns struct classes.""" |
| |
| class Meta(type): |
| |
| def __len__(cls): |
| return cls._length |
| |
| def __init__(cls, unused_name, unused_bases, namespace): |
| # Make the class object have the name that's passed in. |
| type.__init__(cls, namespace["_name"], unused_bases, namespace) |
| |
| class CStruct(object): |
| """Class representing a C-like structure.""" |
| |
| __metaclass__ = Meta |
| |
| # Name of the struct. |
| _name = name |
| # List of field names. |
| _fieldnames = fieldnames |
| # Dict mapping field indices to nested struct classes. |
| _nested = {} |
| |
| if isinstance(_fieldnames, str): |
| _fieldnames = _fieldnames.split(" ") |
| |
| # Parse fmt into _format, converting any S format characters to "XXs", |
| # where XX is the length of the struct type's packed representation. |
| _format = "" |
| laststructindex = 0 |
| for i in xrange(len(fmt)): |
| if fmt[i] == "S": |
| # Nested struct. Record the index in our struct it should go into. |
| index = CalcNumElements(fmt[:i]) |
| _nested[index] = substructs[laststructindex] |
| laststructindex += 1 |
| _format += "%ds" % len(_nested[index]) |
| else: |
| # Standard struct format character. |
| _format += fmt[i] |
| |
| _length = struct.calcsize(_format) |
| |
| def _SetValues(self, values): |
| super(CStruct, self).__setattr__("_values", list(values)) |
| |
| def _Parse(self, data): |
| data = data[:self._length] |
| values = list(struct.unpack(self._format, data)) |
| for index, value in enumerate(values): |
| if isinstance(value, str) and index in self._nested: |
| values[index] = self._nested[index](value) |
| self._SetValues(values) |
| |
| def __init__(self, values): |
| # Initializing from a string. |
| if isinstance(values, str): |
| if len(values) < self._length: |
| raise TypeError("%s requires string of length %d, got %d" % |
| (self._name, self._length, len(values))) |
| self._Parse(values) |
| else: |
| # Initializing from a tuple. |
| if len(values) != len(self._fieldnames): |
| raise TypeError("%s has exactly %d fieldnames (%d given)" % |
| (self._name, len(self._fieldnames), len(values))) |
| self._SetValues(values) |
| |
| def _FieldIndex(self, attr): |
| try: |
| return self._fieldnames.index(attr) |
| except ValueError: |
| raise AttributeError("'%s' has no attribute '%s'" % |
| (self._name, attr)) |
| |
| def __getattr__(self, name): |
| return self._values[self._FieldIndex(name)] |
| |
| def __setattr__(self, name, value): |
| self._values[self._FieldIndex(name)] = value |
| |
| @classmethod |
| def __len__(cls): |
| return cls._length |
| |
| def __ne__(self, other): |
| return not self.__eq__(other) |
| |
| def __eq__(self, other): |
| return (isinstance(other, self.__class__) and |
| self._name == other._name and |
| self._fieldnames == other._fieldnames and |
| self._values == other._values) |
| |
| @staticmethod |
| def _MaybePackStruct(value): |
| if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta: |
| return value.Pack() |
| else: |
| return value |
| |
| def Pack(self): |
| values = [self._MaybePackStruct(v) for v in self._values] |
| return struct.pack(self._format, *values) |
| |
| def __str__(self): |
| def FieldDesc(index, name, value): |
| if isinstance(value, str) and any( |
| c not in string.printable for c in value): |
| value = value.encode("hex") |
| return "%s=%s" % (name, value) |
| |
| descriptions = [ |
| FieldDesc(i, n, v) for i, (n, v) in |
| enumerate(zip(self._fieldnames, self._values))] |
| |
| return "%s(%s)" % (self._name, ", ".join(descriptions)) |
| |
| def __repr__(self): |
| return str(self) |
| |
| def CPointer(self): |
| """Returns a C pointer to the serialized structure.""" |
| buf = ctypes.create_string_buffer(self.Pack()) |
| # Store the C buffer in the object so it doesn't get garbage collected. |
| super(CStruct, self).__setattr__("_buffer", buf) |
| return ctypes.addressof(self._buffer) |
| |
| return CStruct |
| |
| |
| def Read(data, struct_type): |
| length = len(struct_type) |
| return struct_type(data), data[length:] |