| # Description: SPMD partitioning pass. |
| |
| load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") |
| load("//tensorflow:tensorflow.bzl", "tf_cc_test") |
| |
| package( |
| default_visibility = [":friends"], |
| licenses = ["notice"], |
| ) |
| |
| package_group( |
| name = "friends", |
| includes = [ |
| "//tensorflow/compiler/xla:friends", |
| ], |
| ) |
| |
| cc_library( |
| name = "spmd_partitioner", |
| srcs = [ |
| "convolution_handler.cc", |
| "custom_call_handler.cc", |
| "dot_handler.cc", |
| "fft_handler.cc", |
| "gather_scatter_handler.cc", |
| "spmd_partitioner.cc", |
| "spmd_partitioner_util.cc", |
| ], |
| hdrs = [ |
| "convolution_handler.h", |
| "custom_call_handler.h", |
| "spmd_partitioner.h", |
| "spmd_partitioner_util.h", |
| ], |
| deps = [ |
| "//tensorflow/compiler/xla:comparison_util", |
| "//tensorflow/compiler/xla:literal_util", |
| "//tensorflow/compiler/xla:protobuf_util", |
| "//tensorflow/compiler/xla:shape_util", |
| "//tensorflow/compiler/xla:status", |
| "//tensorflow/compiler/xla:util", |
| "//tensorflow/compiler/xla:window_util", |
| "//tensorflow/compiler/xla:xla_data_proto_cc", |
| "//tensorflow/compiler/xla/client:xla_builder", |
| "//tensorflow/compiler/xla/client/lib:comparators", |
| "//tensorflow/compiler/xla/service:dot_as_convolution_util", |
| "//tensorflow/compiler/xla/service:flatten_call_graph", |
| "//tensorflow/compiler/xla/service:hlo", |
| "//tensorflow/compiler/xla/service:hlo_cse", |
| "//tensorflow/compiler/xla/service:hlo_dce", |
| "//tensorflow/compiler/xla/service:hlo_lexer", |
| "//tensorflow/compiler/xla/service:hlo_pass", |
| "//tensorflow/compiler/xla/service:hlo_pass_pipeline", |
| "//tensorflow/compiler/xla/service:hlo_query", |
| "//tensorflow/compiler/xla/service:hlo_sharding_util", |
| "//tensorflow/compiler/xla/service:pattern_matcher", |
| "//tensorflow/compiler/xla/service:shape_inference", |
| "//tensorflow/compiler/xla/service:sharding_propagation", |
| "//tensorflow/compiler/xla/service:tuple_simplifier", |
| "//tensorflow/core:lib", |
| "//tensorflow/core/platform:numbers", |
| "//tensorflow/stream_executor/lib", |
| "@com_google_absl//absl/algorithm:container", |
| "@com_google_absl//absl/cleanup", |
| "@com_google_absl//absl/container:flat_hash_map", |
| "@com_google_absl//absl/container:flat_hash_set", |
| "@com_google_absl//absl/container:inlined_vector", |
| "@com_google_absl//absl/memory", |
| "@com_google_absl//absl/strings", |
| "@com_google_absl//absl/types:optional", |
| "@com_google_absl//absl/types:span", |
| ], |
| ) |
| |
| tf_cc_test( |
| name = "spmd_partitioner_test", |
| srcs = ["spmd_partitioner_test.cc"], |
| deps = [ |
| ":spmd_partitioner", |
| "//tensorflow/compiler/xla:util", |
| "//tensorflow/compiler/xla:xla_data_proto_cc", |
| "//tensorflow/compiler/xla/service:hlo", |
| "//tensorflow/compiler/xla/service:hlo_matchers", |
| "//tensorflow/compiler/xla/service:hlo_parser", |
| "//tensorflow/compiler/xla/service:hlo_pass_pipeline", |
| "//tensorflow/compiler/xla/service:hlo_verifier", |
| "//tensorflow/compiler/xla/tests:hlo_test_base", |
| "//tensorflow/compiler/xla/tests:xla_internal_test_main", |
| "//tensorflow/core:test", |
| ], |
| ) |
| |
| tf_cc_test( |
| name = "canonicalize_all_gather_for_cse_test", |
| srcs = ["canonicalize_all_gather_for_cse_test.cc"], |
| deps = [ |
| ":canonicalize_all_gather_for_cse", |
| "//tensorflow/compiler/xla:xla_data_proto_cc", |
| "//tensorflow/compiler/xla/service:hlo", |
| "//tensorflow/compiler/xla/service:hlo_matchers", |
| "//tensorflow/compiler/xla/service:hlo_parser", |
| "//tensorflow/compiler/xla/service:hlo_pass_pipeline", |
| "//tensorflow/compiler/xla/service:hlo_verifier", |
| "//tensorflow/compiler/xla/tests:hlo_test_base", |
| "//tensorflow/compiler/xla/tests:xla_internal_test_main", |
| "//tensorflow/core:test", |
| ], |
| ) |
| |
| cc_library( |
| name = "canonicalize_all_gather_for_cse", |
| srcs = ["canonicalize_all_gather_for_cse.cc"], |
| hdrs = ["canonicalize_all_gather_for_cse.h"], |
| deps = [ |
| "//tensorflow/compiler/xla/service:hlo", |
| "//tensorflow/compiler/xla/service:hlo_pass", |
| "//tensorflow/compiler/xla/service:hlo_query", |
| ], |
| ) |
| |
| tf_cc_test( |
| name = "schedule_aware_collective_ops_cse_test", |
| srcs = ["schedule_aware_collective_ops_cse_test.cc"], |
| deps = [ |
| ":schedule_aware_collective_ops_cse", |
| "//tensorflow/compiler/xla:xla_data_proto_cc", |
| "//tensorflow/compiler/xla/service:hlo", |
| "//tensorflow/compiler/xla/service:hlo_matchers", |
| "//tensorflow/compiler/xla/service:hlo_parser", |
| "//tensorflow/compiler/xla/service:hlo_pass_pipeline", |
| "//tensorflow/compiler/xla/service:hlo_verifier", |
| "//tensorflow/compiler/xla/tests:hlo_test_base", |
| "//tensorflow/compiler/xla/tests:xla_internal_test_main", |
| "//tensorflow/core:test", |
| ], |
| ) |
| |
| cc_library( |
| name = "schedule_aware_collective_ops_cse", |
| srcs = ["schedule_aware_collective_ops_cse.cc"], |
| hdrs = ["schedule_aware_collective_ops_cse.h"], |
| deps = [ |
| "//tensorflow/compiler/xla/service:hlo", |
| "//tensorflow/compiler/xla/service:hlo_pass", |
| "//tensorflow/stream_executor/lib", |
| "@com_google_absl//absl/container:flat_hash_map", |
| ], |
| ) |