blob: 4a2be118005e13a60bcdc7dd8228658c95650d6e [file] [log] [blame]
# Description: SPMD partitioning pass.
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
default_visibility = [":friends"],
licenses = ["notice"],
)
package_group(
name = "friends",
includes = [
"//tensorflow/compiler/xla:friends",
],
)
cc_library(
name = "spmd_partitioner",
srcs = [
"convolution_handler.cc",
"custom_call_handler.cc",
"dot_handler.cc",
"fft_handler.cc",
"gather_scatter_handler.cc",
"spmd_partitioner.cc",
"spmd_partitioner_util.cc",
],
hdrs = [
"convolution_handler.h",
"custom_call_handler.h",
"spmd_partitioner.h",
"spmd_partitioner_util.h",
],
deps = [
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:comparators",
"//tensorflow/compiler/xla/service:dot_as_convolution_util",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cse",
"//tensorflow/compiler/xla/service:hlo_dce",
"//tensorflow/compiler/xla/service:hlo_lexer",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_query",
"//tensorflow/compiler/xla/service:hlo_sharding_util",
"//tensorflow/compiler/xla/service:pattern_matcher",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/service:sharding_propagation",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/core:lib",
"//tensorflow/core/platform:numbers",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "spmd_partitioner_test",
srcs = ["spmd_partitioner_test.cc"],
deps = [
":spmd_partitioner",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
tf_cc_test(
name = "canonicalize_all_gather_for_cse_test",
srcs = ["canonicalize_all_gather_for_cse_test.cc"],
deps = [
":canonicalize_all_gather_for_cse",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "canonicalize_all_gather_for_cse",
srcs = ["canonicalize_all_gather_for_cse.cc"],
hdrs = ["canonicalize_all_gather_for_cse.h"],
deps = [
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_query",
],
)
tf_cc_test(
name = "schedule_aware_collective_ops_cse_test",
srcs = ["schedule_aware_collective_ops_cse_test.cc"],
deps = [
":schedule_aware_collective_ops_cse",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "schedule_aware_collective_ops_cse",
srcs = ["schedule_aware_collective_ops_cse.cc"],
hdrs = ["schedule_aware_collective_ops_cse.h"],
deps = [
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
],
)