blob: 6fc592c8e759016c21fba5157f6e805b76fe0bed [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import difflib
import json
import unittest
from pprint import pformat
from typing import List
import executorch.sdk.etdump.schema_flatcc as flatcc
from executorch.exir._serialize._dataclass import _DataclassEncoder
from executorch.sdk.etdump.schema import (
AllocationEvent,
Allocator,
DebugBlock,
DebugEvent,
ETDump,
ProfileBlock,
ProfileEvent,
RunData,
ScalarType,
Tensor,
Value,
)
from executorch.sdk.etdump.serialize import (
deserialize_from_etdump,
deserialize_from_etdump_flatcc,
serialize_to_etdump,
serialize_to_etdump_flatcc,
)
def diff_jsons(a: str, b: str) -> List[str]:
data_a = json.loads(a)
data_b = json.loads(b)
return list(
difflib.unified_diff(pformat(data_a).splitlines(), pformat(data_b).splitlines())
)
def get_sample_etdump() -> ETDump:
return ETDump(
version=0,
run_data=[
RunData(
debug_blocks=[
DebugBlock(
name="test_debug_block",
debug_events=[
DebugEvent(
debug_handle=0,
debug_entries=[
Value(
val=Tensor(
scalar_type=ScalarType.FLOAT,
sizes=[1, 1],
strides=[1, 1],
data=b"datadump",
),
)
],
)
],
)
],
profile_blocks=[
ProfileBlock(
name="test_profile_block",
allocators=[Allocator("test_allocator")],
profile_events=[
ProfileEvent(
name="test_profile_event",
debug_handle=1,
start_time=1001,
end_time=2002,
)
],
allocation_events=[
AllocationEvent(
allocator_id=1,
allocation_size=8,
)
],
)
],
)
],
)
def get_sample_etdump_flatcc() -> flatcc.ETDumpFlatCC:
return flatcc.ETDumpFlatCC(
version=0,
run_data=[
flatcc.RunData(
name="test_block",
allocators=[
flatcc.Allocator(
name="test_allocator",
)
],
events=[
flatcc.Event(
profile_event=flatcc.ProfileEvent(
name="test_profile_event",
chain_id=1,
instruction_id=1,
delegate_debug_id_str="",
delegate_debug_id_int=-1,
delegate_debug_metadata="",
start_time=1001,
end_time=2002,
),
allocation_event=None,
debug_event=None,
),
flatcc.Event(
profile_event=flatcc.ProfileEvent(
name="test_profile_event_delegated",
chain_id=1,
instruction_id=1,
delegate_debug_id_str="",
delegate_debug_id_int=13,
delegate_debug_metadata="",
start_time=1001,
end_time=2002,
),
allocation_event=None,
debug_event=None,
),
flatcc.Event(
profile_event=None,
allocation_event=flatcc.AllocationEvent(
allocator_id=1,
allocation_size=8,
),
debug_event=None,
),
flatcc.Event(
profile_event=None,
allocation_event=None,
debug_event=flatcc.DebugEvent(
chain_idx=1,
debug_handle=0,
debug_entries=[
flatcc.Value(
val=flatcc.ValueType.TENSOR.value,
offset=12345,
)
],
),
),
],
)
],
)
class TestSerialize(unittest.TestCase):
def test_serialize(self) -> None:
program = get_sample_etdump()
flatbuffer_from_py = serialize_to_etdump(program)
deserialized_obj = deserialize_from_etdump(flatbuffer_from_py)
self.assertEqual(
program,
deserialized_obj,
msg="\n".join(
diff_jsons(
json.dumps(program, cls=_DataclassEncoder, indent=4),
json.dumps(deserialized_obj, cls=_DataclassEncoder, indent=4),
)
),
)
class TestSerializeFlatCC(unittest.TestCase):
def test_serialize(self) -> None:
import json
program = get_sample_etdump_flatcc()
flatcc_from_py = serialize_to_etdump_flatcc(program)
deserialized_obj = deserialize_from_etdump_flatcc(
flatcc_from_py, size_prefixed=False
)
self.assertEqual(
program,
deserialized_obj,
msg="\n".join(
diff_jsons(
json.dumps(program, cls=_DataclassEncoder, indent=4),
json.dumps(deserialized_obj, cls=_DataclassEncoder, indent=4),
)
),
)