blob: 09c689db0c546b629d3af9bc29b098d5a3be8bc9 [file] [log] [blame]
#!/usr/bin/python3
# Copyright 2017, The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VTS testcase generator
Implements VTS test backend. Shares most logic with the CTS test
generator. Invoked by ml/nn/runtime/test/specs/generate_vts_tests.sh;
See that script for details on how this script is used.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from functools import reduce
import math
import numpy as np
import os
import re
import struct
import sys
import contextlib
import pprint
# Stuff from test generator
import test_generator as tg
from test_generator import ActivationConverter
from test_generator import BoolScalar
from test_generator import Configuration
from test_generator import DataTypeConverter
from test_generator import DataLayoutConverter
from test_generator import Example
from test_generator import Float16Scalar
from test_generator import Float32Scalar
from test_generator import Float32Vector
from test_generator import IgnoredOutput
from test_generator import Input
from test_generator import Int32Scalar
from test_generator import Int32Vector
from test_generator import Internal
from test_generator import Model
from test_generator import Operand
from test_generator import Output
from test_generator import Parameter
from test_generator import ParameterAsInputConverter
from test_generator import RelaxedModeConverter
from test_generator import SmartOpen
from test_generator import SymmPerChannelQuantParams
# Dumping methods that shared with CTS generator
from cts_generator import DumpCtsExample
from cts_generator import DumpCtsIsIgnored
# TODO: Make this part of tg.Configuration?
target_hal_version = None
# Take a model from command line
def ParseCmdLine():
parser = argparse.ArgumentParser()
parser.add_argument("spec", help="the spec file")
parser.add_argument(
"-m", "--model", help="the output model file", default="-")
parser.add_argument(
"-e", "--example", help="the output example file", default="-")
parser.add_argument(
"-t", "--test", help="the output test file", default="-")
parser.add_argument(
"--target_hal_version",
help="the HAL version of the output",
required=True,
choices=["V1_0", "V1_1", "V1_2"])
args = parser.parse_args()
tg.FileNames.InitializeFileLists(
args.spec, args.model, args.example, args.test)
global target_hal_version
target_hal_version = args.target_hal_version
# Generate operands in VTS format
def generate_vts_operands(model):
# Dump operand definitions
op_def = """\
{{
.type = OperandType::{operand_type},
.dimensions = {shape},
.numberOfConsumers = {no_consumers},
.scale = {scale},
.zeroPoint = {zero_point},
.lifetime = OperandLifeTime::{lifetime},
.location = {{.poolIndex = 0, .offset = {offset}, .length = {length}}},{extraParams}
}}"""
offset = 0
op_definitions = []
extra_params_definitions = []
for index, o in enumerate(model.operands):
length = o.type.GetByteSize() if isinstance(o, Parameter) else 0
add_extra_params = o.type.extraParams is not None and not o.type.extraParams.hide
op = {
"operand_type": o.type.type,
"shape": o.type.GetDimensionsString(),
"no_consumers": len(o.outs),
"scale": tg.PrettyPrintAsFloat(o.type.scale),
"zero_point": str(int(o.type.zeroPoint)),
"lifetime": o.lifetime,
"offset": offset if isinstance(o, Parameter) else 0,
"length": length,
"extraParams": "" if not add_extra_params else "\n .extraParams = std::move(extraParams%d)," % (index,),
}
offset += length
op_definitions.append(op_def.format(**op))
extra_params_def = """\
Operand::ExtraParams extraParams{index};
extraParams{index}.{setMethodName}({param});
"""
if add_extra_params:
ep = o.type.extraParams
op = {
"index": index,
"setMethodName": ep.GetVtsSetter(),
"param": ep.GetVtsConstructor(),
}
extra_params_definitions.append(extra_params_def.format(**op))
op_vec = """{0}\
const std::vector<Operand> operands = {{
{1}
}};""".format(",\n".join(extra_params_definitions), ",\n".join(op_definitions))
return op_vec
# Generate VTS operand values
def generate_vts_operand_values(operands):
weights = [o for o in operands if isinstance(o, Parameter)]
binit = []
for w in weights:
ty = w.type.type
if ty == "TENSOR_QUANT8_ASYMM":
binit += w.value
elif ty == "TENSOR_QUANT8_SYMM_PER_CHANNEL" or ty == "TENSOR_QUANT8_SYMM":
binit += [struct.pack("b", value)[0] for value in w.value]
elif ty == "BOOL" or ty == "TENSOR_BOOL8":
binit += [1 if x else 0 for x in w.value]
elif ty == "TENSOR_FLOAT16" or ty == "FLOAT16":
for f in w.value:
# The pack format for float16 is not available until Python 3.6.
binit += [int(x) for x in np.float16(f).tostring()]
elif ty in {"TENSOR_FLOAT32", "FLOAT32", "TENSOR_INT32", "INT32", "TENSOR_QUANT16_ASYMM"}:
if ty in ["TENSOR_FLOAT32", "FLOAT32"]:
fmt = "f"
elif ty in ["TENSOR_INT32", "INT32"]:
fmt = "i"
elif ty == "TENSOR_QUANT16_ASYMM":
fmt = "H"
for f in w.value:
binit += [int(x) for x in struct.pack(fmt, f)]
else:
assert 0 and "Unsupported VTS operand type"
init_defs = ", ".join([str(x) for x in binit])
if (init_defs != ""):
init_defs = "\n %s\n " % init_defs
byte_vec_fmt = """{%s}""" % init_defs
return byte_vec_fmt
# Generate VTS operations
def generate_vts_operation(op, model):
op_fmt = """\
{{
.type = OperationType::{op_code},
.inputs = {{{ins}}},
.outputs = {{{outs}}},
}}"""
op_content = {
'op_code': op.optype,
'ins': tg.GetJointStr(model.GetIndexOfOperands(op.ins)),
'outs': tg.GetJointStr(model.GetIndexOfOperands(op.outs))
}
return op_fmt.format(**op_content)
def generate_vts_operations(model):
vts_ops = [generate_vts_operation(op, model) for op in model.operations]
return ",\n".join(vts_ops)
def generate_vts_model(model, model_file):
operand_values_fmt = ""
if Configuration.useSHM():
# Boilerplate code for passing weights in shared memory
operand_values_fmt = """\
std::vector<uint8_t> operandValues = {{}};
const uint8_t data[] = {operand_values};
// Allocate segment of android shared memory, wrapped in hidl_memory.
// This object will be automatically freed when sharedMemory is destroyed.
hidl_memory sharedMemory = allocateSharedMemory(sizeof(data));
// Mmap ashmem into usable address and hold it within the mappedMemory object.
// MappedMemory will automatically munmap the memory when it is destroyed.
sp<IMemory> mappedMemory = mapMemory(sharedMemory);
if (mappedMemory != nullptr) {{
// Retrieve the mmapped pointer.
uint8_t* mappedPointer =
static_cast<uint8_t*>(static_cast<void*>(mappedMemory->getPointer()));
if (mappedPointer != nullptr) {{
// Acquire the write lock for the shared memory segment, upload the data,
// and release the lock.
mappedMemory->update();
std::copy(data, data + sizeof(data), mappedPointer);
mappedMemory->commit();
}}
}}
const std::vector<hidl_memory> pools = {{sharedMemory}};
"""
else:
# Passing weights via operandValues
operand_values_fmt = """\
std::vector<uint8_t> operandValues = {operand_values};
const std::vector<hidl_memory> pools = {{}};
"""
operand_values_val = {
'operand_values': generate_vts_operand_values(model.operands)
}
operand_values = operand_values_fmt.format(**operand_values_val)
# operand_values = operand_values_fmt
model_fmt = """\
// Create the model
Model {create_test_model_name}() {{
{operand_decls}
const std::vector<Operation> operations = {{
{operations}
}};
const std::vector<uint32_t> inputIndexes = {{{input_indices}}};
const std::vector<uint32_t> outputIndexes = {{{output_indices}}};
{operand_values}
return {{
.operands = operands,
.operations = operations,
.inputIndexes = inputIndexes,
.outputIndexes = outputIndexes,
.operandValues = operandValues,
.pools = pools,{relaxed_field}
}};
}}
"""
model_dict = {
"hal_version": target_hal_version,
"create_test_model_name": str(model.createTestFunctionName),
"operations": generate_vts_operations(model),
"operand_decls": generate_vts_operands(model),
"operand_values": operand_values,
"output_indices": tg.GetJointStr(model.GetOutputsIndex()),
"input_indices": tg.GetJointStr(model.GetInputsIndex()),
"relaxed_field":
"\n .relaxComputationFloat32toFloat16 = true," if (model.isRelaxed) else ""
}
print(model_fmt.format(**model_dict), file = model_file)
def generate_vts(model, model_file):
assert model.compiled
generate_vts_model(model, model_file)
DumpCtsIsIgnored(model, model_file)
def generate_vts_test(example, test_file):
testTemplate = """\
TEST_F({test_case_name}, {test_name}) {{
generated_tests::Execute(device,
{namespace}::{create_model_name},
{namespace}::{is_ignored_name},
{namespace}::get_{examples_name}(){test_dynamic_output_shape});\n}}
TEST_F(ValidationTest, {test_name}) {{
const Model model = {namespace}::{create_model_name}();
const std::vector<Request> requests = createRequests({namespace}::get_{examples_name}());
validateEverything(model, requests);
}}\n
"""
if example.model.hasDynamicOutputShape:
print("#ifdef NN_TEST_DYNAMIC_OUTPUT_SHAPE", file=test_fd)
print(testTemplate.format(
test_case_name="DynamicOutputShapeTest" if example.model.hasDynamicOutputShape \
else "NeuralnetworksHidlTest",
test_name=str(example.testName),
namespace=tg.FileNames.specName,
create_model_name=str(example.model.createTestFunctionName),
is_ignored_name=str(example.model.isIgnoredFunctionName),
examples_name=str(example.examplesName),
test_dynamic_output_shape=", true" if example.model.hasDynamicOutputShape else ""
), file=test_fd)
if example.model.hasDynamicOutputShape:
print("#endif", file=test_fd)
def InitializeFiles(model_fd, example_fd, test_fd):
fileHeader = "// clang-format off\n// Generated file (from: {spec_file}). Do not edit"
testFileHeader = """\
// Generated from: {spec_file}.
namespace {spec_name} {{
// Generated {spec_name} test
#include "{example_file}"
// Generated model constructor
#include "{model_file}"
}} // namespace {spec_name}\n"""
# This regex is to remove prefix and get relative path for #include
pathRegex = r".*frameworks/ml/nn/(runtime/test/generated/)?"
specFileBase = os.path.basename(tg.FileNames.specFile)
print(fileHeader.format(spec_file=specFileBase), file=model_fd)
print(fileHeader.format(spec_file=specFileBase), file=example_fd)
print(testFileHeader.format(
spec_file=specFileBase,
model_file=re.sub(pathRegex, "", tg.FileNames.modelFile),
example_file=re.sub(pathRegex, "", tg.FileNames.exampleFile),
spec_name=tg.FileNames.specName), file=test_fd)
if __name__ == "__main__":
ParseCmdLine()
while tg.FileNames.NextFile():
print("Generating test(s) from spec: %s" % tg.FileNames.specFile, file=sys.stderr)
exec (open(tg.FileNames.specFile, "r").read())
print("Output VTS model: %s" % tg.FileNames.modelFile, file=sys.stderr)
print("Output example: " + tg.FileNames.exampleFile, file=sys.stderr)
with SmartOpen(tg.FileNames.modelFile) as model_fd, \
SmartOpen(tg.FileNames.exampleFile) as example_fd, \
SmartOpen(tg.FileNames.testFile, mode="a") as test_fd:
InitializeFiles(model_fd, example_fd, test_fd)
Example.DumpAllExamples(
DumpModel=generate_vts, model_fd=model_fd,
DumpExample=DumpCtsExample, example_fd=example_fd,
DumpTest=generate_vts_test, test_fd=test_fd)