blob: 3c1dc5c55ecca1692629acb1cc7ab7ae43cd68c7 [file] [log] [blame]
"""BUILD extension for TF composition project."""
load("//tensorflow:tensorflow.bzl", "py_binary", "tf_custom_op_library", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
def gen_op_libraries(
name,
src,
deps = [],
tags = [],
test = False):
"""gen_op_libraries() generates all cc and py libraries for composite op source.
Args:
name: used as the name component of all the generated libraries.
src: File contains the composite ops.
deps: Libraries the 'src' depends on.
tags:
test:
"""
if not src.endswith(".py") or name == src[:-3]:
fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src)
gen_op_lib_exec = src[:-3] # Strip off the .py
py_binary(
name = gen_op_lib_exec,
srcs = [src],
srcs_version = "PY2AND3",
python_version = "PY3",
deps = [
"//tensorflow/compiler/mlir/tfr:op_reg_gen",
"//tensorflow/compiler/mlir/tfr:tfr_gen",
"//tensorflow/compiler/mlir/tfr:composite",
] + deps,
)
registed_op = "registed_" + name
native.genrule(
name = registed_op,
srcs = [],
outs = [name + ".inc.cc"],
cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec,
tools = [":" + gen_op_lib_exec],
tags = tags,
)
native.cc_library(
name = name + "_cc",
testonly = test,
srcs = [":" + registed_op],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
tf_custom_op_library(
name = name + ".so",
srcs = [":" + registed_op],
)
tf_gen_op_wrapper_py(
name = "gen_" + name,
out = "gen_" + name + ".py",
deps = [
":%s_cc" % name,
],
)
tf_custom_op_py_library(
name = name,
dso = [":%s.so" % name],
kernels = [":%s_cc" % name],
srcs_version = "PY2AND3",
deps = [
":gen_%s" % name,
],
)
# Link the register op and rebuild the binary
gen_tfr_lib_exec = gen_op_lib_exec + "_with_op_library"
py_binary(
name = gen_tfr_lib_exec,
main = src,
srcs = [src],
srcs_version = "PY2AND3",
python_version = "PY3",
deps = [
"//tensorflow/compiler/mlir/tfr:op_reg_gen",
"//tensorflow/compiler/mlir/tfr:tfr_gen",
"//tensorflow/compiler/mlir/tfr:composite",
":%s" % name,
] + deps,
)
native.genrule(
name = name + "_mlir",
srcs = [],
outs = [name + ".mlir"],
cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec,
tools = [":" + gen_tfr_lib_exec],
tags = tags,
)
native.py_library(
name = name + "_py",
srcs = [src],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/compiler/mlir/tfr:op_reg_gen",
"//tensorflow/compiler/mlir/tfr:tfr_gen",
"//tensorflow/compiler/mlir/tfr:composite",
] + deps,
)
def gen_op_bindings(name):
native.cc_library(
name = name + "_ops_cc",
srcs = [name + "_ops.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
tf_custom_op_library(
name = name + "_ops.so",
srcs = [name + "_ops.cc"],
)
tf_gen_op_wrapper_py(
name = "gen_" + name + "_ops",
out = "gen_" + name + "_ops.py",
deps = [
":" + name + "_ops_cc",
],
)
tf_custom_op_py_library(
name = name + "_ops",
dso = [":" + name + "_ops.so"],
kernels = [":" + name + "_ops_cc"],
visibility = ["//visibility:public"],
deps = [
":gen_" + name + "_ops",
],
)