[xla:runtime] NFC: Clean up jitrt <-> xla namespace re-export
Remove xla.h header and rename all uses of ::tfrt::jitrt symbols to ::xla::runtime.
PiperOrigin-RevId: 467853273
diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD
index f2336b0..2435108 100644
--- a/tensorflow/compiler/mlir/tfrt/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/BUILD
@@ -170,9 +170,12 @@
# srcs = ["jit/tf_jitrt_test.cc"],
# deps = [
# ":tf_jitrt",
+# "//tensorflow/compiler/xla/runtime:results",
+# "//tensorflow/compiler/xla/runtime:types",
# "@com_google_googletest//:gtest_main",
# "@llvm-project//mlir:mlir_c_runner_utils",
# "@tf_runtime//backends/jitrt",
+# "@tf_runtime//backends/jitrt:results",
# ],
# )
# copybara:uncomment_end
@@ -204,7 +207,11 @@
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tfrt/jit/transforms:tf_jitrt_passes",
"//tensorflow/compiler/xla/mlir/utils/runtime:async_runtime_api",
+ "//tensorflow/compiler/xla/runtime:arguments",
"//tensorflow/compiler/xla/runtime:async_runtime",
+ "//tensorflow/compiler/xla/runtime:executable",
+ "//tensorflow/compiler/xla/runtime:jit_executable",
+ "//tensorflow/compiler/xla/runtime:types",
"//tensorflow/core:framework",
"//tensorflow/core:platform_base",
"//tensorflow/core/platform:dynamic_annotations",
@@ -224,6 +231,7 @@
"@tf_runtime//:tracing",
"@tf_runtime//backends/jitrt",
"@tf_runtime//backends/jitrt:jitrt_compiler",
+ "@tf_runtime//backends/jitrt:results",
],
)
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD
index cff9311..72ba919 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/BUILD
@@ -17,6 +17,9 @@
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tfrt:tf_jitrt_pipeline",
+ "//tensorflow/compiler/xla/runtime:arguments",
+ "//tensorflow/compiler/xla/runtime:jit_executable",
+ "//tensorflow/compiler/xla/runtime:types",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:logging",
@@ -30,8 +33,8 @@
"@tf_runtime//:hostcontext",
"@tf_runtime//:support",
"@tf_runtime//:tensor",
- "@tf_runtime//backends/jitrt",
"@tf_runtime//backends/jitrt:jitrt_compiler",
+ "@tf_runtime//backends/jitrt:results",
],
)
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc
index 6b4ac21..889e172 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.cc
@@ -32,7 +32,7 @@
using ::tfrt::HostContext;
using ::tfrt::jitrt::CompilationPipelineOptions;
-using ::tfrt::jitrt::MemrefType;
+using ::xla::runtime::MemrefType;
const bool kStaticDim = false;
const bool kDynamicDim = true;
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h
index 42ba84d..7344a93 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark.h
@@ -27,8 +27,11 @@
#include "llvm/Support/MemoryBuffer.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h"
+#include "tensorflow/compiler/xla/runtime/arguments.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
+#include "tensorflow/compiler/xla/runtime/types.h"
#include "tensorflow/core/platform/test_benchmark.h"
-#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
+#include "tfrt/jitrt/results.h" // from @tf_runtime
#include "tfrt/dtype/dtype.h" // from @tf_runtime
#include "tfrt/host_context/host_context.h" // from @tf_runtime
#include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime
@@ -42,9 +45,9 @@
using ::tfrt::HostContext;
using ::tfrt::RemainingResults;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
-using ::tfrt::jitrt::Type;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+using ::xla::runtime::Type;
// Constants to make shape specification more readable.
// kStaticDim refers to the static shape in IR taken from ARGS of the benchmark.
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc
index 4d1c250..5b01a98 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/benchmark_mlir_function.cc
@@ -39,12 +39,13 @@
using ::tfrt::RequestContext;
using ::tfrt::RequestContextBuilder;
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::HostContextAsyncTaskRunner;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
using ::tfrt::jitrt::RemainingResultsConverter;
+using ::xla::runtime::Executable;
+using ::xla::runtime::HostContextAsyncTaskRunner;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+
// Returns random tensors generated based on the input specs.
static llvm::SmallVector<Tensor> GetInputTensors(
llvm::ArrayRef<InputTensorSpec> input_specs) {
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h
index fc74378..e2576a7 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/cwise_op_unary_benchmark.h
@@ -35,12 +35,14 @@
using ::tfrt::RemainingResults;
using ::tfrt::RequestContext;
using ::tfrt::RequestContextBuilder;
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::HostContextAsyncTaskRunner;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
+
using ::tfrt::jitrt::RemainingResultsConverter;
+using ::xla::runtime::Executable;
+using ::xla::runtime::HostContextAsyncTaskRunner;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+
// -------------------------------------------------------------------------- //
// Run benchmark by compiling MLIR function using TFRT JitRt API.
// -------------------------------------------------------------------------- //
diff --git a/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h b/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h
index 96da5f4..817e115 100644
--- a/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h
+++ b/tensorflow/compiler/mlir/tfrt/benchmarks/matmul_op_benchmark.h
@@ -34,11 +34,11 @@
using ::tfrt::RemainingResults;
using ::tfrt::RequestContext;
using ::tfrt::RequestContextBuilder;
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::HostContextAsyncTaskRunner;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
using ::tfrt::jitrt::RemainingResultsConverter;
+using ::xla::runtime::Executable;
+using ::xla::runtime::HostContextAsyncTaskRunner;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
// -------------------------------------------------------------------------- //
// Run benchmark by compiling MLIR function using TFRT JitRt API.
diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD
index ec5143b..f80be60 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/BUILD
@@ -43,6 +43,8 @@
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tfrt:tf_jitrt_pipeline",
"//tensorflow/compiler/mlir/tfrt/python_tests:python_test_attrs_registration",
+ "//tensorflow/compiler/xla/runtime:executable",
+ "//tensorflow/compiler/xla/runtime:jit_executable",
"//tensorflow/core/platform:dynamic_annotations",
"//third_party/eigen3",
"//third_party/python_runtime:headers", # build_cleaner: keep
@@ -55,6 +57,7 @@
"@tf_runtime//:support",
"@tf_runtime//backends/jitrt",
"@tf_runtime//backends/jitrt:jitrt_compiler",
+ "@tf_runtime//backends/jitrt:results",
],
)
@@ -94,6 +97,5 @@
deps = [
"//tensorflow/compiler/xla/runtime:arguments",
"@tf_runtime//:dtype",
- "@tf_runtime//backends/jitrt:xla",
],
)
diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc
index 76b40d0..029e217 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.cc
@@ -30,6 +30,7 @@
#include "tensorflow/core/platform/dynamic_annotations.h"
#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
#include "tfrt/jitrt/jitrt_compiler.h" // from @tf_runtime
+#include "tfrt/jitrt/results.h" // from @tf_runtime
#include "tfrt/dtype/dtype.h" // from @tf_runtime
#include "tfrt/host_context/async_value.h" // from @tf_runtime
#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime
@@ -55,14 +56,15 @@
using ::tfrt::jitrt::CompilationPipelineOptions;
using ::tfrt::jitrt::CreateDefaultJitRtCompilationPipeline;
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::HostContextAsyncTaskRunner;
-using ::tfrt::jitrt::JitExecutable;
-using ::tfrt::jitrt::MemrefDesc;
using ::tfrt::jitrt::RegisterDefaultJitRtDialects;
using ::tfrt::jitrt::RemainingResultsConverter;
using ::tfrt::jitrt::ReturnStridedMemref;
+using ::xla::runtime::Executable;
+using ::xla::runtime::HostContextAsyncTaskRunner;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+
namespace tensorflow {
TfJitRtExecutor::TfJitRtExecutor()
diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h
index 314e6f3..0288fbd 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h
+++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tf_jitrt_executor.h
@@ -23,7 +23,8 @@
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
-#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
+#include "tensorflow/compiler/xla/runtime/executable.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
#include "tfrt/host_context/host_context.h" // from @tf_runtime
namespace tensorflow {
@@ -34,7 +35,7 @@
class TfJitRtExecutor {
public:
using Handle = int64_t;
- using Specialization = tfrt::jitrt::JitExecutable::Specialization;
+ using Specialization = xla::runtime::JitExecutable::Specialization;
TfJitRtExecutor();
@@ -57,7 +58,7 @@
private:
tfrt::HostContext host_context_;
- llvm::DenseMap<Handle, tfrt::jitrt::JitExecutable> jit_executables_;
+ llvm::DenseMap<Handle, xla::runtime::JitExecutable> jit_executables_;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc
index 5e0c0ad..6db9221 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/python_binding/tfrt_fallback.cc
@@ -35,7 +35,7 @@
namespace py = pybind11;
-using ::tfrt::jitrt::MemrefDesc;
+using ::xla::runtime::MemrefDesc;
static py::array ConvertTensorToPyArray(const Tensor& tensor) {
auto tensor_sizes = tensor.shape().dim_sizes();
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc
index b1feca3..a1d490f 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_kernels.cc
@@ -32,7 +32,11 @@
#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_request_context.h"
#include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_passes.h"
#include "tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.h"
+#include "tensorflow/compiler/xla/runtime/arguments.h"
#include "tensorflow/compiler/xla/runtime/async_runtime.h"
+#include "tensorflow/compiler/xla/runtime/executable.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
+#include "tensorflow/compiler/xla/runtime/types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
@@ -42,6 +46,7 @@
#include "tensorflow/core/tfrt/utils/fallback_tensor.h"
#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
#include "tfrt/jitrt/jitrt_compiler.h" // from @tf_runtime
+#include "tfrt/jitrt/results.h" // from @tf_runtime
#include "tfrt/dtype/dtype.h" // from @tf_runtime
#include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
#include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
@@ -100,22 +105,23 @@
using ::tfrt::StringAttribute;
using ::tfrt::TaskFunction;
-using ::tfrt::jitrt::ArgumentConstraint;
-using ::tfrt::jitrt::ArgumentsRef;
using ::tfrt::jitrt::CompilationPipelineOptions;
using ::tfrt::jitrt::CreateDefaultJitRtCompilationPipeline;
-using ::tfrt::jitrt::EigenThreadPoolAsyncTaskRunner;
-using ::tfrt::jitrt::Executable;
-using ::tfrt::jitrt::JitExecutable;
using ::tfrt::jitrt::JitExecutableCache;
-using ::tfrt::jitrt::MemrefDesc;
using ::tfrt::jitrt::RegisterDefaultJitRtDialects;
using ::tfrt::jitrt::ReturnErrors;
using ::tfrt::jitrt::ReturnStridedMemref;
using ::tfrt::jitrt::ReturnValueConversion;
-using ::tfrt::jitrt::SpecializationListener;
using ::tfrt::jitrt::StaticRemainingResultsConverter;
+using ::xla::runtime::ArgumentConstraint;
+using ::xla::runtime::ArgumentsRef;
+using ::xla::runtime::EigenThreadPoolAsyncTaskRunner;
+using ::xla::runtime::Executable;
+using ::xla::runtime::JitExecutable;
+using ::xla::runtime::MemrefDesc;
+using ::xla::runtime::SpecializationListener;
+
using ::tensorflow::profiler::TraceMe;
using ::tensorflow::profiler::TraceMeEncode;
using ::tensorflow::tfd::KernelFallbackCompatRequestState;
diff --git a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc
index e294e88..c9981e0 100644
--- a/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_test.cc
@@ -22,6 +22,9 @@
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "testing/base/public/benchmark.h"
#include <gtest/gtest.h>
+#include "tensorflow/compiler/xla/runtime/results.h"
+#include "tensorflow/compiler/xla/runtime/types.h"
+#include "tfrt/jitrt/results.h" // from @tf_runtime
namespace tensorflow {
@@ -30,11 +33,12 @@
using ::tfrt::RCReference;
using ::tfrt::RemainingResults;
-using ::tfrt::jitrt::MemrefType;
using ::tfrt::jitrt::ReturnStridedMemref;
using ::tfrt::jitrt::ReturnValueConversion;
using ::tfrt::jitrt::StaticRemainingResultsConverter;
+using ::xla::runtime::MemrefType;
+
using ReturnTensorflowTensor =
ReturnValueConversion<TensorflowConversionContext,
ReturnStridedMemref<ConvertTensor>>;
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD b/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD
index ad84270..8310c25 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/BUILD
@@ -29,6 +29,8 @@
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tfrt:tf_jitrt_pipeline",
+ "//tensorflow/compiler/xla/runtime:executable",
+ "//tensorflow/compiler/xla/runtime:jit_executable",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:Support",
diff --git a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc
index 5a8997f..14941d0 100644
--- a/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc
+++ b/tensorflow/compiler/mlir/tfrt/tests/jit/tf_jitrt_benchmark_test.cc
@@ -19,6 +19,8 @@
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tfrt/jit/tf_jitrt_pipeline.h"
+#include "tensorflow/compiler/xla/runtime/executable.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
#include "tfrt/jitrt/jitrt_compiler.h" // from @tf_runtime
@@ -59,10 +61,11 @@
using ::tfrt::jitrt::CompilationPipelineOptions;
using ::tfrt::jitrt::CreateDefaultJitRtCompilationPipeline;
-using ::tfrt::jitrt::JitExecutable;
using ::tfrt::jitrt::JitExecutableCache;
using ::tfrt::jitrt::RegisterDefaultJitRtDialects;
+using ::xla::runtime::JitExecutable;
+
static void BM_InstantiateExecutable(::testing::benchmark::State& state) {
// Options for the default JitRt compilation pipeline (lowering to LLVM).
CompilationPipelineOptions copts;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 2ed28a1..79f111d 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -591,6 +591,10 @@
# "@llvm-project//llvm:OrcJIT",
# "@llvm-project//mlir:Support",
# "//tensorflow/compiler/xla:tfrt_utils",
+# "//tensorflow/compiler/xla/runtime:arguments",
+# "//tensorflow/compiler/xla/runtime:types",
+# "//tensorflow/compiler/xla/runtime:executable",
+# "//tensorflow/compiler/xla/runtime:jit_executable",
# "//tensorflow/compiler/xla:shape_util",
# "//tensorflow/compiler/xla/service:custom_call_status_internal",
# "//tensorflow/compiler/xla/service:custom_call_target_registry",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 6a8f714..2872649 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -1435,27 +1435,27 @@
copts.num_worker_threads = 1;
// Options for constructing JitRt JitExecutable.
- jitrt::JitExecutable::Options opts;
- opts.specialization = jitrt::JitExecutable::Specialization::kDisabled;
+ runtime::JitExecutable::Options opts;
+ opts.specialization = runtime::JitExecutable::Specialization::kDisabled;
opts.compiler.register_dialects = jitrt::RegisterDefaultJitRtDialects;
// Register JitRt Gpu runtime custom calls with the linker.
opts.compiler.runtime_symbol_map =
- jitrt::GetSymbolsBinding(JitRtGpuCustomCalls());
+ runtime::GetSymbolsBinding(JitRtGpuCustomCalls());
opts.compiler.create_compilation_pipeline = [copts](mlir::PassManager& pm) {
jitrt::CreateDefaultJitRtCompilationPipeline(pm, copts);
};
// Instantiate new JitExecutable from the MLIR source.
- auto jit_executable = jitrt::JitExecutable::Instantiate(
+ auto jit_executable = runtime::JitExecutable::Instantiate(
program->module, program->entry_point, opts);
if (auto err = jit_executable.takeError())
return InternalError("Failed to compile JitRt program: %s",
tfrt::StrCat(err));
// For static shapes we can always serialize only the default executable.
- jitrt::Executable& executable = jit_executable->DefaultExecutable().get();
+ runtime::Executable& executable = jit_executable->DefaultExecutable().get();
// Check if JitRt executable saved the compilation result.
std::unique_ptr<llvm::MemoryBuffer> obj_file = executable.obj_file();
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index f4ceee0..7e9e107 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -103,8 +103,8 @@
copts.populate_attr_encodings = PopulateLmhloToXlaAttrEncoding;
// Options for constructing JitRt JitExecutable.
- jitrt::JitExecutable::Options opts;
- opts.specialization = jitrt::JitExecutable::Specialization::kDisabled;
+ runtime::JitExecutable::Options opts;
+ opts.specialization = runtime::JitExecutable::Specialization::kDisabled;
opts.compiler.register_dialects = [](mlir::DialectRegistry& registry) {
jitrt::RegisterDefaultJitRtDialects(registry);
// For the encoding of attributes to custom calls.
@@ -113,7 +113,7 @@
// Register JitRt Gpu runtime custom calls with the linker.
opts.compiler.runtime_symbol_map =
- jitrt::GetSymbolsBinding(JitRtGpuCustomCalls());
+ runtime::GetSymbolsBinding(JitRtGpuCustomCalls());
// We just use the default compilation pipeline provided by the JitRt.
// Alternatively instead of having a separate JitRtProgram (LMHLO lowered to
@@ -131,7 +131,7 @@
opts.compiler.jit_code_opt_level = llvm::CodeGenOpt::None;
// Instantiate new JitExecutable from the MLIR source.
- auto jit_executable = jitrt::JitExecutable::Instantiate(
+ auto jit_executable = runtime::JitExecutable::Instantiate(
program->module, program->entry_point, opts);
if (auto err = jit_executable.takeError())
return InternalError("Failed to compile JitRt program: %s",
@@ -140,22 +140,22 @@
// Pass ownership to the GpuExecutable.
return new JitRtExecutable(
std::move(program->buffer_sizes),
- std::make_unique<jitrt::JitExecutable>(std::move(*jit_executable)),
+ std::make_unique<runtime::JitExecutable>(std::move(*jit_executable)),
std::move(program->debug_options));
}
// Create JitRtExecutable from the AOT compiled binary.
static StatusOr<JitRtExecutable*> Create(
- absl::Span<const int64_t> buffer_sizes, jitrt::Executable executable,
+ absl::Span<const int64_t> buffer_sizes, runtime::Executable executable,
DebugOptions debug_options) {
// Pass ownership to the GpuExecutable.
return new JitRtExecutable(
std::vector<int64_t>(buffer_sizes.begin(), buffer_sizes.end()),
- std::make_unique<jitrt::Executable>(std::move(executable)),
+ std::make_unique<runtime::Executable>(std::move(executable)),
std::move(debug_options));
}
- jitrt::Executable& default_executable() { return *default_executable_; }
+ runtime::Executable& default_executable() { return *default_executable_; }
JitRtKernelsCache& kernels_cache() { return kernels_cache_; }
JitRtGemmConfigCache& gemm_configs_cache() { return gemm_configs_cache_; }
JitRtCollectiveSupport& collectives() { return collectives_; }
@@ -170,7 +170,7 @@
private:
JitRtExecutable(std::vector<int64_t> buffer_sizes,
- std::unique_ptr<jitrt::JitExecutable> jit_executable,
+ std::unique_ptr<runtime::JitExecutable> jit_executable,
DebugOptions debug_options)
: buffer_sizes_(std::move(buffer_sizes)),
jit_executable_(std::move(jit_executable)),
@@ -178,7 +178,7 @@
debug_options_(std::move(debug_options)) {}
JitRtExecutable(std::vector<int64_t> buffer_sizes,
- std::unique_ptr<jitrt::Executable> aot_executable,
+ std::unique_ptr<runtime::Executable> aot_executable,
DebugOptions debug_options)
: buffer_sizes_(std::move(buffer_sizes)),
aot_executable_(std::move(aot_executable)),
@@ -187,11 +187,11 @@
std::vector<int64_t> buffer_sizes_;
- std::unique_ptr<jitrt::JitExecutable> jit_executable_;
- jitrt::Executable* default_executable_; // owned by `jit_executable`
+ std::unique_ptr<runtime::JitExecutable> jit_executable_;
+ runtime::Executable* default_executable_; // owned by `jit_executable`
- std::unique_ptr<jitrt::Executable> aot_executable_;
- jitrt::Executable* executable_;
+ std::unique_ptr<runtime::Executable> aot_executable_;
+ runtime::Executable* executable_;
DebugOptions debug_options_;
@@ -615,7 +615,7 @@
// compiled function will make a copy of all arguments and will write all
// results after the call to `Execute` completes, so it is safe to keep in on
// the stack.
- jitrt::Executable::CallFrame call_frame;
+ runtime::Executable::CallFrame call_frame;
// Each buffer allocation pased as 1d memref to the compiled kernel:
// {basePtr, dataPtr, offset, [sizes, ...], [strides, ...]}
@@ -648,14 +648,14 @@
}
// JitRt executables do not return any values.
- jitrt::NoResultConverter converter;
+ runtime::NoResultConverter converter;
// Prepare options for executing JitRt program.
- jitrt::Executable::ExecuteOpts opts;
+ runtime::Executable::ExecuteOpts opts;
// We don't expect to see any async tasks in the JitRt executable.
opts.async_task_runner =
- reinterpret_cast<jitrt::AsyncTaskRunner*>(0XDEADBEEF);
+ reinterpret_cast<runtime::AsyncTaskRunner*>(0XDEADBEEF);
// Get the async communications stream for async collectives.
int device_ordinal = run_options->stream()->parent()->device_ordinal();
@@ -669,7 +669,7 @@
async_comms_stream.ok() ? async_comms_stream->get() : nullptr);
// Pass auxiliary data to the custom call handlers.
- jitrt::CustomCall::UserData user_data;
+ runtime::CustomCall::UserData user_data;
user_data.insert_all(
run_options, &jitrt_executable->debug_options(),
&jitrt_executable->kernels_cache(),
@@ -678,9 +678,9 @@
opts.custom_call_data = &user_data;
// Collect all emitted diagnostic messages.
- jitrt::DiagnosticEngine diagnostic_engine;
+ runtime::DiagnosticEngine diagnostic_engine;
std::string diagnostic;
- diagnostic_engine.AddHandler([&](jitrt::Diagnostic& d) {
+ diagnostic_engine.AddHandler([&](runtime::Diagnostic& d) {
llvm::raw_string_ostream(diagnostic) << d.str();
return mlir::success();
});
@@ -689,7 +689,7 @@
// Get the default executable. We do not support specialization because
// all shapes are static. Default executable is guaranteed to be available.
- jitrt::Executable& executable = jitrt_executable->default_executable();
+ runtime::Executable& executable = jitrt_executable->default_executable();
// Execute with the prepared call frame.
executable.Execute(call_frame, opts);
@@ -1095,24 +1095,24 @@
auto buffer = llvm::MemoryBuffer::getMemBuffer(data, hlo_module->name());
// Create a JitRt function signature (all arguments passed as 1d memrefs).
- llvm::SmallVector<std::unique_ptr<jitrt::Type>> args;
- llvm::SmallVector<std::unique_ptr<jitrt::Type>> rt_args;
- rt_args.push_back(std::make_unique<jitrt::KernelContextOperandType>());
+ llvm::SmallVector<std::unique_ptr<runtime::Type>> args;
+ llvm::SmallVector<std::unique_ptr<runtime::Type>> rt_args;
+ rt_args.push_back(std::make_unique<runtime::KernelContextOperandType>());
for (int64_t size : buffer_sizes) {
auto i8 = tfrt::DType::I8;
- args.push_back(std::make_unique<jitrt::MemrefType>(size, i8));
- rt_args.push_back(std::make_unique<jitrt::MemrefType>(size, i8));
+ args.push_back(std::make_unique<runtime::MemrefType>(size, i8));
+ rt_args.push_back(std::make_unique<runtime::MemrefType>(size, i8));
}
- jitrt::FunctionType signature(std::move(args), /*results=*/{});
- jitrt::FunctionType rt_signature(std::move(rt_args), /*results=*/{});
+ runtime::FunctionType signature(std::move(args), /*results=*/{});
+ runtime::FunctionType rt_signature(std::move(rt_args), /*results=*/{});
- auto symbol_map = jitrt::GetSymbolsBinding(JitRtGpuCustomCalls());
+ auto symbol_map = runtime::GetSymbolsBinding(JitRtGpuCustomCalls());
// Load JitRt executable from an object file, and link it with Gpu runtime
// intrinsics implementing Gpu custom calls.
- auto executable = jitrt::Executable::LoadFromObjFile(
+ auto executable = runtime::Executable::LoadFromObjFile(
hlo_module->name(), std::move(buffer),
hlo_module->entry_computation()->name(), std::move(signature),
std::move(rt_signature), symbol_map);
diff --git a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
index 352fcc5..0a3c72f 100644
--- a/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
+++ b/tensorflow/compiler/xla/service/gpu/jitrt_custom_calls.cc
@@ -23,8 +23,12 @@
#include "llvm/ExecutionEngine/Orc/Mangling.h"
#include "mlir/Support/LogicalResult.h" // from @llvm-project
+#include "tensorflow/compiler/xla/runtime/arguments.h"
#include "tensorflow/compiler/xla/runtime/custom_call.h"
+#include "tensorflow/compiler/xla/runtime/executable.h"
+#include "tensorflow/compiler/xla/runtime/jit_executable.h"
#include "tensorflow/compiler/xla/runtime/type_id.h"
+#include "tensorflow/compiler/xla/runtime/types.h"
#include "tensorflow/compiler/xla/service/custom_call_status_internal.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
@@ -45,7 +49,6 @@
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/stream_executor/gpu/gpu_types.h"
-#include "tfrt/jitrt/jitrt.h" // from @tf_runtime
#include "tfrt/dtype/dtype.h" // from @tf_runtime
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
@@ -53,17 +56,17 @@
#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
xla::gpu::JitRtKernelsCache);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
xla::gpu::JitRtGemmConfigCache);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
xla::gpu::JitRtCollectiveSupport);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
xla::gpu::JitRtAsyncCollectiveSupport);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
const xla::ServiceExecutableRunOptions);
-TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(tfrt::jitrt::CustomCall,
+TFRT_DEFINE_EXPLICIT_DENSE_TYPE_ID(xla::runtime::CustomCall,
const xla::DebugOptions);
namespace xla {
@@ -83,12 +86,12 @@
using mlir::success;
using tfrt::MakeStringError;
-using tfrt::jitrt::CustomCall;
-using tfrt::jitrt::DirectCustomCallLibrary;
-using tfrt::jitrt::Executable;
+
+using ::xla::runtime::CustomCall;
+using ::xla::runtime::DirectCustomCallLibrary;
+using ::xla::runtime::Executable;
namespace se = ::stream_executor;
-namespace jitrt = ::tfrt::jitrt;
namespace lmhlo_gpu = ::mlir::lmhlo_gpu;
namespace mhlo = ::mlir::mhlo;
@@ -122,24 +125,23 @@
// Add custom call arguments and attributes encoding for custom HLO enums and
// structs, so that we can pass them to custom calls.
void PopulateLmhloToXlaAttrEncoding(
- jitrt::CustomCallAttrEncodingSet& encoding) {
- encoding.Add<
- jitrt::EnumAttrEncoding<lmhlo_gpu::ActivationAttr, lmhlo_gpu::Activation,
- se::dnn::ActivationMode>>(
+ tfrt::jitrt::CustomCallAttrEncodingSet& encoding) {
+ encoding.Add<tfrt::jitrt::EnumAttrEncoding<lmhlo_gpu::ActivationAttr,
+ lmhlo_gpu::Activation,
+ se::dnn::ActivationMode>>(
[](lmhlo_gpu::Activation value) -> se::dnn::ActivationMode {
return ConvertConvActivationMode(value).value();
});
- encoding.Add<jitrt::EnumAttrEncoding<lmhlo_gpu::CublasLtMatmulEpilogueAttr,
- lmhlo_gpu::CublasLtMatmulEpilogue,
- se::cuda::BlasLt::Epilogue>>(
- [](lmhlo_gpu::CublasLtMatmulEpilogue value)
- -> se::cuda::BlasLt::Epilogue {
- return cublas_lt::AsBlasLtEpilogue(value).value();
- });
+ encoding.Add<tfrt::jitrt::EnumAttrEncoding<
+ lmhlo_gpu::CublasLtMatmulEpilogueAttr, lmhlo_gpu::CublasLtMatmulEpilogue,
+ se::cuda::BlasLt::Epilogue>>([](lmhlo_gpu::CublasLtMatmulEpilogue value)
+ -> se::cuda::BlasLt::Epilogue {
+ return cublas_lt::AsBlasLtEpilogue(value).value();
+ });
- encoding.Add<
- jitrt::EnumAttrEncoding<mhlo::FftTypeAttr, mhlo::FftType, se::fft::Type>>(
+ encoding.Add<tfrt::jitrt::EnumAttrEncoding<mhlo::FftTypeAttr, mhlo::FftType,
+ se::fft::Type>>(
[](mhlo::FftType value) -> se::fft::Type {
switch (value) {
case mhlo::FftType::FFT:
@@ -156,9 +158,10 @@
});
using DotDimsAttr = mhlo::DotDimensionNumbersAttr;
- encoding.Add<jitrt::AggregateAttrEncoding<DotDimsAttr, DotDimensionNumbers>>(
+ encoding.Add<
+ tfrt::jitrt::AggregateAttrEncoding<DotDimsAttr, DotDimensionNumbers>>(
encoding,
- jitrt::AggregateAttrDef<DotDimsAttr>()
+ tfrt::jitrt::AggregateAttrDef<DotDimsAttr>()
.Add("lhs_batch", &DotDimsAttr::getLhsBatchingDimensions)
.Add("lhs_contract", &DotDimsAttr::getLhsContractingDimensions)
.Add("rhs_batch", &DotDimsAttr::getRhsBatchingDimensions)
@@ -166,9 +169,9 @@
using ConvDimsAttr = mhlo::ConvDimensionNumbersAttr;
encoding.Add<
- jitrt::AggregateAttrEncoding<ConvDimsAttr, ConvDimensionNumbers>>(
+ tfrt::jitrt::AggregateAttrEncoding<ConvDimsAttr, ConvDimensionNumbers>>(
encoding,
- jitrt::AggregateAttrDef<ConvDimsAttr>()
+ tfrt::jitrt::AggregateAttrDef<ConvDimsAttr>()
.Add("input_batch_dim", &ConvDimsAttr::getInputBatchDimension)
.Add("input_feature_dim", &ConvDimsAttr::getInputFeatureDimension)
.Add("input_spatial_dims", &ConvDimsAttr::getInputSpatialDimensions)
@@ -183,9 +186,10 @@
&ConvDimsAttr::getOutputSpatialDimensions));
using ConvConfigAttr = lmhlo_gpu::ConvolutionBackendConfigAttr;
- encoding.Add<jitrt::AggregateAttrEncoding<ConvConfigAttr, ConvBackendConfig>>(
+ encoding.Add<
+ tfrt::jitrt::AggregateAttrEncoding<ConvConfigAttr, ConvBackendConfig>>(
encoding,
- jitrt::AggregateAttrDef<ConvConfigAttr>()
+ tfrt::jitrt::AggregateAttrDef<ConvConfigAttr>()
.Add("algorithm", &ConvConfigAttr::getAlgorithm)
.Add("tensor_ops_enabled", &ConvConfigAttr::getTensorOpsEnabled)
.Add("is_cudnn_frontend", &ConvConfigAttr::getIsCudnnFrontend)
@@ -230,7 +234,7 @@
return se::DeviceMemoryBase(memref.data, size);
}
-static se::DeviceMemoryBase GetDeviceAddress(jitrt::FlatMemrefView& memref) {
+static se::DeviceMemoryBase GetDeviceAddress(runtime::FlatMemrefView& memref) {
return se::DeviceMemoryBase(memref.data, memref.size_in_bytes);
}
@@ -295,7 +299,7 @@
// -------------------------------------------------------------------------- //
-static Shape ToShape(const jitrt::StridedMemrefView& memref) {
+static Shape ToShape(const runtime::StridedMemrefView& memref) {
PrimitiveType type = TfrtToPrimitiveType(memref.dtype);
// Recover `minor_to_major` dimensions permutation from strides.
@@ -314,12 +318,15 @@
return ShapeUtil::MakeShapeWithLayout(type, memref.sizes, minor_to_major);
}
-static StatusOr<GemmConfig> GetGemmConfig(
- const jitrt::StridedMemrefView& lhs, const jitrt::StridedMemrefView& rhs,
- const jitrt::StridedMemrefView& out, int64_t algorithm, double alpha_real,
- double alpha_imag, double beta, ArrayRef<int64_t> lhs_batch,
- ArrayRef<int64_t> lhs_contract, ArrayRef<int64_t> rhs_batch,
- ArrayRef<int64_t> rhs_contract) {
+static StatusOr<GemmConfig> GetGemmConfig(const runtime::StridedMemrefView& lhs,
+ const runtime::StridedMemrefView& rhs,
+ const runtime::StridedMemrefView& out,
+ int64_t algorithm, double alpha_real,
+ double alpha_imag, double beta,
+ ArrayRef<int64_t> lhs_batch,
+ ArrayRef<int64_t> lhs_contract,
+ ArrayRef<int64_t> rhs_batch,
+ ArrayRef<int64_t> rhs_contract) {
return GemmConfig::For(ToShape(lhs), lhs_batch, lhs_contract, ToShape(rhs),
rhs_batch, rhs_contract, ToShape(out), alpha_real,
alpha_imag, beta, algorithm,
@@ -362,8 +369,8 @@
std::vector<DeviceBufferPair> device_buffers;
device_buffers.reserve(buffer_pairs);
for (int i = 0; i < buffer_pairs; ++i) {
- auto source = args.get<jitrt::StridedMemrefView>(i);
- auto destination = args.get<jitrt::StridedMemrefView>(i + buffer_pairs);
+ auto source = args.get<runtime::StridedMemrefView>(i);
+ auto destination = args.get<runtime::StridedMemrefView>(i + buffer_pairs);
if (failed(source) || failed(destination)) {
// Unsupported argument type.
return failure();
@@ -437,7 +444,7 @@
// Add MemRef arguments as buffer arguments.
for (unsigned i = 0; i < args.size(); ++i) {
// Simple row major memref passed as shapeless buffer.
- auto memref = args.get<jitrt::FlatMemrefView>(i);
+ auto memref = args.get<runtime::FlatMemrefView>(i);
if (succeeded(memref)) {
buffer_args.emplace_back(GetDeviceAddress(*memref));
continue;
@@ -445,7 +452,7 @@
// Memref layout must be encoded in the compiled device kernel, so we don't
// have to pass strides or minor to major dimensions order to the kernel.
- auto strided = args.get<jitrt::StridedMemrefView>(i);
+ auto strided = args.get<runtime::StridedMemrefView>(i);
if (succeeded(strided)) {
buffer_args.emplace_back(GetDeviceAddress(*strided));
continue;
@@ -488,11 +495,12 @@
LLVM_ATTRIBUTE_ALWAYS_INLINE
Error operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
- JitRtGemmConfigCache* configs, jitrt::StridedMemrefView lhs,
- jitrt::StridedMemrefView rhs, jitrt::StridedMemrefView out,
- int64_t algorithm, double alpha_real, double alpha_imag,
- double beta, DotDimensionNumbers dot_dims,
- int64_t uid) const;
+ JitRtGemmConfigCache* configs,
+ runtime::StridedMemrefView lhs,
+ runtime::StridedMemrefView rhs,
+ runtime::StridedMemrefView out, int64_t algorithm,
+ double alpha_real, double alpha_imag, double beta,
+ DotDimensionNumbers dot_dims, int64_t uid) const;
static Gemm Handler() { return Gemm(); }
};
@@ -501,9 +509,9 @@
Error Gemm::operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
JitRtGemmConfigCache* configs,
- jitrt::StridedMemrefView lhs,
- jitrt::StridedMemrefView rhs,
- jitrt::StridedMemrefView out, int64_t algorithm,
+ runtime::StridedMemrefView lhs,
+ runtime::StridedMemrefView rhs,
+ runtime::StridedMemrefView out, int64_t algorithm,
double alpha_real, double alpha_imag, double beta,
DotDimensionNumbers dot_dims, int64_t uid) const {
se::DeviceMemoryBase lhs_data = GetDeviceAddress(lhs);
@@ -537,9 +545,9 @@
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
.UserData<JitRtGemmConfigCache*>()
- .Arg<jitrt::StridedMemrefView>() // lhs
- .Arg<jitrt::StridedMemrefView>() // rhs
- .Arg<jitrt::StridedMemrefView>() // out
+ .Arg<runtime::StridedMemrefView>() // lhs
+ .Arg<runtime::StridedMemrefView>() // rhs
+ .Arg<runtime::StridedMemrefView>() // out
.Attr<int64_t>("algorithm")
.Attr<double>("alpha_real")
.Attr<double>("alpha_imag")
@@ -561,9 +569,9 @@
LLVM_ATTRIBUTE_ALWAYS_INLINE
Error operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
- jitrt::StridedMemrefView a, jitrt::StridedMemrefView b,
- jitrt::StridedMemrefView c, jitrt::StridedMemrefView d,
- Optional<jitrt::StridedMemrefView> bias, int64_t algorithm,
+ runtime::StridedMemrefView a, runtime::StridedMemrefView b,
+ runtime::StridedMemrefView c, runtime::StridedMemrefView d,
+ Optional<runtime::StridedMemrefView> bias, int64_t algorithm,
double alpha_real, double alpha_imag, double beta,
DotDimensionNumbers dot_dims,
se::cuda::BlasLt::Epilogue epilogue,
@@ -575,9 +583,9 @@
Error CublasLtMatmul::operator()(
const ServiceExecutableRunOptions* run_options,
- const DebugOptions* debug_options, jitrt::StridedMemrefView a,
- jitrt::StridedMemrefView b, jitrt::StridedMemrefView c,
- jitrt::StridedMemrefView d, Optional<jitrt::StridedMemrefView> bias,
+ const DebugOptions* debug_options, runtime::StridedMemrefView a,
+ runtime::StridedMemrefView b, runtime::StridedMemrefView c,
+ runtime::StridedMemrefView d, Optional<runtime::StridedMemrefView> bias,
int64_t algorithm, double alpha_real, double alpha_imag, double beta,
DotDimensionNumbers dot_dims, se::cuda::BlasLt::Epilogue epilogue,
ArrayRef<int32_t> precision, int64_t uid) const {
@@ -616,7 +624,7 @@
// Adds custom call bindings for matmul operations.
template <typename... Ts>
-static auto BindMatmulAttributes(jitrt::CustomCallBinding<Ts...> binding) {
+static auto BindMatmulAttributes(runtime::CustomCallBinding<Ts...> binding) {
return std::move(binding)
.template Attr<int64_t>("algorithm")
.template Attr<double>("alpha_real")
@@ -634,11 +642,11 @@
BindMatmulAttributes(CustomCall::Bind("xla.gpu.cublas.lt.matmul")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
- .Arg<jitrt::StridedMemrefView>() // a
- .Arg<jitrt::StridedMemrefView>() // b
- .Arg<jitrt::StridedMemrefView>() // c
- .Arg<jitrt::StridedMemrefView>() // d
- .Value(CustomCall::None) // bias
+ .Arg<runtime::StridedMemrefView>() // a
+ .Arg<runtime::StridedMemrefView>() // b
+ .Arg<runtime::StridedMemrefView>() // c
+ .Arg<runtime::StridedMemrefView>() // d
+ .Value(CustomCall::None) // bias
)
.To<RuntimeChecks()>(CublasLtMatmul::Handler())
.release();
@@ -652,11 +660,11 @@
BindMatmulAttributes(CustomCall::Bind("xla.gpu.cublas.lt.matmul.bias")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
- .Arg<jitrt::StridedMemrefView>() // a
- .Arg<jitrt::StridedMemrefView>() // b
- .Arg<jitrt::StridedMemrefView>() // c
- .Arg<jitrt::StridedMemrefView>() // d
- .Arg<jitrt::StridedMemrefView>() // bias
+ .Arg<runtime::StridedMemrefView>() // a
+ .Arg<runtime::StridedMemrefView>() // b
+ .Arg<runtime::StridedMemrefView>() // c
+ .Arg<runtime::StridedMemrefView>() // d
+ .Arg<runtime::StridedMemrefView>() // bias
)
.To<RuntimeChecks()>(CublasLtMatmul::Handler())
.release();
@@ -699,8 +707,8 @@
static GpuConvDescriptor GetConvDescriptor(
CudnnConvKind kind,
// Arguments
- jitrt::StridedMemrefView operand0, jitrt::StridedMemrefView operand1,
- jitrt::StridedMemrefView output, jitrt::FlatMemrefView scratch,
+ runtime::StridedMemrefView operand0, runtime::StridedMemrefView operand1,
+ runtime::StridedMemrefView output, runtime::FlatMemrefView scratch,
// Attributes
ConvDimensionNumbers dims, Window w, ConvBackendConfig b, ConvAttrs attrs,
// Conv-specific arguments and attributes
@@ -711,7 +719,7 @@
descriptor.kind = kind;
// Apply backend config layout to the shape.
- auto apply_layout = [](jitrt::StridedMemrefView& memref,
+ auto apply_layout = [](runtime::StridedMemrefView& memref,
ArrayRef<int64_t> minor_to_major) {
Shape shape = ToShape(memref);
return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
@@ -790,10 +798,11 @@
LLVM_ATTRIBUTE_ALWAYS_INLINE
Error operator()(
const ServiceExecutableRunOptions* run_options,
- const DebugOptions* debug_options, jitrt::StridedMemrefView operand0,
- jitrt::StridedMemrefView operand1, Optional<jitrt::FlatMemrefView> bias,
- Optional<jitrt::StridedMemrefView> side_input,
- jitrt::StridedMemrefView output, jitrt::FlatMemrefView scratch,
+ const DebugOptions* debug_options, runtime::StridedMemrefView operand0,
+ runtime::StridedMemrefView operand1,
+ Optional<runtime::FlatMemrefView> bias,
+ Optional<runtime::StridedMemrefView> side_input,
+ runtime::StridedMemrefView output, runtime::FlatMemrefView scratch,
ConvDimensionNumbers conv_dims,
// Window config
ArrayRef<int64_t> window_strides, ArrayRef<int64_t> padding,
@@ -859,7 +868,7 @@
// Adds custom call bindings for convolution operations.
template <typename... Ts>
-static auto BindConvAttributes(jitrt::CustomCallBinding<Ts...> binding) {
+static auto BindConvAttributes(runtime::CustomCallBinding<Ts...> binding) {
return std::move(binding)
// Convolution dimensions numbers
.template Attr<ConvDimensionNumbers>("conv_dims")
@@ -882,12 +891,12 @@
BindConvAttributes(CustomCall::Bind("xla.gpu.conv")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
- .Arg<jitrt::StridedMemrefView>() // operand0
- .Arg<jitrt::StridedMemrefView>() // operand1
- .Value(CustomCall::None) // bias
- .Value(CustomCall::None) // side_input
- .Arg<jitrt::StridedMemrefView>() // output
- .Arg<jitrt::FlatMemrefView>() // scratch
+ .Arg<runtime::StridedMemrefView>() // operand0
+ .Arg<runtime::StridedMemrefView>() // operand1
+ .Value(CustomCall::None) // bias
+ .Value(CustomCall::None) // side_input
+ .Arg<runtime::StridedMemrefView>() // output
+ .Arg<runtime::FlatMemrefView>() // scratch
)
.To(Conv::Handler(kind))
.release();
@@ -902,12 +911,12 @@
BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
- .Arg<jitrt::StridedMemrefView>() // operand0
- .Arg<jitrt::StridedMemrefView>() // operand1
- .Arg<jitrt::FlatMemrefView>() // bias
- .Value(CustomCall::None) // side_input
- .Arg<jitrt::StridedMemrefView>() // output
- .Arg<jitrt::FlatMemrefView>() // scratch
+ .Arg<runtime::StridedMemrefView>() // operand0
+ .Arg<runtime::StridedMemrefView>() // operand1
+ .Arg<runtime::FlatMemrefView>() // bias
+ .Value(CustomCall::None) // side_input
+ .Arg<runtime::StridedMemrefView>() // output
+ .Arg<runtime::FlatMemrefView>() // scratch
)
.Attr<se::dnn::ActivationMode>("activation_mode")
.To(Conv::Handler(kind))
@@ -923,12 +932,12 @@
BindConvAttributes(CustomCall::Bind("xla.gpu.conv.fused.side_input")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
- .Arg<jitrt::StridedMemrefView>() // operand0
- .Arg<jitrt::StridedMemrefView>() // operand1
- .Arg<jitrt::FlatMemrefView>() // bias
- .Arg<jitrt::StridedMemrefView>() // side_input
- .Arg<jitrt::StridedMemrefView>() // output
- .Arg<jitrt::FlatMemrefView>() // scratch
+ .Arg<runtime::StridedMemrefView>() // operand0
+ .Arg<runtime::StridedMemrefView>() // operand1
+ .Arg<runtime::FlatMemrefView>() // bias
+ .Arg<runtime::StridedMemrefView>() // side_input
+ .Arg<runtime::StridedMemrefView>() // output
+ .Arg<runtime::FlatMemrefView>() // scratch
)
.Attr<se::dnn::ActivationMode>("activation_mode")
.Attr<double>("side_input_scale")
@@ -964,7 +973,7 @@
size_t index = 0;
for (auto& source : source_buffers.leaves()) {
// Get the destination buffer.
- auto dest = args.get<jitrt::StridedMemrefView>(index);
+ auto dest = args.get<runtime::StridedMemrefView>(index);
if (failed(dest))
return MakeStringError("Failed to get the destination buffer");
@@ -1038,7 +1047,7 @@
size_t index = 0;
for (auto& dest : dest_buffers->leaves()) {
// Get the source buffer.
- auto source = args.get<jitrt::StridedMemrefView>(index);
+ auto source = args.get<runtime::StridedMemrefView>(index);
if (failed(source))
return MakeStringError("Failed to get the source buffer");
@@ -1092,15 +1101,16 @@
template <MemcpyDirection direction>
struct Memcpy {
Error operator()(const ServiceExecutableRunOptions* run_options,
- jitrt::FlatMemrefView dst, jitrt::FlatMemrefView src) const;
+ runtime::FlatMemrefView dst,
+ runtime::FlatMemrefView src) const;
static Memcpy Handler() { return Memcpy(); }
};
} // namespace
template <MemcpyDirection direction>
Error Memcpy<direction>::operator()(
- const ServiceExecutableRunOptions* run_options, jitrt::FlatMemrefView dst,
- jitrt::FlatMemrefView src) const {
+ const ServiceExecutableRunOptions* run_options, runtime::FlatMemrefView dst,
+ runtime::FlatMemrefView src) const {
se::Stream* stream = run_options->stream();
if (dst.size_in_bytes != src.size_in_bytes) {
@@ -1139,8 +1149,8 @@
static bool MemcpyFn(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
.UserData<const ServiceExecutableRunOptions*>()
- .Arg<jitrt::FlatMemrefView>() // dst
- .Arg<jitrt::FlatMemrefView>() // src
+ .Arg<runtime::FlatMemrefView>() // dst
+ .Arg<runtime::FlatMemrefView>() // src
.To<RuntimeChecks()>(Memcpy<direction>::Handler())
.release();
@@ -1153,7 +1163,7 @@
struct Memset {
Error operator()(const ServiceExecutableRunOptions* run_options,
- jitrt::FlatMemrefView dst,
+ runtime::FlatMemrefView dst,
CustomCall::VariantArg constant) const;
static Memset Handler() { return Memset(); }
};
@@ -1161,7 +1171,7 @@
} // namespace
Error Memset::operator()(const ServiceExecutableRunOptions* run_options,
- jitrt::FlatMemrefView dst,
+ runtime::FlatMemrefView dst,
CustomCall::VariantArg constant) const {
se::Stream* stream = run_options->stream();
se::DeviceMemoryBase dst_data = GetDeviceAddress(dst);
@@ -1205,8 +1215,8 @@
static bool MemsetFn(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.memset")
.UserData<const ServiceExecutableRunOptions*>()
- .Arg<jitrt::FlatMemrefView>() // dst
- .Arg<CustomCall::VariantArg>() // constant
+ .Arg<runtime::FlatMemrefView>() // dst
+ .Arg<CustomCall::VariantArg>() // constant
.To<RuntimeChecks()>(Memset::Handler())
.release();
@@ -1219,16 +1229,16 @@
struct Fft {
LLVM_ATTRIBUTE_ALWAYS_INLINE
Error operator()(const ServiceExecutableRunOptions* run_options,
- jitrt::StridedMemrefView input,
- jitrt::StridedMemrefView output,
+ runtime::StridedMemrefView input,
+ runtime::StridedMemrefView output,
ArrayRef<int64_t> fft_length, se::fft::Type fft_type) const;
static Fft Handler() { return Fft(); }
};
} // namespace
Error Fft::operator()(const ServiceExecutableRunOptions* run_options,
- jitrt::StridedMemrefView input,
- jitrt::StridedMemrefView output,
+ runtime::StridedMemrefView input,
+ runtime::StridedMemrefView output,
ArrayRef<int64_t> fft_length,
se::fft::Type fft_type) const {
// TODO(ezhulenev): Cache FFT plans in the GpuExecutable.
@@ -1270,13 +1280,12 @@
static bool Fft(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.fft")
.UserData<const ServiceExecutableRunOptions*>()
- .Arg<jitrt::StridedMemrefView>() // input
- .Arg<jitrt::StridedMemrefView>() // output
+ .Arg<runtime::StridedMemrefView>() // input
+ .Arg<runtime::StridedMemrefView>() // output
.Attr<ArrayRef<int64_t>>("fft_length")
.Attr<se::fft::Type>("fft_type")
.To<RuntimeChecks()>(Fft::Handler())
.release();
-
return succeeded(Executable::Call(ctx, *handler, args, attrs));
}
@@ -1286,19 +1295,20 @@
struct Cholesky {
LLVM_ATTRIBUTE_ALWAYS_INLINE
Error operator()(const ServiceExecutableRunOptions* run_options,
- const DebugOptions* debug_options, jitrt::MemrefView operand,
- jitrt::MemrefView a, jitrt::MemrefView workspace,
- jitrt::MemrefView info, int64_t batch_size, bool is_lower,
- int64_t n) const;
+ const DebugOptions* debug_options,
+ runtime::MemrefView operand, runtime::MemrefView a,
+ runtime::MemrefView workspace, runtime::MemrefView info,
+ int64_t batch_size, bool is_lower, int64_t n) const;
static Cholesky Handler() { return Cholesky(); }
};
} // namespace
Error Cholesky::operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
- jitrt::MemrefView operand, jitrt::MemrefView a,
- jitrt::MemrefView workspace, jitrt::MemrefView info,
- int64_t batch_size, bool is_lower, int64_t n) const {
+ runtime::MemrefView operand, runtime::MemrefView a,
+ runtime::MemrefView workspace,
+ runtime::MemrefView info, int64_t batch_size,
+ bool is_lower, int64_t n) const {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
se::DeviceMemoryBase operand_buffer = GetDeviceAddress(operand);
se::DeviceMemoryBase a_buffer = GetDeviceAddress(a);
@@ -1332,10 +1342,10 @@
static auto* handler = CustomCall::Bind("xla.gpu.cholesky")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
- .Arg<jitrt::MemrefView>() // operand
- .Arg<jitrt::MemrefView>() // a
- .Arg<jitrt::MemrefView>() // workspace
- .Arg<jitrt::MemrefView>() // info
+ .Arg<runtime::MemrefView>() // operand
+ .Arg<runtime::MemrefView>() // a
+ .Arg<runtime::MemrefView>() // workspace
+ .Arg<runtime::MemrefView>() // info
.Attr<int64_t>("batch_size")
.Attr<bool>("is_lower")
.Attr<int64_t>("n")
@@ -1362,9 +1372,10 @@
Error operator()(const ServiceExecutableRunOptions* run_options,
const DebugOptions* debug_options,
- jitrt::StridedMemrefView a, jitrt::StridedMemrefView b,
- jitrt::StridedMemrefView result, jitrt::FlatMemrefView temp,
- bool left_side, bool lower, bool unit_diagonal,
+ runtime::StridedMemrefView a, runtime::StridedMemrefView b,
+ runtime::StridedMemrefView result,
+ runtime::FlatMemrefView temp, bool left_side, bool lower,
+ bool unit_diagonal,
TriangularSolveOptions::Transpose transpose_a) const;
static TriangularSolve Handler() { return TriangularSolve(); }
};
@@ -1381,10 +1392,10 @@
return MakeStringError("Expected 4 arguments, got %n", args.size());
// Check if all arguments have the correct type.
- auto a = args.get<jitrt::StridedMemrefView>(0);
- auto b = args.get<jitrt::StridedMemrefView>(1);
- auto result = args.get<jitrt::StridedMemrefView>(2);
- auto temp = args.get<jitrt::FlatMemrefView>(3);
+ auto a = args.get<runtime::StridedMemrefView>(0);
+ auto b = args.get<runtime::StridedMemrefView>(1);
+ auto result = args.get<runtime::StridedMemrefView>(2);
+ auto temp = args.get<runtime::FlatMemrefView>(3);
if (failed(a) || failed(b) || failed(result) || failed(temp))
return MakeStringError("Incorrect argument types");
@@ -1400,10 +1411,10 @@
Error TriangularSolve::operator()(
const ServiceExecutableRunOptions* run_options,
- const DebugOptions* debug_options, jitrt::StridedMemrefView a,
- jitrt::StridedMemrefView b, jitrt::StridedMemrefView result,
- jitrt::FlatMemrefView temp, bool left_side, bool lower, bool unit_diagonal,
- TriangularSolveOptions::Transpose transpose_a) const {
+ const DebugOptions* debug_options, runtime::StridedMemrefView a,
+ runtime::StridedMemrefView b, runtime::StridedMemrefView result,
+ runtime::FlatMemrefView temp, bool left_side, bool lower,
+ bool unit_diagonal, TriangularSolveOptions::Transpose transpose_a) const {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
se::Stream* stream = run_options->stream();
@@ -1510,11 +1521,11 @@
for (unsigned i = 0; i < args.size(); ++i) {
// We use zero-sized memrefs to represent holes in custom calls with target
// arguments mapping (see `CustomCallTargetArgMapping`).
- if (auto memref = args.get<jitrt::FlatMemrefView>(i); succeeded(memref)) {
+ if (auto memref = args.get<runtime::FlatMemrefView>(i); succeeded(memref)) {
buffers.push_back(memref->size_in_bytes == 0 ? nullptr : memref->data);
continue;
}
- if (auto strided = args.get<jitrt::StridedMemrefView>(i);
+ if (auto strided = args.get<runtime::StridedMemrefView>(i);
succeeded(strided)) {
int64_t size_in_bytes = GetHostSize(strided->dtype);
for (int64_t size : strided->sizes) size_in_bytes *= size;
@@ -1561,7 +1572,7 @@
static auto* handler = CustomCall::Bind("xla.gpu.memcpy")
.UserData<const ServiceExecutableRunOptions*>()
.UserData<const DebugOptions*>()
- .Arg<jitrt::CustomCall::RemainingArgs>() // args
+ .Arg<CustomCall::RemainingArgs>() // args
.Attr<StringRef>("call_target_name")
.Attr<int32_t>("api_version")
.Attr<StringRef>("backend_config")
@@ -2072,13 +2083,13 @@
struct ReplicaId {
LLVM_ATTRIBUTE_ALWAYS_INLINE
Error operator()(const ServiceExecutableRunOptions* run_options,
- jitrt::FlatMemrefView result) const;
+ runtime::FlatMemrefView result) const;
static ReplicaId Handler() { return ReplicaId(); }
};
} // namespace
Error ReplicaId::operator()(const ServiceExecutableRunOptions* run_options,
- jitrt::FlatMemrefView result) const {
+ runtime::FlatMemrefView result) const {
VLOG(3) << "Running ReplicaId";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
@@ -2100,7 +2111,7 @@
static bool ReplicaId(runtime::KernelContext* ctx, void** args, void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.replica_id")
.UserData<const ServiceExecutableRunOptions*>()
- .Arg<jitrt::FlatMemrefView>() // result
+ .Arg<runtime::FlatMemrefView>() // result
.To<RuntimeChecks()>(ReplicaId::Handler())
.release();
@@ -2113,13 +2124,13 @@
struct PartitionId {
LLVM_ATTRIBUTE_ALWAYS_INLINE
Error operator()(const ServiceExecutableRunOptions* run_options,
- jitrt::FlatMemrefView result) const;
+ runtime::FlatMemrefView result) const;
static PartitionId Handler() { return PartitionId(); }
};
} // namespace
Error PartitionId::operator()(const ServiceExecutableRunOptions* run_options,
- jitrt::FlatMemrefView result) const {
+ runtime::FlatMemrefView result) const {
VLOG(3) << "Running PartitionId";
se::Stream* stream = run_options->stream();
NcclExecuteParams params(*run_options, stream);
@@ -2142,7 +2153,7 @@
void** attrs) {
static auto* handler = CustomCall::Bind("xla.gpu.partition_id")
.UserData<const ServiceExecutableRunOptions*>()
- .Arg<jitrt::FlatMemrefView>() // result
+ .Arg<runtime::FlatMemrefView>() // result
.To<RuntimeChecks()>(PartitionId::Handler())
.release();