| #!/usr/bin/env python3 |
| # -*- coding: utf-8 -*- |
| |
| # Copyright 2019 The MLIR Authors. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| # Script for updating SPIR-V dialect by scraping information from SPIR-V |
| # HTML and JSON specs from the Internet. |
| # |
| # For example, to define the enum attribute for SPIR-V memory model: |
| # |
| # ./gen_spirv_dialect.py --base_td_path /path/to/SPIRVBase.td \ |
| # --new-enum MemoryModel |
| # |
| # The 'operand_kinds' dict of spirv.core.grammar.json contains all supported |
| # SPIR-V enum classes. |
| |
| import re |
| import requests |
| import textwrap |
| |
| SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html' |
| SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json' |
| |
| AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n' |
| AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!' |
| AUTOGEN_OPCODE_SECTION_MARKER = ( |
| 'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!') |
| |
| |
| def get_spirv_doc_from_html_spec(): |
| """Extracts instruction documentation from SPIR-V HTML spec. |
| |
| Returns: |
| - A dict mapping from instruction opcode to documentation. |
| """ |
| response = requests.get(SPIRV_HTML_SPEC_URL) |
| spec = response.content |
| |
| from bs4 import BeautifulSoup |
| spirv = BeautifulSoup(spec, 'html.parser') |
| |
| section_anchor = spirv.find('h3', {'id': '_a_id_instructions_a_instructions'}) |
| |
| doc = {} |
| |
| for section in section_anchor.parent.find_all('div', {'class': 'sect3'}): |
| for table in section.find_all('table'): |
| inst_html = table.tbody.tr.td.p |
| opname = inst_html.a['id'] |
| # Ignore the first line, which is just the opname. |
| doc[opname] = inst_html.text.split('\n', 1)[1].strip() |
| |
| return doc |
| |
| |
| def get_spirv_grammar_from_json_spec(): |
| """Extracts operand kind and instruction grammar from SPIR-V JSON spec. |
| |
| Returns: |
| - A list containing all operand kinds' grammar |
| - A list containing all instructions' grammar |
| """ |
| response = requests.get(SPIRV_JSON_SPEC_URL) |
| spec = response.content |
| |
| import json |
| spirv = json.loads(spec) |
| |
| return spirv['operand_kinds'], spirv['instructions'] |
| |
| |
| def split_list_into_sublists(items, offset): |
| """Split the list of items into multiple sublists. |
| |
| This is to make sure the string composed from each sublist won't exceed |
| 80 characters. |
| |
| Arguments: |
| - items: a list of strings |
| - offset: the offset in calculating each sublist's length |
| """ |
| chuncks = [] |
| chunk = [] |
| chunk_len = 0 |
| |
| for item in items: |
| chunk_len += len(item) + 2 |
| if chunk_len > 80: |
| chuncks.append(chunk) |
| chunk = [] |
| chunk_len = len(item) + 2 |
| chunk.append(item) |
| |
| if len(chunk) != 0: |
| chuncks.append(chunk) |
| |
| return chuncks |
| |
| |
| def uniquify(lst, equality_fn): |
| """Returns a list after pruning duplicate elements. |
| |
| Arguments: |
| - lst: List whose elements are to be uniqued. |
| - equality_fn: Function used to compare equality between elements of the |
| list. |
| |
| Returns: |
| - A list with all duplicated removed. The order of elements is same as the |
| original list, with only the first occurence of duplicates retained. |
| """ |
| keys = set() |
| unique_lst = [] |
| for elem in lst: |
| key = equality_fn(elem) |
| if key not in keys: |
| unique_lst.append(elem) |
| keys.add(key) |
| return unique_lst |
| |
| |
| def gen_operand_kind_enum_attr(operand_kind): |
| """Generates the TableGen I32EnumAttr definition for the given operand kind. |
| |
| Returns: |
| - The operand kind's name |
| - A string containing the TableGen I32EnumAttr definition |
| """ |
| if 'enumerants' not in operand_kind: |
| return '', '' |
| |
| kind_name = operand_kind['kind'] |
| kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) |
| kind_cases = [(case['enumerant'], case['value']) |
| for case in operand_kind['enumerants']] |
| kind_cases = uniquify(kind_cases, lambda x: x[1]) |
| max_len = max([len(symbol) for (symbol, _) in kind_cases]) |
| |
| # Generate the definition for each enum case |
| fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\ |
| 'I32EnumAttrCase<"{symbol}", {value}>;' |
| case_defs = [ |
| fmt_str.format( |
| acronym=kind_acronym, |
| symbol=case[0], |
| value=case[1], |
| colon=':', |
| offset=(max_len + 1 - len(case[0]))) for case in kind_cases |
| ] |
| case_defs = '\n'.join(case_defs) |
| |
| # Generate the list of enum case names |
| fmt_str = 'SPV_{acronym}_{symbol}'; |
| case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0]) |
| for case in kind_cases] |
| |
| # Split them into sublists and concatenate into multiple lines |
| case_names = split_list_into_sublists(case_names, 6) |
| case_names = ['{:6}'.format('') + ', '.join(sublist) |
| for sublist in case_names] |
| case_names = ',\n'.join(case_names) |
| |
| # Generate the enum attribute definition |
| enum_attr = 'def SPV_{name}Attr :\n '\ |
| 'I32EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n ]> {{\n'\ |
| ' let returnType = "::mlir::spirv::{name}";\n'\ |
| ' let convertFromStorage = '\ |
| '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\ |
| ' let cppNamespace = "::mlir::spirv";\n}}'.format( |
| name=kind_name, cases=case_names) |
| return kind_name, case_defs + '\n\n' + enum_attr |
| |
| |
| def gen_opcode(instructions): |
| """ Generates the TableGen definition to map opname to opcode |
| |
| Returns: |
| - A string containing the TableGen SPV_OpCode definition |
| """ |
| |
| max_len = max([len(inst['opname']) for inst in instructions]) |
| def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\ |
| 'I32EnumAttrCase<"{name}", {value}>;' |
| opcode_defs = [ |
| def_fmt_str.format( |
| name=inst['opname'], |
| value=inst['opcode'], |
| colon=':', |
| offset=(max_len + 1 - len(inst['opname']))) for inst in instructions |
| ] |
| opcode_str = '\n'.join(opcode_defs) |
| |
| decl_fmt_str = 'SPV_OC_{name}' |
| opcode_list = [ |
| decl_fmt_str.format(name=inst['opname']) for inst in instructions |
| ] |
| opcode_list = split_list_into_sublists(opcode_list, 6) |
| opcode_list = [ |
| '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list |
| ] |
| opcode_list = ',\n'.join(opcode_list) |
| enum_attr = 'def SPV_OpcodeAttr :\n'\ |
| ' I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\ |
| '{lst}\n'\ |
| ' ]> {{\n'\ |
| ' let returnType = "::mlir::spirv::{name}";\n'\ |
| ' let convertFromStorage = '\ |
| '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\ |
| ' let cppNamespace = "::mlir::spirv";\n}}'.format( |
| name='Opcode', lst=opcode_list) |
| return opcode_str + '\n\n' + enum_attr |
| |
| |
| def update_td_opcodes(path, instructions, filter_list): |
| """Updates SPIRBase.td with new generated opcode cases. |
| |
| Arguments: |
| - path: the path to SPIRBase.td |
| - instructions: a list containing all SPIR-V instructions' grammar |
| - filter_list: a list containing new opnames to add |
| """ |
| |
| with open(path, 'r') as f: |
| content = f.read() |
| |
| content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) |
| assert len(content) == 3 |
| |
| # Extend opcode list with existing list |
| existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])] |
| filter_list.extend(existing_opcodes) |
| filter_list = list(set(filter_list)) |
| |
| # Generate the opcode for all instructions in SPIR-V |
| filter_instrs = list( |
| filter(lambda inst: (inst['opname'] in filter_list), instructions)) |
| # Sort instruction based on opcode |
| filter_instrs.sort(key=lambda inst: inst['opcode']) |
| opcode = gen_opcode(filter_instrs) |
| |
| # Substitute the opcode |
| content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \ |
| opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \ |
| + content[2] |
| |
| with open(path, 'w') as f: |
| f.write(content) |
| |
| |
| def update_td_enum_attrs(path, operand_kinds, filter_list): |
| """Updates SPIRBase.td with new generated enum definitions. |
| |
| Arguments: |
| - path: the path to SPIRBase.td |
| - operand_kinds: a list containing all operand kinds' grammar |
| - filter_list: a list containing new enums to add |
| """ |
| with open(path, 'r') as f: |
| content = f.read() |
| |
| content = content.split(AUTOGEN_ENUM_SECTION_MARKER) |
| assert len(content) == 3 |
| |
| # Extend filter list with existing enum definitions |
| existing_kinds = [ |
| k[8:-4] for k in re.findall('def SPV_\w+Attr', content[1])] |
| filter_list.extend(existing_kinds) |
| |
| # Generate definitions for all enums in filter list |
| defs = [gen_operand_kind_enum_attr(kind) |
| for kind in operand_kinds if kind['kind'] in filter_list] |
| # Sort alphabetically according to enum name |
| defs.sort(key=lambda enum : enum[0]) |
| # Only keep the definitions from now on |
| defs = [enum[1] for enum in defs] |
| |
| # Substitute the old section |
| content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \ |
| '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER \ |
| + content[2]; |
| |
| with open(path, 'w') as f: |
| f.write(content) |
| |
| |
| def snake_casify(name): |
| """Turns the given name to follow snake_case convension.""" |
| name = re.sub('\W+', '', name).split() |
| name = [s.lower() for s in name] |
| return '_'.join(name) |
| |
| |
| def map_spec_operand_to_ods_argument(operand): |
| """Maps a operand in SPIR-V JSON spec to an op argument in ODS. |
| |
| Arguments: |
| - A dict containing the operand's kind, quantifier, and name |
| |
| Returns: |
| - A string containing both the type and name for the argument |
| """ |
| kind = operand['kind'] |
| quantifier = operand.get('quantifier', '') |
| |
| # These instruction "operands" are for encoding the results; they should |
| # not be handled here. |
| assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind' |
| assert kind != 'IdResult', 'unexpected to handle "IdResult" kind' |
| |
| if kind == 'IdRef': |
| if quantifier == '': |
| arg_type = 'SPV_Type' |
| elif quantifier == '?': |
| arg_type = 'SPV_Optional<SPV_Type>' |
| else: |
| arg_type = 'Variadic<SPV_Type>' |
| elif kind == 'IdMemorySemantics' or kind == 'IdScope': |
| # TODO(antiagainst): Need to further constrain 'IdMemorySemantics' |
| # and 'IdScope' given that they should be gernated from OpConstant. |
| assert quantifier == '', ('unexpected to have optional/variadic memory ' |
| 'semantics or scope <id>') |
| arg_type = 'I32' |
| elif kind == 'LiteralInteger': |
| if quantifier == '': |
| arg_type = 'I32Attr' |
| elif quantifier == '?': |
| arg_type = 'OptionalAttr<I32Attr>' |
| else: |
| arg_type = 'OptionalAttr<I32ArrayAttr>' |
| elif kind == 'LiteralString' or \ |
| kind == 'LiteralContextDependentNumber' or \ |
| kind == 'LiteralExtInstInteger' or \ |
| kind == 'LiteralSpecConstantOpInteger' or \ |
| kind == 'PairLiteralIntegerIdRef' or \ |
| kind == 'PairIdRefLiteralInteger' or \ |
| kind == 'PairIdRefIdRef': |
| assert False, '"{}" kind unimplemented'.format(kind) |
| else: |
| # The rest are all enum operands that we represent with op attributes. |
| assert quantifier != '*', 'unexpected to have variadic enum attribute' |
| arg_type = 'SPV_{}Attr'.format(kind) |
| if quantifier == '?': |
| arg_type = 'OptionalAttr<{}>'.format(arg_type) |
| |
| name = operand.get('name', '') |
| name = snake_casify(name) if name else kind.lower() |
| |
| return '{}:${}'.format(arg_type, name) |
| |
| |
| def get_op_definition(instruction, doc, existing_info): |
| """Generates the TableGen op definition for the given SPIR-V instruction. |
| |
| Arguments: |
| - instruction: the instruction's SPIR-V JSON grammar |
| - doc: the instruction's SPIR-V HTML doc |
| - existing_info: a dict containing potential manually specified sections for |
| this instruction |
| |
| Returns: |
| - A string containing the TableGen op definition |
| """ |
| fmt_str = 'def SPV_{opname}Op : SPV_Op<"{opname}", [{traits}]> {{\n'\ |
| ' let summary = {summary};\n\n'\ |
| ' let description = [{{\n'\ |
| '{description}\n\n'\ |
| ' ### Custom assembly form\n'\ |
| '{assembly}'\ |
| '}}];\n\n'\ |
| ' let arguments = (ins{args});\n\n'\ |
| ' let results = (outs{results});\n'\ |
| '{extras}'\ |
| '}}\n' |
| |
| opname = instruction['opname'][2:] |
| |
| summary, description = doc.split('\n', 1) |
| wrapper = textwrap.TextWrapper( |
| width=76, initial_indent=' ', subsequent_indent=' ') |
| |
| # Format summary. If the summary can fit in the same line, we print it out |
| # as a "-quoted string; otherwise, wrap the lines using "[{...}]". |
| summary = summary.strip(); |
| if len(summary) + len(' let summary = "";') <= 80: |
| summary = '"{}"'.format(summary) |
| else: |
| summary = '[{{\n{}\n }}]'.format(wrapper.fill(summary)) |
| |
| # Wrap description |
| description = description.split('\n') |
| description = [wrapper.fill(line) for line in description if line] |
| description = '\n\n'.join(description) |
| |
| operands = instruction.get('operands', []) |
| |
| # Set op's result |
| results = '' |
| if len(operands) > 0 and operands[0]['kind'] == 'IdResultType': |
| results = '\n SPV_Type:$result\n ' |
| operands = operands[1:] |
| if 'results' in existing_info: |
| results = existing_info['results'] |
| |
| # Ignore the operand standing for the result <id> |
| if len(operands) > 0 and operands[0]['kind'] == 'IdResult': |
| operands = operands[1:] |
| |
| # Set op' argument |
| arguments = existing_info.get('arguments', None) |
| if arguments is None: |
| arguments = [map_spec_operand_to_ods_argument(o) for o in operands] |
| arguments = ',\n '.join(arguments) |
| if arguments: |
| # Prepend and append whitespace for formatting |
| arguments = '\n {}\n '.format(arguments) |
| |
| assembly = existing_info.get('assembly', None) |
| if assembly is None: |
| assembly = '\n ``` {.ebnf}\n'\ |
| ' [TODO]\n'\ |
| ' ```\n\n'\ |
| ' For example:\n\n'\ |
| ' ```\n'\ |
| ' [TODO]\n'\ |
| ' ```\n ' |
| |
| return fmt_str.format( |
| opname=opname, |
| traits=existing_info.get('traits', ''), |
| summary=summary, |
| description=description, |
| assembly=assembly, |
| args=arguments, |
| results=results, |
| extras=existing_info.get('extras', '')) |
| |
| |
| def extract_td_op_info(op_def): |
| """Extracts potentially manually specified sections in op's definition. |
| |
| Arguments: - A string containing the op's TableGen definition |
| - doc: the instruction's SPIR-V HTML doc |
| |
| Returns: |
| - A dict containing potential manually specified sections |
| """ |
| # Get opname |
| opname = [o[8:-2] for o in re.findall('def SPV_\w+Op', op_def)] |
| assert len(opname) == 1, 'more than one ops in the same section!' |
| opname = opname[0] |
| |
| # Get traits |
| op_tmpl_params = op_def.split('<', 1)[1].split('>', 1)[0].split(', ', 1) |
| if len(op_tmpl_params) == 1: |
| traits = '' |
| else: |
| traits = op_tmpl_params[1].strip('[]') |
| |
| # Get custom assembly form |
| rest = op_def.split('### Custom assembly form\n') |
| assert len(rest) == 2, \ |
| '{}: cannot find "### Custom assembly form"'.format(opname) |
| rest = rest[1].split(' let arguments = (ins') |
| assert len(rest) == 2, '{}: cannot find arguments'.format(opname) |
| assembly = rest[0].rstrip('}];\n') |
| |
| # Get arguments |
| rest = rest[1].split(' let results = (outs') |
| assert len(rest) == 2, '{}: cannot find results'.format(opname) |
| args = rest[0].rstrip(');\n') |
| |
| # Get results |
| rest = rest[1].split(');', 1) |
| assert len(rest) == 2, \ |
| '{}: cannot find ");" ending results'.format(opname) |
| results = rest[0] |
| |
| extras = rest[1].strip(' }\n') |
| if extras: |
| extras = '\n {}\n'.format(extras) |
| |
| return { |
| # Prefix with 'Op' to make it consistent with SPIR-V spec |
| 'opname': 'Op{}'.format(opname), |
| 'traits': traits, |
| 'assembly': assembly, |
| 'arguments': args, |
| 'results': results, |
| 'extras': extras |
| } |
| |
| |
| def update_td_op_definitions(path, instructions, docs, filter_list): |
| """Updates SPIRVOps.td with newly generated op definition. |
| |
| Arguments: |
| - path: path to SPIRVOps.td |
| - instructions: SPIR-V JSON grammar for all instructions |
| - docs: SPIR-V HTML doc for all instructions |
| - filter_list: a list containing new opnames to include |
| |
| Returns: |
| - A string containing all the TableGen op definitions |
| """ |
| with open(path, 'r') as f: |
| content = f.read() |
| |
| # Split the file into chuncks, each containing one op. |
| ops = content.split(AUTOGEN_OP_DEF_SEPARATOR) |
| header = ops[0] |
| footer = ops[-1] |
| ops = ops[1:-1] |
| |
| # For each existing op, extract the manually-written sections out to retain |
| # them when re-generating the ops. Also append the existing ops to filter |
| # list. |
| op_info_dict = {} |
| for op in ops: |
| info_dict = extract_td_op_info(op) |
| opname = info_dict['opname'] |
| op_info_dict[opname] = info_dict |
| filter_list.append(opname) |
| filter_list = sorted(list(set(filter_list))) |
| |
| op_defs = [] |
| for opname in filter_list: |
| # Find the grammar spec for this op |
| instruction = next( |
| inst for inst in instructions if inst['opname'] == opname) |
| op_defs.append( |
| get_op_definition(instruction, docs[opname], |
| op_info_dict.get(opname, {}))) |
| |
| # Substitute the old op definitions |
| op_defs = [header] + op_defs + [footer] |
| content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs) |
| |
| with open(path, 'w') as f: |
| f.write(content) |
| |
| |
| if __name__ == '__main__': |
| import argparse |
| |
| cli_parser = argparse.ArgumentParser( |
| description='Update SPIR-V dialect definitions using SPIR-V spec') |
| |
| cli_parser.add_argument( |
| '--base-td-path', |
| dest='base_td_path', |
| type=str, |
| default=None, |
| help='Path to SPIRVBase.td') |
| cli_parser.add_argument( |
| '--op-td-path', |
| dest='op_td_path', |
| type=str, |
| default=None, |
| help='Path to SPIRVOps.td') |
| |
| cli_parser.add_argument( |
| '--new-enum', |
| dest='new_enum', |
| type=str, |
| default=None, |
| help='SPIR-V enum to be added to SPIRVBase.td') |
| cli_parser.add_argument( |
| '--new-opcodes', |
| dest='new_opcodes', |
| type=str, |
| default=None, |
| nargs='*', |
| help='update SPIR-V opcodes in SPIRVBase.td') |
| cli_parser.add_argument( |
| '--new-inst', |
| dest='new_inst', |
| type=str, |
| default=None, |
| help='SPIR-V instruction to be added to SPIRVOps.td') |
| |
| args = cli_parser.parse_args() |
| |
| operand_kinds, instructions = get_spirv_grammar_from_json_spec() |
| |
| # Define new enum attr |
| if args.new_enum is not None: |
| assert args.base_td_path is not None |
| filter_list = [args.new_enum] if args.new_enum else [] |
| update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list) |
| |
| # Define new opcode |
| if args.new_opcodes is not None: |
| assert args.base_td_path is not None |
| update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) |
| |
| # Define new op |
| if args.new_inst is not None: |
| assert args.op_td_path is not None |
| filter_list = [args.new_inst] if args.new_inst else [] |
| docs = get_spirv_doc_from_html_spec() |
| update_td_op_definitions(args.op_td_path, instructions, docs, filter_list) |
| print('Done. Note that this script just generates a template; ', end='') |
| print('please read the spec and update traits, arguments, and ', end='') |
| print('results accordingly.') |