| #!/usr/bin/python3 |
| |
| # Copyright 2018, The Android Open Source Project |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| """Example generator |
| |
| Compiles spec files and generates the corresponding C++ TestModel definitions. |
| Invoked by ml/nn/runtime/test/specs/generate_all_tests.sh; |
| See that script for details on how this script is used. |
| |
| """ |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| import os |
| import sys |
| import traceback |
| |
| import test_generator as tg |
| |
| # See ToCpp() |
| COMMENT_KEY = "__COMMENT__" |
| |
| # Take a model from command line |
| def ParseCmdLine(): |
| parser = tg.ArgumentParser() |
| parser.add_argument("-e", "--example", help="the output example file or directory") |
| args = tg.ParseArgs(parser) |
| tg.FileNames.InitializeFileLists(args.spec, args.example) |
| |
| # Write headers for generated files, which are boilerplate codes only related to filenames |
| def InitializeFiles(example_fd): |
| specFileBase = os.path.basename(tg.FileNames.specFile) |
| fileHeader = """\ |
| // Generated from {spec_file} |
| // DO NOT EDIT |
| // clang-format off |
| #include "TestHarness.h" |
| using namespace test_helper; |
| """ |
| if example_fd is not None: |
| print(fileHeader.format(spec_file=specFileBase), file=example_fd) |
| |
| def IndentedStr(s, indent): |
| return ("\n" + " " * indent).join(s.split('\n')) |
| |
| def ToCpp(var, indent=0): |
| """Get the C++-style representation of a Python object. |
| |
| For Python dictionary, it will be mapped to C++ struct aggregate initialization: |
| { |
| .key0 = value0, |
| .key1 = value1, |
| ... |
| } |
| |
| For Python list, it will be mapped to C++ list initalization: |
| {value0, value1, ...} |
| |
| In both cases, value0, value1, ... are stringified by invoking this method recursively. |
| """ |
| if isinstance(var, dict): |
| if not var: |
| return "{}" |
| comment = var.get(COMMENT_KEY) |
| comment = "" if comment is None else " // %s" % comment |
| str_pair = lambda k, v: " .%s = %s" % (k, ToCpp(v, indent + 4)) |
| agg_init = "{%s\n%s\n}" % (comment, |
| ",\n".join(str_pair(k, var[k]) |
| for k in sorted(var.keys()) |
| if k != COMMENT_KEY)) |
| return IndentedStr(agg_init, indent) |
| elif isinstance(var, (list, tuple)): |
| return "{%s}" % (", ".join(ToCpp(i, indent) for i in var)) |
| elif type(var) is bool: |
| return "true" if var else "false" |
| elif type(var) is float: |
| return tg.PrettyPrintAsFloat(var) |
| else: |
| return str(var) |
| |
| def GetSymmPerChannelQuantParams(extraParams): |
| """Get the dictionary that corresponds to test_helper::TestSymmPerChannelQuantParams.""" |
| if extraParams is None or extraParams.hide: |
| return {} |
| else: |
| return {"scales": extraParams.scales, "channelDim": extraParams.channelDim} |
| |
| def GetOperandStruct(operand): |
| """Get the dictionary that corresponds to test_helper::TestOperand.""" |
| return { |
| COMMENT_KEY: operand.name, |
| "type": "TestOperandType::" + operand.type.type, |
| "dimensions": operand.type.dimensions, |
| "scale": operand.type.scale, |
| "zeroPoint": operand.type.zeroPoint, |
| "numberOfConsumers": len(operand.outs), |
| "lifetime": "TestOperandLifeTime::" + operand.lifetime, |
| "channelQuant": GetSymmPerChannelQuantParams(operand.type.extraParams), |
| "isIgnored": isinstance(operand, tg.IgnoredOutput), |
| "data": "TestBuffer::createFromVector<{cpp_type}>({data})".format( |
| cpp_type=operand.type.GetCppTypeString(), |
| data=operand.GetListInitialization(), |
| ) |
| } |
| |
| def GetOperationStruct(operation): |
| """Get the dictionary that corresponds to test_helper::TestOperation.""" |
| return { |
| "type": "TestOperationType::" + operation.optype, |
| "inputs": [op.model_index for op in operation.ins], |
| "outputs": [op.model_index for op in operation.outs], |
| } |
| |
| def GetSubgraphStruct(subgraph): |
| """Get the dictionary that corresponds to test_helper::TestSubgraph.""" |
| return { |
| COMMENT_KEY: subgraph.name, |
| "operands": [GetOperandStruct(op) for op in subgraph.operands], |
| "operations": [GetOperationStruct(op) for op in subgraph.operations], |
| "inputIndexes": [op.model_index for op in subgraph.GetInputs()], |
| "outputIndexes": [op.model_index for op in subgraph.GetOutputs()], |
| } |
| |
| def GetModelStruct(example): |
| """Get the dictionary that corresponds to test_helper::TestModel.""" |
| return { |
| "main": GetSubgraphStruct(example.model), |
| "referenced": [GetSubgraphStruct(model) for model in example.model.GetReferencedModels()], |
| "isRelaxed": example.model.isRelaxed, |
| "expectedMultinomialDistributionTolerance": |
| example.expectedMultinomialDistributionTolerance, |
| "expectFailure": example.expectFailure, |
| "minSupportedVersion": "TestHalVersion::%s" % ( |
| example.model.version if example.model.version is not None else "UNKNOWN"), |
| } |
| |
| def DumpExample(example, example_fd): |
| assert example.model.compiled |
| template = """\ |
| namespace generated_tests::{spec_name} {{ |
| |
| const TestModel& get_{example_name}() {{ |
| static TestModel model = {aggregate_init}; |
| return model; |
| }} |
| |
| const auto dummy_{example_name} = TestModelManager::get().add("{test_name}", get_{example_name}()); |
| |
| }} // namespace generated_tests::{spec_name} |
| """ |
| print(template.format( |
| spec_name=tg.FileNames.specName, |
| test_name=str(example.testName), |
| example_name=str(example.examplesName), |
| aggregate_init=ToCpp(GetModelStruct(example), indent=4), |
| ), file=example_fd) |
| |
| |
| if __name__ == '__main__': |
| ParseCmdLine() |
| tg.Run(InitializeFiles=InitializeFiles, DumpExample=DumpExample) |