blob: 51702883b0642bd66c740142fe3acce07785ea81 [file] [log] [blame]
#!/usr/bin/python3
# Copyright 2018, The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example generator
Compiles spec files and generates the corresponding C++ TestModel definitions.
Invoked by ml/nn/runtime/test/specs/generate_all_tests.sh;
See that script for details on how this script is used.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import traceback
import test_generator as tg
# See ToCpp()
COMMENT_KEY = "__COMMENT__"
# Take a model from command line
def ParseCmdLine():
parser = tg.ArgumentParser()
parser.add_argument("-e", "--example", help="the output example file or directory")
args = tg.ParseArgs(parser)
tg.FileNames.InitializeFileLists(args.spec, args.example)
# Write headers for generated files, which are boilerplate codes only related to filenames
def InitializeFiles(example_fd):
specFileBase = os.path.basename(tg.FileNames.specFile)
fileHeader = """\
// Generated from {spec_file}
// DO NOT EDIT
// clang-format off
#include "TestHarness.h"
using namespace test_helper;
"""
if example_fd is not None:
print(fileHeader.format(spec_file=specFileBase), file=example_fd)
def IndentedStr(s, indent):
return ("\n" + " " * indent).join(s.split('\n'))
def ToCpp(var, indent=0):
"""Get the C++-style representation of a Python object.
For Python dictionary, it will be mapped to C++ struct aggregate initialization:
{
.key0 = value0,
.key1 = value1,
...
}
For Python list, it will be mapped to C++ list initalization:
{value0, value1, ...}
In both cases, value0, value1, ... are stringified by invoking this method recursively.
"""
if isinstance(var, dict):
if not var:
return "{}"
comment = var.get(COMMENT_KEY)
comment = "" if comment is None else " // %s" % comment
str_pair = lambda k, v: " .%s = %s" % (k, ToCpp(v, indent + 4))
agg_init = "{%s\n%s\n}" % (comment,
",\n".join(str_pair(k, var[k])
for k in sorted(var.keys())
if k != COMMENT_KEY))
return IndentedStr(agg_init, indent)
elif isinstance(var, (list, tuple)):
return "{%s}" % (", ".join(ToCpp(i, indent) for i in var))
elif type(var) is bool:
return "true" if var else "false"
elif type(var) is float:
return tg.PrettyPrintAsFloat(var)
else:
return str(var)
def GetSymmPerChannelQuantParams(extraParams):
"""Get the dictionary that corresponds to test_helper::TestSymmPerChannelQuantParams."""
if extraParams is None or extraParams.hide:
return {}
else:
return {"scales": extraParams.scales, "channelDim": extraParams.channelDim}
def GetOperandStruct(operand):
"""Get the dictionary that corresponds to test_helper::TestOperand."""
return {
COMMENT_KEY: operand.name,
"type": "TestOperandType::" + operand.type.type,
"dimensions": operand.type.dimensions,
"scale": operand.type.scale,
"zeroPoint": operand.type.zeroPoint,
"numberOfConsumers": len(operand.outs),
"lifetime": "TestOperandLifeTime::" + operand.lifetime,
"channelQuant": GetSymmPerChannelQuantParams(operand.type.extraParams),
"isIgnored": isinstance(operand, tg.IgnoredOutput),
"data": "TestBuffer::createFromVector<{cpp_type}>({data})".format(
cpp_type=operand.type.GetCppTypeString(),
data=operand.GetListInitialization(),
)
}
def GetOperationStruct(operation):
"""Get the dictionary that corresponds to test_helper::TestOperation."""
return {
"type": "TestOperationType::" + operation.optype,
"inputs": [op.model_index for op in operation.ins],
"outputs": [op.model_index for op in operation.outs],
}
def GetSubgraphStruct(subgraph):
"""Get the dictionary that corresponds to test_helper::TestSubgraph."""
return {
COMMENT_KEY: subgraph.name,
"operands": [GetOperandStruct(op) for op in subgraph.operands],
"operations": [GetOperationStruct(op) for op in subgraph.operations],
"inputIndexes": [op.model_index for op in subgraph.GetInputs()],
"outputIndexes": [op.model_index for op in subgraph.GetOutputs()],
}
def GetModelStruct(example):
"""Get the dictionary that corresponds to test_helper::TestModel."""
return {
"main": GetSubgraphStruct(example.model),
"referenced": [GetSubgraphStruct(model) for model in example.model.GetReferencedModels()],
"isRelaxed": example.model.isRelaxed,
"expectedMultinomialDistributionTolerance":
example.expectedMultinomialDistributionTolerance,
"expectFailure": example.expectFailure,
"minSupportedVersion": "TestHalVersion::%s" % (
example.model.version if example.model.version is not None else "UNKNOWN"),
}
def DumpExample(example, example_fd):
assert example.model.compiled
template = """\
namespace generated_tests::{spec_name} {{
const TestModel& get_{example_name}() {{
static TestModel model = {aggregate_init};
return model;
}}
const auto dummy_{example_name} = TestModelManager::get().add("{test_name}", get_{example_name}());
}} // namespace generated_tests::{spec_name}
"""
print(template.format(
spec_name=tg.FileNames.specName,
test_name=str(example.testName),
example_name=str(example.examplesName),
aggregate_init=ToCpp(GetModelStruct(example), indent=4),
), file=example_fd)
if __name__ == '__main__':
ParseCmdLine()
tg.Run(InitializeFiles=InitializeFiles, DumpExample=DumpExample)