blob: a0f30b0ce9fa78dbbd3920b6ef655f5e1e0d02ca [file] [log] [blame]
# TensorFlow -> TOSA Compiler Bridge.
# See:
# https://developer.mlplatform.org/w/tosa/
# https://github.com/llvm/llvm-project/blob/main/mlir/docs/Dialects/TOSA.md
load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
# TODO: Tighten visibility once targets are at the right granularity.
package(
default_visibility = [":internal"],
licenses = ["notice"],
)
package_group(
name = "internal",
packages = [
"//tensorflow/compiler/mlir/...",
],
)
package_group(
name = "friends",
includes = [
":internal",
],
packages = [
"//third_party/iree/...",
],
)
filegroup(
name = "tosa_ops_td_files",
srcs = [
"@llvm-project//mlir:TosaDialectTdFiles",
],
compatible_with = get_compatible_with_cloud(),
)
gentbl_cc_library(
name = "tosa_passes_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=LegalizeTosa",
],
"transforms/passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/passes.td",
deps = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)
cc_library(
name = "passes_header",
hdrs = [
"transforms/passes.h",
"transforms/passes.h.inc",
],
compatible_with = get_compatible_with_cloud(),
deps = ["@llvm-project//mlir:Pass"],
)
cc_library(
name = "legalize_common",
srcs = [
"transforms/legalize_common.cc",
"transforms/legalize_utils.cc",
],
hdrs = [
"transforms/legalize_common.h",
"transforms/legalize_utils.h",
],
compatible_with = get_compatible_with_cloud(),
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Rewrite",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:Transforms",
],
)
gentbl_cc_library(
name = "tosa_legalize_tf_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
(
["-gen-rewriters"],
"transforms/tf_legalize_patterns.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/tf_legalize_patterns.td",
deps = [
":tosa_ops_td_files",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
"@llvm-project//mlir:StdOpsTdFiles",
],
)
cc_library(
name = "tf_passes",
srcs = [
"tf_passes.cc",
"transforms/fuse_bias_tf.cc",
"transforms/legalize_tf.cc",
"transforms/tf_legalize_patterns.inc",
],
hdrs = [
"tf_passes.h",
"transforms/passes.h",
],
compatible_with = get_compatible_with_cloud(),
visibility = [":friends"],
deps = [
":legalize_common",
":passes_header",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:Transforms",
],
)
gentbl_cc_library(
name = "tosa_legalize_tfl_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
(
["-gen-rewriters"],
"transforms/tfl_legalize_patterns.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/tfl_legalize_patterns.td",
deps = [
":tosa_ops_td_files",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files",
"@llvm-project//mlir:StdOpsTdFiles",
],
)
cc_library(
name = "tfl_passes",
srcs = [
"tfl_passes.cc",
"transforms/convert_tfl_uint8.cc",
"transforms/legalize_tfl.cc",
"transforms/strip_quant_types.cc",
"transforms/tfl_legalize_patterns.inc",
],
hdrs = [
"tfl_passes.h",
"transforms/passes.h",
],
compatible_with = get_compatible_with_cloud(),
visibility = [":friends"],
deps = [
":legalize_common",
":passes_header",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineTransforms",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:Transforms",
],
)
cc_library(
name = "tf_tfl_passes",
srcs = [
"tf_tfl_passes.cc",
"transforms/legalize_tf_tfl.cc",
],
hdrs = [
"tf_tfl_passes.h",
"transforms/passes.h",
],
compatible_with = get_compatible_with_cloud(),
visibility = [":friends"],
deps = [
":legalize_common",
":passes_header",
":tf_passes",
":tfl_passes",
"//tensorflow/compiler/mlir/lite:lift_tflite_flex_ops",
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"@llvm-project//mlir:AffineTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:Transforms",
],
)