| # MLIR passes for DTensor support. |
| load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") |
| load("@bazel_skylib//rules:build_test.bzl", "build_test") |
| load("//tensorflow:tensorflow.bzl", "tf_cc_test") |
| load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") |
| load("//tensorflow:tensorflow.bzl", "get_compatible_with_cloud") |
| |
| package( |
| default_visibility = [ |
| "//tensorflow/dtensor:dtensor-internal", |
| # Allow visibility from the mlir language server. |
| "//learning/brain/mlir/mlir_lsp_server:__pkg__", |
| ], |
| licenses = ["notice"], |
| ) |
| |
| gentbl_cc_library( |
| name = "tensorflow_dtensor_ops_inc_gen", |
| compatible_with = get_compatible_with_cloud(), |
| tbl_outs = [ |
| ( |
| ["-gen-op-decls"], |
| "ir/tf_dtensor.h.inc", |
| ), |
| ( |
| ["-gen-op-defs"], |
| "ir/tf_dtensor.cc.inc", |
| ), |
| ], |
| tblgen = "@llvm-project//mlir:mlir-tblgen", |
| td_file = "ir/tf_dtensor.td", |
| td_srcs = [ |
| "//tensorflow/compiler/mlir/tensorflow:ir/tf_op_base.td", |
| "//tensorflow/compiler/mlir/tensorflow:ir/tf_op_interfaces.td", |
| ], |
| deps = [ |
| "@llvm-project//mlir:CallInterfacesTdFiles", |
| "@llvm-project//mlir:FuncTdFiles", |
| "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", |
| "@llvm-project//mlir:OpBaseTdFiles", |
| "@llvm-project//mlir:SideEffectInterfacesTdFiles", |
| ], |
| ) |
| |
| gentbl_cc_library( |
| name = "dtensor_passes_inc_gen", |
| compatible_with = get_compatible_with_cloud(), |
| tbl_outs = [( |
| [ |
| "-gen-pass-decls", |
| "-name=DTensor", |
| ], |
| "dtensor_passes.h.inc", |
| )], |
| tblgen = "@llvm-project//mlir:mlir-tblgen", |
| td_file = "Passes.td", |
| deps = ["@llvm-project//mlir:PassBaseTdFiles"], |
| ) |
| |
| cc_library( |
| name = "tf_dtensor_dialect", |
| srcs = ["ir/tf_dtensor.cc"], |
| hdrs = ["ir/tf_dtensor.h"], |
| includes = ["include"], |
| deps = [ |
| ":tensorflow_dtensor_ops_inc_gen", |
| "//tensorflow/compiler/mlir/tensorflow", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", |
| "//tensorflow/dtensor/mlir/dtensor_dialect:ir/dtensor_attributes", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:DerivedAttributeOpInterface", |
| "@llvm-project//mlir:Dialect", |
| "@llvm-project//mlir:IR", |
| "@llvm-project//mlir:InferTypeOpInterface", |
| "@llvm-project//mlir:SideEffectInterfaces", |
| "@llvm-project//mlir:Support", |
| ], |
| alwayslink = 1, |
| ) |
| |
| cc_library( |
| name = "collectives", |
| srcs = ["collectives.cc"], |
| hdrs = ["collectives.h"], |
| deps = [ |
| ":collectives_common", |
| ":dtensor_location", |
| ":layout_parsing", |
| ":shape_utils", |
| ":sparse_expander_common", |
| ":spmd_expander_common", |
| ":tf_dtensor_dialect", |
| ":value_utils", |
| "//tensorflow/compiler/mlir/tensorflow", |
| "//tensorflow/compiler/mlir/tensorflow:convert_tensor", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", |
| "//tensorflow/core:lib", |
| "//tensorflow/dtensor/cc:dstatus", |
| "//tensorflow/dtensor/cc:tensor_layout", |
| "@com_google_absl//absl/container:flat_hash_set", |
| "@com_google_absl//absl/strings", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:IR", |
| ], |
| ) |
| |
| cc_library( |
| name = "collectives_common", |
| srcs = ["collectives_common.cc"], |
| hdrs = ["collectives_common.h"], |
| deps = [ |
| "//tensorflow/dtensor/cc:tensor_layout", |
| "@com_google_absl//absl/container:flat_hash_map", |
| ], |
| ) |
| |
| cc_library( |
| name = "device_utils", |
| srcs = ["device_utils.cc"], |
| hdrs = ["device_utils.h"], |
| deps = [ |
| "//tensorflow/core/platform:errors", |
| "//tensorflow/dtensor/cc:dstatus", |
| "//tensorflow/dtensor/cc:tensor_layout", |
| "@llvm-project//mlir:FuncDialect", |
| "@llvm-project//mlir:IR", |
| ], |
| alwayslink = True, |
| ) |
| |
| cc_library( |
| name = "dtensor_location", |
| srcs = ["dtensor_location.cc"], |
| hdrs = ["dtensor_location.h"], |
| deps = [ |
| "//tensorflow/compiler/mlir:name_utils", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:IR", |
| "@llvm-project//mlir:Support", |
| ], |
| ) |
| |
| cc_library( |
| name = "create_dtensor_mlir_passes", |
| hdrs = [ |
| "create_dtensor_mlir_passes.h", |
| "dtensor_mlir_passes_classes.h", |
| ], |
| deps = [ |
| ":device_utils", |
| ":dtensor_passes_inc_gen", |
| ":dtensor_send_recv", |
| ":layout_parsing", |
| ":op_utils", |
| ":shape_utils", |
| ":sparse_expander", |
| ":spmd_expander", |
| ":spmd_expander_common", |
| ":tf_dtensor_dialect", |
| ":value_utils", |
| "//tensorflow/compiler/mlir/hlo", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", |
| "//tensorflow/core:core_cpu_base", |
| "//tensorflow/core:lib", |
| "@llvm-project//mlir:FuncDialect", |
| "@llvm-project//mlir:IR", |
| "@llvm-project//mlir:MemRefDialect", |
| "@llvm-project//mlir:Pass", |
| "@llvm-project//mlir:SCFDialect", |
| "@llvm-project//mlir:ShapeDialect", |
| "@llvm-project//mlir:TensorDialect", |
| ], |
| alwayslink = 1, |
| ) |
| |
| cc_library( |
| name = "dtensor_mlir_passes", |
| srcs = [ |
| "annotate_global_shape.cc", |
| "cluster_function_conversion.cc", |
| "constant_folding.cc", |
| "dce.cc", |
| "designate_resource_handle_mesh.cc", |
| "device_mesh_cluster_coarsening.cc", |
| "dtensor_allreduce_combine_optimization.cc", |
| "dtensor_allreduce_scatter_optimization.cc", |
| "dtensor_allreduce_sum_optimization.cc", |
| "dtensor_mixed_precision_reduce.cc", |
| "dtensor_mlir_passes.cc", |
| "function_renaming.cc", |
| "handle_cross_cluster_dependencies.cc", |
| "handle_sparsetensors.cc", |
| "layout_propagation_v2.cc", |
| "lower_send_recv.cc", |
| "merge_clusters.cc", |
| "mesh_propagation.cc", |
| "move_compilation_to_host.cc", |
| "op_to_device_cluster.cc", |
| "propagate_default_layout.cc", |
| "propagate_device_id_to_function_args.cc", |
| "set_default_sharding.cc", |
| "sparse_expansion.cc", |
| "spmd_expansion.cc", |
| "tpu_add_resource_device_attribute.cc", |
| "tpu_integration.cc", |
| "undo_merge_const_across_mesh.cc", |
| ], |
| hdrs = ["dtensor_mlir_passes.h"], |
| deps = [ |
| ":collectives_common", |
| ":create_dtensor_mlir_passes", |
| ":device_utils", |
| ":dtensor_passes_inc_gen", |
| ":dtensor_send_recv", |
| ":group_assignment", |
| ":layout_parsing", |
| ":op_utils", |
| ":shape_utils", |
| ":sparse_expander", |
| ":spmd_expander", |
| ":spmd_expander_common", |
| ":tf_dtensor_dialect", |
| ":value_utils", |
| "//tensorflow/compiler/mlir:name_utils", |
| "//tensorflow/compiler/mlir/tensorflow", |
| "//tensorflow/compiler/mlir/tensorflow:attribute_utils", |
| "//tensorflow/compiler/mlir/tensorflow:bridge_logger", |
| "//tensorflow/compiler/mlir/tensorflow:convert_tensor", |
| "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", |
| "//tensorflow/compiler/mlir/tensorflow:error_util", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", |
| "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", |
| "//tensorflow/compiler/xla/client:sharding_builder", |
| "//tensorflow/core:core_cpu_base", |
| "//tensorflow/core:lib", |
| "//tensorflow/dtensor/cc:constants", |
| "//tensorflow/dtensor/cc:dtensor_utils", |
| "//tensorflow/dtensor/cc:tensor_layout", |
| "//tensorflow/dtensor/mlir/dtensor_dialect:ir/dtensor_attributes", |
| "//tensorflow/dtensor/mlir/utils:dtensor_mlir_passes_internal", |
| "@com_google_absl//absl/container:flat_hash_set", |
| "@com_google_absl//absl/strings", |
| "@com_google_absl//absl/types:optional", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:FuncDialect", |
| "@llvm-project//mlir:IR", |
| "@llvm-project//mlir:Pass", |
| "@llvm-project//mlir:Support", |
| "@llvm-project//mlir:Transforms", |
| ], |
| alwayslink = 1, |
| ) |
| |
| cc_library( |
| name = "dtensor_send_recv", |
| srcs = ["dtensor_send_recv.cc"], |
| hdrs = ["dtensor_send_recv.h"], |
| deps = [ |
| ":device_utils", |
| ":layout_parsing", |
| ":tf_dtensor_dialect", |
| ":value_utils", |
| "//tensorflow/compiler/mlir/tensorflow", |
| "//tensorflow/compiler/mlir/tensorflow:convert_tensor", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", |
| "//tensorflow/core/platform:errors", |
| "//tensorflow/dtensor/cc:dstatus", |
| "//tensorflow/dtensor/cc:tensor_layout", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:IR", |
| ], |
| alwayslink = True, |
| ) |
| |
| cc_library( |
| name = "group_assignment", |
| srcs = ["group_assignment.cc"], |
| hdrs = ["group_assignment.h"], |
| deps = [ |
| "//tensorflow/core:lib", |
| "//tensorflow/dtensor/cc:dstatus", |
| "@com_google_absl//absl/container:flat_hash_map", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:IR", |
| ], |
| ) |
| |
| tf_cc_test( |
| name = "group_assignment_test", |
| srcs = ["group_assignment_test.cc"], |
| deps = [ |
| ":group_assignment", |
| "//tensorflow/core:lib", |
| "//tensorflow/core:test", |
| "//tensorflow/core:test_main", |
| "//tensorflow/dtensor/cc:dstatus", |
| "//tensorflow/stream_executor/lib", |
| "@com_google_absl//absl/container:flat_hash_map", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:IR", |
| ], |
| ) |
| |
| cc_library( |
| name = "layout_parsing", |
| srcs = [ |
| "layout_parsing.cc", |
| ], |
| hdrs = ["layout_parsing.h"], |
| deps = [ |
| ":tf_dtensor_dialect", |
| "//tensorflow/compiler/mlir/tensorflow", |
| "//tensorflow/core:lib", |
| "//tensorflow/dtensor/cc:constants", |
| "//tensorflow/dtensor/cc:dstatus", |
| "//tensorflow/dtensor/cc:tensor_layout", |
| "//tensorflow/dtensor/proto:layout_proto_cc", |
| "//tensorflow/stream_executor/lib", |
| "@com_google_absl//absl/container:flat_hash_map", |
| "@com_google_absl//absl/container:flat_hash_set", |
| "@com_google_absl//absl/strings", |
| "@com_google_absl//absl/types:optional", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:FuncDialect", |
| "@llvm-project//mlir:IR", |
| ], |
| alwayslink = 1, |
| ) |
| |
| cc_library( |
| name = "op_utils", |
| srcs = ["op_utils.cc"], |
| hdrs = ["op_utils.h"], |
| deps = [ |
| ":tf_dtensor_dialect", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:FuncDialect", |
| "@llvm-project//mlir:IR", |
| ], |
| alwayslink = True, |
| ) |
| |
| cc_library( |
| name = "shape_utils", |
| srcs = ["shape_utils.cc"], |
| hdrs = ["shape_utils.h"], |
| deps = [ |
| ":tf_dtensor_dialect", |
| ":value_utils", |
| "//tensorflow/compiler/mlir/tensorflow:shape_inference_utils", |
| "//tensorflow/core:framework", |
| "//tensorflow/dtensor/cc:constants", |
| "//tensorflow/dtensor/cc:dstatus", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:FuncDialect", |
| "@llvm-project//mlir:IR", |
| ], |
| alwayslink = True, |
| ) |
| |
| cc_library( |
| name = "sparse_expander", |
| srcs = [ |
| "sparse_expanders.cc", |
| ] + glob([ |
| "*sparse_expander.cc", |
| "sparse_expansions/*sparse_expander.cc", |
| ]), |
| hdrs = glob([ |
| "*sparse_expander.h", |
| "sparse_expansions/*sparse_expander.h", |
| ]), |
| deps = [ |
| ":op_utils", |
| ":sparse_expander_common", |
| ":tf_dtensor_dialect", |
| ":value_utils", |
| "//tensorflow/compiler/mlir/tensorflow", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", |
| "//tensorflow/core:framework", |
| "//tensorflow/core/platform:errors", |
| "//tensorflow/core/platform:statusor", |
| "//tensorflow/dtensor/cc:dstatus", |
| "@com_google_absl//absl/container:flat_hash_map", |
| "@llvm-project//mlir:IR", |
| ], |
| alwayslink = 1, |
| ) |
| |
| cc_library( |
| name = "sparse_expander_common", |
| srcs = ["sparse_expander_common.cc"], |
| hdrs = ["sparse_expander_common.h"], |
| deps = [ |
| ":tf_dtensor_dialect", |
| "//tensorflow/compiler/mlir/tensorflow", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", |
| "//tensorflow/dtensor/cc:dstatus", |
| "@com_google_absl//absl/container:flat_hash_map", |
| "@com_google_absl//absl/types:optional", |
| "@llvm-project//mlir:IR", |
| ], |
| alwayslink = True, |
| ) |
| |
| cc_library( |
| name = "spmd_expander_common", |
| srcs = ["spmd_expander_common.cc"], |
| hdrs = ["spmd_expander_common.h"], |
| deps = [ |
| ":device_utils", |
| ":layout_parsing", |
| ":op_utils", |
| ":shape_utils", |
| ":tf_dtensor_dialect", |
| ":value_utils", |
| "//tensorflow/compiler/mlir/hlo:convert_op_folder", |
| "//tensorflow/compiler/mlir/tensorflow", |
| "//tensorflow/compiler/mlir/tensorflow:convert_tensor", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", |
| "//tensorflow/core:lib", |
| "//tensorflow/dtensor/cc:constants", |
| "//tensorflow/dtensor/cc:dstatus", |
| "//tensorflow/dtensor/cc:tensor_layout", |
| "@com_google_absl//absl/container:flat_hash_map", |
| "@com_google_absl//absl/strings", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:FuncDialect", |
| "@llvm-project//mlir:IR", |
| ], |
| alwayslink = True, |
| ) |
| |
| cc_library( |
| name = "spmd_expander", |
| srcs = [ |
| "spmd_expanders.cc", |
| ] + glob([ |
| "*spmd_expander.cc", |
| "expansions/*spmd_expander.cc", |
| ]), |
| hdrs = glob([ |
| "*spmd_expander.h", |
| "expansions/*spmd_expander.h", |
| ]), |
| deps = [ |
| ":collectives", |
| ":device_utils", |
| ":dtensor_location", |
| ":dtensor_send_recv", |
| ":layout_parsing", |
| ":op_utils", |
| ":shape_utils", |
| ":spmd_expander_common", |
| ":tf_dtensor_dialect", |
| ":value_utils", |
| "//tensorflow/compiler/mlir:array_container_utils", |
| "//tensorflow/compiler/mlir:name_utils", |
| "//tensorflow/compiler/mlir/hlo:convert_op_folder", |
| "//tensorflow/compiler/mlir/tensorflow", |
| "//tensorflow/compiler/mlir/tensorflow:convert_tensor", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", |
| "//tensorflow/core:framework", |
| "//tensorflow/core:lib", |
| "//tensorflow/dtensor/cc:constants", |
| "//tensorflow/dtensor/cc:dstatus", |
| "//tensorflow/dtensor/cc:save_restore_util", |
| "//tensorflow/dtensor/cc:tensor_layout", |
| "//tensorflow/dtensor/proto:layout_proto_cc", |
| "@com_google_absl//absl/container:flat_hash_map", |
| "@com_google_absl//absl/container:flat_hash_set", |
| "@com_google_absl//absl/container:inlined_vector", |
| "@com_google_absl//absl/strings", |
| "@com_google_absl//absl/types:optional", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:FuncDialect", |
| "@llvm-project//mlir:IR", |
| "@llvm-project//mlir:InferTypeOpInterface", |
| "@llvm-project//mlir:Support", |
| ], |
| alwayslink = 1, |
| ) |
| |
| cc_library( |
| name = "value_utils", |
| srcs = ["value_utils.cc"], |
| hdrs = ["value_utils.h"], |
| deps = [ |
| ":op_utils", |
| ":tf_dtensor_dialect", |
| "//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", |
| "//tensorflow/core:lib", |
| "//tensorflow/dtensor/cc:dstatus", |
| "@llvm-project//llvm:Support", |
| "@llvm-project//mlir:IR", |
| ], |
| alwayslink = True, |
| ) |
| |
| tf_cc_test( |
| name = "dtensor_location_test", |
| srcs = ["dtensor_location_test.cc"], |
| deps = [ |
| ":dtensor_location", |
| "//tensorflow/compiler/mlir:name_utils", |
| "//tensorflow/core:test", |
| "//tensorflow/core:test_main", |
| "@llvm-project//mlir:IR", |
| "@llvm-project//mlir:Support", |
| ], |
| ) |
| |
| build_test( |
| name = "mlir_build_test", |
| targets = [ |
| ":tf_dtensor_dialect", |
| ":tensorflow_dtensor_ops_inc_gen", |
| ], |
| ) |