blob: a934bb3ecc4d80d2c6937ad3e3616168e40f54f4 [file] [log] [blame]
#!/usr/bin/env python3
import os
from pathlib import Path
from torch.jit._decompositions import decomposition_table
# from tools.codegen.code_template import CodeTemplate
DECOMP_HEADER = r"""
/**
* @generated
* This is an auto-generated file. Please do not modify it by hand.
* To re-generate, please run:
* cd ~/pytorch && python tools/codegen/decompositions/gen_jit_decompositions.py
*/
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/decomposition_registry_util.h>
namespace torch {
namespace jit {
const std::string decomp_funcs =
R"("""
DECOMP_CENTER = r"""
)";
const std::string& GetSerializedDecompositions() {
return decomp_funcs;
}
const OperatorMap<std::string>& GetDecompositionMapping() {
// clang-format off
static const OperatorMap<std::string> decomposition_mapping {
"""
DECOMP_END = r"""
};
// clang-format on
return decomposition_mapping;
}
} // namespace jit
} // namespace torch
"""
DECOMPOSITION_UTIL_FILE_NAME = "decomposition_registry_util.cpp"
def gen_serialized_decompisitions() -> str:
return "\n".join([scripted_func.code for scripted_func in decomposition_table.values()])
def gen_decomposition_mappings() -> str:
decomposition_mappings = []
for schema, scripted_func in decomposition_table.items():
decomposition_mappings.append(
' {"' + schema + '", "' + scripted_func.name + '"},'
)
return "\n".join(decomposition_mappings)
def write_decomposition_util_file(path: str) -> None:
decomposition_str = gen_serialized_decompisitions()
decomposition_mappings = gen_decomposition_mappings()
file_components = [DECOMP_HEADER, decomposition_str, DECOMP_CENTER, decomposition_mappings, DECOMP_END]
print("writing file to : ", path + "/" + DECOMPOSITION_UTIL_FILE_NAME)
with open(
os.path.join(path, DECOMPOSITION_UTIL_FILE_NAME), "wb"
) as out_file:
final_output = "".join(file_components)
out_file.write(final_output.encode("utf-8"))
def main() -> None:
pytorch_dir = Path(__file__).resolve().parents[3]
upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
write_decomposition_util_file(str(upgrader_path))
if __name__ == '__main__':
main()