| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <cstdint> |
| #include <string> |
| #include <vector> |
| |
| #include "absl/base/casts.h" |
| #include "absl/hash/hash.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/synchronization/mutex.h" |
| #include "absl/types/optional.h" |
| #include "absl/types/span.h" |
| #include "include/pybind11/numpy.h" |
| #include "include/pybind11/pybind11.h" |
| #include "tensorflow/compiler/xla/client/client_library.h" |
| #include "tensorflow/compiler/xla/client/lib/comparators.h" |
| #include "tensorflow/compiler/xla/client/lib/math.h" |
| #include "tensorflow/compiler/xla/client/lib/qr.h" |
| #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" |
| #include "tensorflow/compiler/xla/client/lib/svd.h" |
| #include "tensorflow/compiler/xla/client/local_client.h" |
| #include "tensorflow/compiler/xla/client/xla_builder.h" |
| #include "tensorflow/compiler/xla/client/xla_computation.h" |
| #include "tensorflow/compiler/xla/python/local_client.h" |
| #include "tensorflow/compiler/xla/python/python_ref_manager.h" |
| #include "tensorflow/compiler/xla/python/types.h" |
| #include "tensorflow/compiler/xla/python/xrt.h" |
| #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" |
| #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" |
| #include "tensorflow/compiler/xla/service/hlo_instruction.h" |
| #include "tensorflow/compiler/xla/service/hlo_module.h" |
| #include "tensorflow/compiler/xla/service/hlo_parser.h" |
| #include "tensorflow/compiler/xla/service/name_uniquer.h" |
| #include "tensorflow/compiler/xla/service/platform_util.h" |
| #include "tensorflow/compiler/xla/shape.h" |
| #include "tensorflow/compiler/xla/shape_util.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| |
| namespace xla { |
| |
| namespace py = pybind11; |
| |
| namespace { |
| |
| struct Uniquer { |
| absl::Mutex mu; |
| NameUniquer name_uniquer GUARDED_BY(mu); |
| }; |
| |
| Uniquer* GetUniquer() { |
| static Uniquer* uniquer = new Uniquer; |
| return uniquer; |
| } |
| |
| static std::string UniquifyName(const std::string& name) { |
| Uniquer* uniquer = GetUniquer(); |
| absl::MutexLock lock(&uniquer->mu); |
| return uniquer->name_uniquer.GetUniqueName(name); |
| } |
| |
| // Converts a computation to a serialized HloModuleProto. |
| StatusOr<py::bytes> GetComputationSerializedProto( |
| const XlaComputation& computation) { |
| std::string result; |
| if (!computation.proto().SerializeToString(&result)) { |
| return Unknown("Failed to serialize the HloModuleProto."); |
| } |
| return py::bytes(result); |
| } |
| |
| // Converts a computation to textual HLO form. |
| StatusOr<std::string> GetComputationHloText(const XlaComputation& computation) { |
| TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, |
| HloModule::CreateModuleConfigFromProto( |
| computation.proto(), GetDebugOptionsFromFlags())); |
| TF_ASSIGN_OR_RETURN( |
| std::unique_ptr<HloModule> hlo_module, |
| HloModule::CreateFromProto(computation.proto(), module_config)); |
| HloPrintOptions options; |
| options = HloPrintOptions::ShortParsable(); |
| options.set_print_large_constants(false); |
| return hlo_module->ToString(options); |
| } |
| |
| // Converts a computation to HLO dot graph form. |
| StatusOr<std::string> GetComputationHloDotGraph( |
| const XlaComputation& computation) { |
| TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, |
| HloModule::CreateModuleConfigFromProto( |
| computation.proto(), GetDebugOptionsFromFlags())); |
| TF_ASSIGN_OR_RETURN( |
| std::unique_ptr<HloModule> hlo_module, |
| HloModule::CreateFromProto(computation.proto(), module_config)); |
| return RenderGraph(*hlo_module->entry_computation(), /*label=*/"", |
| hlo_module->config().debug_options(), |
| RenderedGraphFormat::kDot); |
| } |
| |
| // Registers a 'fn_capsule' as a CPU custom call target. |
| // 'fn_capsule' must be a void* pointer encapsulated in a PyCapsule object, |
| // with name "xla._CUSTOM_CALL_TARGET". |
| // 'platform' is an XLA platform name, e.g., "Host" or "CUDA". |
| Status PyRegisterCustomCallTarget(const std::string& fn_name, |
| py::capsule capsule, |
| const std::string& platform) { |
| static const char* const kName = "xla._CUSTOM_CALL_TARGET"; |
| // TODO(phawkins): remove old name after fixing users. |
| static const char* const kOldCpuName = "xla._CPU_CUSTOM_CALL_TARGET"; |
| if (absl::string_view(capsule.name()) != kName && |
| absl::string_view(capsule.name()) != kOldCpuName) { |
| return InvalidArgument( |
| "Argument to RegisterCustomCallTargetRegistry was not a " |
| "xla._CUSTOM_CALL_TARGET capsule."); |
| } |
| CustomCallTargetRegistry::Global()->Register( |
| fn_name, static_cast<void*>(capsule), platform); |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| PYBIND11_MODULE(xla_extension, m) { |
| // Types |
| py::enum_<PrimitiveType>(m, "PrimitiveType") |
| .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID) |
| .value("PRED", PRED) |
| .value("S8", S8) |
| .value("S16", S16) |
| .value("S32", S32) |
| .value("S64", S64) |
| .value("U8", U8) |
| .value("U16", U16) |
| .value("U32", U32) |
| .value("U64", U64) |
| .value("F16", F16) |
| .value("BF16", BF16) |
| .value("F32", F32) |
| .value("F64", F64) |
| .value("C64", C64) |
| .value("C128", C128) |
| .value("TUPLE", TUPLE) |
| .value("OPAQUE_TYPE", OPAQUE_TYPE) |
| .value("TOKEN", TOKEN); |
| |
| // Shapes |
| py::class_<Shape> shape_class(m, "Shape"); |
| shape_class |
| .def(py::init([](const string& s) { |
| return absl::make_unique<Shape>(ValueOrThrow(ParseShape(s))); |
| })) |
| .def_static( |
| "tuple_shape", |
| [](std::vector<Shape> shapes) -> Shape { |
| return ShapeUtil::MakeTupleShape(shapes); |
| }, |
| "Constructs a tuple shape.") |
| .def_static( |
| "array_shape", |
| [](PrimitiveType type, py::object dims_seq, |
| absl::optional<py::object> layout_seq) -> Shape { |
| std::vector<int64> dims = IntSequenceToVector(dims_seq); |
| if (layout_seq) { |
| std::vector<int64> layout = IntSequenceToVector(*layout_seq); |
| return ShapeUtil::MakeShapeWithLayout(type, dims, layout); |
| } else { |
| Shape shape = ShapeUtil::MakeShape(type, dims); |
| shape.clear_layout(); |
| return shape; |
| } |
| }, |
| "Constructs an array shape.", py::arg("type"), py::arg("dims"), |
| py::arg("layout") = absl::nullopt) |
| .def_static( |
| "array_shape", |
| [](py::dtype dtype, py::object dims_seq, |
| absl::optional<py::object> layout_seq) -> Shape { |
| PrimitiveType type = ValueOrThrow(DtypeToPrimitiveType(dtype)); |
| std::vector<int64> dims = IntSequenceToVector(dims_seq); |
| if (layout_seq) { |
| std::vector<int64> layout = IntSequenceToVector(*layout_seq); |
| return ShapeUtil::MakeShapeWithLayout(type, dims, layout); |
| } else { |
| Shape shape = ShapeUtil::MakeShape(type, dims); |
| shape.clear_layout(); |
| return shape; |
| } |
| }, |
| "Constructs an array shape.", py::arg("type"), py::arg("dims"), |
| py::arg("layout") = absl::nullopt) |
| .def("dimensions", |
| [](const Shape& shape) -> py::tuple { |
| return IntSpanToTuple(shape.dimensions()); |
| }) |
| .def("xla_element_type", &Shape::element_type) |
| .def("element_type", |
| [](const Shape& shape) { |
| return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type())); |
| }) |
| .def("numpy_dtype", |
| [](const Shape& shape) { |
| if (shape.IsTuple()) { |
| return py::dtype("O"); |
| } |
| return ValueOrThrow(PrimitiveTypeToDtype(shape.element_type())); |
| }) |
| .def("is_tuple", &Shape::IsTuple) |
| .def("is_array", &Shape::IsArray) |
| .def("rank", &Shape::rank) |
| .def("to_serialized_proto", |
| [](const Shape& shape) { |
| ShapeProto proto = shape.ToProto(); |
| return py::bytes(proto.SerializeAsString()); |
| }) |
| .def("tuple_shapes", |
| [](const Shape& shape) { |
| return std::vector<Shape>(shape.tuple_shapes()); |
| }) |
| .def( |
| "with_major_to_minor_layout_if_absent", |
| [](const Shape& shape) { |
| Shape out = shape; |
| ShapeUtil::ForEachMutableSubshape( |
| &out, [](Shape* subshape, const ShapeIndex&) { |
| if (!subshape->has_layout()) { |
| LayoutUtil::SetToDefaultLayout(subshape); |
| } |
| }); |
| return out; |
| }, |
| "Returns a copy of a shape with missing layouts set to " |
| "major-to-minor.") |
| .def("__eq__", [](const Shape& shape, |
| const Shape& other) { return shape == other; }) |
| .def("__ne__", [](const Shape& shape, |
| const Shape& other) { return shape != other; }) |
| .def("__hash__", |
| [](const Shape& shape) { return absl::Hash<Shape>()(shape); }) |
| .def("__repr__", [](const Shape& shape) { |
| return shape.ToString(/*print_layouts=*/true); |
| }); |
| |
| py::class_<ProgramShape>(m, "ProgramShape") |
| .def(py::init( |
| [](absl::Span<const Shape> params, Shape result) -> ProgramShape { |
| ProgramShape program_shape; |
| for (const Shape& param : params) { |
| *program_shape.add_parameters() = param; |
| } |
| *program_shape.mutable_result() = result; |
| return program_shape; |
| })) |
| .def("parameter_shapes", |
| static_cast<const std::vector<Shape>& (ProgramShape::*)() const>( |
| &ProgramShape::parameters)) |
| .def("result_shape", &ProgramShape::result) |
| .def("__repr__", &ProgramShape::ToString); |
| |
| // Literals |
| py::class_<Literal, std::shared_ptr<Literal>>(m, "Literal") |
| .def("__repr__", &Literal::ToString); |
| py::class_<LiteralSlice>(m, "LiteralSlice"); |
| py::implicitly_convertible<Literal, LiteralSlice>(); |
| py::implicitly_convertible<BorrowingLiteral, LiteralSlice>(); |
| |
| // Device assignments |
| py::class_<DeviceAssignment>(m, "DeviceAssignment") |
| .def_static("create", |
| [](py::array_t<int> array) -> StatusOr<DeviceAssignment> { |
| if (array.ndim() != 2) { |
| return InvalidArgument( |
| "Argument to DeviceAssignment constructor must be a " |
| "2D array, " |
| "received an %dD array.", |
| array.ndim()); |
| } |
| DeviceAssignment result(array.shape(0), array.shape(1)); |
| for (int i = 0; i < array.shape(0); ++i) { |
| for (int j = 0; j < array.shape(1); ++j) { |
| result(i, j) = array.at(i, j); |
| } |
| } |
| return result; |
| }) |
| .def("replica_count", &DeviceAssignment::replica_count) |
| .def("computation_count", &DeviceAssignment::computation_count) |
| .def("__repr__", &DeviceAssignment::ToString); |
| |
| // Local XLA client methods. |
| |
| // Custom-call targets. |
| m.def("RegisterCustomCallTarget", &PyRegisterCustomCallTarget); |
| |
| py::class_<AllocatorConfig> alloc_config(m, "AllocatorConfig"); |
| alloc_config.def(py::init<>()) |
| .def_readwrite("kind", &AllocatorConfig::kind) |
| .def_readwrite("memory_fraction", &AllocatorConfig::memory_fraction) |
| .def_readwrite("preallocate", &AllocatorConfig::preallocate); |
| py::enum_<AllocatorConfig::Kind>(alloc_config, "Kind") |
| .value("DEFAULT", AllocatorConfig::Kind::kDefault) |
| .value("PLATFORM", AllocatorConfig::Kind::kPlatform) |
| .value("BFC", AllocatorConfig::Kind::kBFC); |
| |
| py::class_<PyLocalClient, std::shared_ptr<PyLocalClient>>(m, "LocalClient") |
| .def_static("Get", &PyLocalClient::Get, py::arg("platform"), |
| py::arg("xla_platform_id"), py::arg("asynchronous"), |
| py::arg("allocator_config") = AllocatorConfig()) |
| .def("DeviceCount", &PyLocalClient::device_count) |
| .def("TransferToInfeed", |
| [](PyLocalClient* client, const LiteralSlice& literal, |
| int device_ordinal) { |
| GlobalPyRefManager()->CollectGarbage(); |
| py::gil_scoped_release gil_release; |
| return client->TransferToInfeed(literal, device_ordinal); |
| }) |
| .def("TransferFromOutfeed", |
| [](PyLocalClient* client, const Shape& shape, |
| int device_ordinal) -> StatusOr<py::object> { |
| GlobalPyRefManager()->CollectGarbage(); |
| std::shared_ptr<Literal> literal_shared; |
| { |
| py::gil_scoped_release gil_release; |
| TF_ASSIGN_OR_RETURN(Literal literal, client->TransferFromOutfeed( |
| shape, device_ordinal)); |
| literal_shared = std::make_shared<Literal>(std::move(literal)); |
| } |
| return LiteralToPython(std::move(literal_shared)); |
| }); |
| |
| py::class_<PyLocalBuffer>(m, "PyLocalBuffer") |
| .def_static( |
| "from_python", |
| [](const pybind11::object& argument, |
| std::shared_ptr<PyLocalClient> client, |
| int device_ordinal) -> StatusOr<std::unique_ptr<PyLocalBuffer>> { |
| GlobalPyRefManager()->CollectGarbage(); |
| TF_ASSIGN_OR_RETURN(PythonBufferTree tree, |
| GetPythonBufferTree(argument)); |
| std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref = |
| GlobalPyRefManager()->ManageReferences( |
| absl::MakeSpan(tree.arrays)); |
| tree.arrays.clear(); |
| |
| std::vector<BorrowingLiteral> leaves; |
| leaves.insert(leaves.end(), |
| std::make_move_iterator(tree.leaves.begin()), |
| std::make_move_iterator(tree.leaves.end())); |
| |
| py::gil_scoped_release gil_release; |
| return PyLocalBuffer::FromLiterals( |
| std::move(leaves), tree.shape, std::move(py_buffer_ref), |
| std::move(client), device_ordinal); |
| }) |
| .def_static("make_tuple", &PyLocalBuffer::MakeTuple) |
| .def("copy_to_device", |
| [](PyLocalBuffer* buffer, int dst_device_ordinal) { |
| GlobalPyRefManager()->CollectGarbage(); |
| py::gil_scoped_release gil_release; |
| return buffer->CopyToDevice(dst_device_ordinal); |
| }) |
| .def("delete", &PyLocalBuffer::Delete) |
| .def("destructure", &PyLocalBuffer::DestructureTuple) |
| .def("block_host_until_ready", |
| [](PyLocalBuffer* buffer) { |
| GlobalPyRefManager()->CollectGarbage(); |
| py::gil_scoped_release gil_release; |
| return buffer->BlockHostUntilReady(); |
| }) |
| .def("copy_to_host_async", &PyLocalBuffer::CopyToHostAsync) |
| .def("to_py", |
| [](PyLocalBuffer* buffer) -> StatusOr<py::object> { |
| GlobalPyRefManager()->CollectGarbage(); |
| std::shared_ptr<Literal> literal; |
| { |
| py::gil_scoped_release gil_release; |
| TF_ASSIGN_OR_RETURN(literal, buffer->ToLiteral()); |
| } |
| return LiteralToPython(std::move(literal)); |
| }) |
| .def("shape", &PyLocalBuffer::on_host_shape) |
| .def("device", &PyLocalBuffer::device_ordinal) |
| .def("is_deleted", |
| [](const PyLocalBuffer& buffer) { |
| return buffer.DeviceBuffer() == nullptr; |
| }) |
| .def("unsafe_buffer_pointer", |
| [](const PyLocalBuffer& buffer) -> StatusOr<std::uintptr_t> { |
| TF_ASSIGN_OR_RETURN(ShapedBuffer shaped_buffer, |
| buffer.AsShapedBuffer()); |
| if (shaped_buffer.on_device_shape().IsTuple()) { |
| return Unimplemented( |
| "unsafe_buffer_pointer is not implemented for tuple " |
| "buffers."); |
| } |
| return absl::bit_cast<std::uintptr_t>( |
| shaped_buffer.root_buffer().opaque()); |
| }); |
| |
| py::class_<PyLocalExecutable>(m, "LocalExecutable") |
| .def_static("Compile", &PyLocalExecutable::Compile, |
| py::call_guard<py::gil_scoped_release>()) |
| .def("DeviceOrdinals", &PyLocalExecutable::DeviceOrdinals) |
| .def("SizeOfGeneratedCodeInBytes", |
| &PyLocalExecutable::SizeOfGeneratedCodeInBytes) |
| .def("Delete", &PyLocalExecutable::Delete) |
| .def("Execute", &PyLocalExecutable::Execute, |
| py::call_guard<py::gil_scoped_release>(), py::arg("arguments")) |
| .def("ExecutePerReplica", &PyLocalExecutable::ExecutePerReplica, |
| py::call_guard<py::gil_scoped_release>(), py::arg("arguments")); |
| |
| py::class_<DebugOptions>(m, "DebugOptions") |
| .def_property("xla_cpu_enable_fast_math", |
| &DebugOptions::xla_cpu_enable_fast_math, |
| &DebugOptions::set_xla_cpu_enable_fast_math) |
| .def_property("xla_cpu_fast_math_honor_infs", |
| &DebugOptions::xla_cpu_fast_math_honor_infs, |
| &DebugOptions::set_xla_cpu_fast_math_honor_infs) |
| .def_property("xla_cpu_fast_math_honor_nans", |
| &DebugOptions::xla_cpu_fast_math_honor_nans, |
| &DebugOptions::set_xla_cpu_fast_math_honor_nans) |
| .def_property("xla_cpu_fast_math_honor_division", |
| &DebugOptions::xla_cpu_fast_math_honor_division, |
| &DebugOptions::set_xla_cpu_fast_math_honor_division) |
| .def_property("xla_gpu_enable_fast_min_max", |
| &DebugOptions::xla_gpu_enable_fast_min_max, |
| &DebugOptions::set_xla_gpu_enable_fast_min_max); |
| |
| py::class_<ExecutableBuildOptions>(m, "ExecutableBuildOptions") |
| .def(py::init<>()) |
| .def_property( |
| "result_layout", |
| [](const ExecutableBuildOptions& options) -> absl::optional<Shape> { |
| return options.result_layout() |
| ? absl::optional<Shape>(*options.result_layout()) |
| : absl::nullopt; |
| }, |
| &ExecutableBuildOptions::set_result_layout) |
| .def_property("num_replicas", &ExecutableBuildOptions::num_replicas, |
| &ExecutableBuildOptions::set_num_replicas) |
| .def_property_readonly( |
| "debug_options", &ExecutableBuildOptions::mutable_debug_options, |
| py::return_value_policy::reference, py::keep_alive<1, 0>()); |
| |
| py::class_<XlaComputation>(m, "XlaComputation") |
| .def("GetProgramShape", &XlaComputation::GetProgramShape) |
| .def("GetSerializedProto", &GetComputationSerializedProto) |
| .def("GetHloText", &GetComputationHloText) |
| .def("GetHloDotGraph", &GetComputationHloDotGraph); |
| |
| py::class_<XlaOp>(m, "XlaOp"); |
| |
| py::class_<XlaBuilder>(m, "XlaBuilder") |
| .def(py::init([](const std::string& name) -> std::unique_ptr<XlaBuilder> { |
| return absl::make_unique<XlaBuilder>(UniquifyName(name)); |
| })) |
| .def( |
| "Build", |
| [](XlaBuilder& builder, absl::optional<XlaOp> root) { |
| return root ? builder.Build(*root) : builder.Build(); |
| }, |
| "Builds a computation from the contents of the builder.", |
| py::arg("root") = absl::nullopt) |
| .def("ClearOpMetadata", &XlaBuilder::ClearOpMetadata) |
| .def("GetShape", &XlaBuilder::GetShape) |
| .def( |
| "GetProgramShape", |
| [](const XlaBuilder& builder, |
| absl::optional<XlaOp> root) -> StatusOr<ProgramShape> { |
| return root ? builder.GetProgramShape(*root) |
| : builder.GetProgramShape(); |
| }, |
| py::arg("root") = absl::nullopt) |
| .def("IsConstant", &XlaBuilder::IsConstant) |
| .def("SetOpMetadata", &XlaBuilder::SetOpMetadata); |
| |
| // ops submodule, containing free functions that add operators to an |
| // XlaBuilder. |
| py::module ops = m.def_submodule("ops", "XLA operations"); |
| |
| ops.def("AllReduce", |
| static_cast<XlaOp (*)( |
| XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>, |
| const absl::optional<ChannelHandle>&)>(&CrossReplicaSum)); |
| ops.def("AllToAll", &AllToAll); |
| ops.def("CollectivePermute", &CollectivePermute); |
| ops.def("CrossReplicaSum", |
| static_cast<XlaOp (*)(XlaOp, absl::Span<const ReplicaGroup>)>( |
| &CrossReplicaSum)); |
| ops.def("BitcastConvertType", &BitcastConvertType, py::arg("operand"), |
| py::arg("new_element_type")); |
| ops.def("Broadcast", &Broadcast, py::arg("operand"), py::arg("sizes")); |
| ops.def("BroadcastInDim", &BroadcastInDim, py::arg("operand"), |
| py::arg("shape"), py::arg("broadcast_dimensions")); |
| ops.def("Call", &Call); |
| ops.def("Cholesky", &Cholesky, py::arg("a"), py::arg("lower") = true); |
| ops.def("Clamp", &Clamp); |
| ops.def("Collapse", &Collapse, py::arg("operand"), py::arg("dimensions")); |
| ops.def("ConcatInDim", &ConcatInDim); |
| ops.def("Conditional", |
| static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaComputation* const>, |
| absl::Span<const XlaOp>)>(&Conditional)); |
| ops.def("Conditional", |
| static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&, XlaOp, |
| const XlaComputation&)>(&Conditional)); |
| ops.def("ConstantLiteral", &ConstantLiteral); |
| ops.def("ConvGeneralDilated", &ConvGeneralDilated, py::arg("lhs"), |
| py::arg("rhs"), py::arg("window_strides"), py::arg("padding"), |
| py::arg("lhs_dilation"), py::arg("rhs_dilation"), |
| py::arg("dimension_numbers"), py::arg("feature_group_count") = 1, |
| py::arg("batch_group_count") = 1, |
| py::arg("precision_config") = nullptr); |
| ops.def("ConvertElementType", &ConvertElementType, py::arg("operand"), |
| py::arg("new_element_type")); |
| ops.def("CustomCall", &CustomCallWithLayout); |
| ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), |
| py::arg("precision_config") = nullptr); |
| ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), |
| py::arg("dimension_numbers"), py::arg("precision_config") = nullptr); |
| ops.def("DynamicSlice", |
| static_cast<XlaOp (*)(XlaOp, absl::Span<const XlaOp>, |
| absl::Span<const int64>)>(&DynamicSlice)); |
| ops.def("DynamicUpdateSlice", |
| static_cast<XlaOp (*)(XlaOp, XlaOp, absl::Span<const XlaOp>)>( |
| &DynamicUpdateSlice)); |
| |
| ops.def("Fft", &Fft); |
| py::enum_<FftType>(m, "FftType") |
| .value("FFT", FftType::FFT) |
| .value("IFFT", FftType::IFFT) |
| .value("RFFT", FftType::RFFT) |
| .value("IRFFT", FftType::IRFFT); |
| |
| ops.def("Gather", &Gather, py::arg("a"), py::arg("start_indices"), |
| py::arg("dimension_numbers"), py::arg("slice_sizes")); |
| ops.def("GetTupleElement", &GetTupleElement); |
| ops.def("Infeed", &Infeed, py::arg("builder"), py::arg("shape"), |
| py::arg("config") = ""); |
| ops.def("Iota", |
| static_cast<XlaOp (*)(XlaBuilder*, const Shape&, int64)>(&Iota)); |
| ops.def("Iota", |
| static_cast<XlaOp (*)(XlaBuilder*, PrimitiveType, int64)>(&Iota)); |
| ops.def("Map", &Map); |
| ops.def("Outfeed", &Outfeed, py::arg("operand"), py::arg("shape_with_layout"), |
| py::arg("outfeed_config") = ""); |
| ops.def("Pad", &Pad); |
| ops.def("Parameter", static_cast<XlaOp (*)(XlaBuilder*, int64, const Shape&, |
| const std::string&)>(&Parameter)); |
| ops.def("QR", |
| [](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> { |
| TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices)); |
| return std::make_pair(qr.q, qr.r); |
| }); |
| ops.def( |
| "Eigh", |
| [](XlaOp a, bool lower, int64 max_iter, |
| float epsilon) -> std::pair<XlaOp, XlaOp> { |
| auto eigh = SelfAdjointEig(a, lower, max_iter, epsilon); |
| return std::make_pair(eigh.v, eigh.w); |
| }, |
| py::arg("a"), py::arg("lower") = true, py::arg("max_iter") = 100, |
| py::arg("epsilon") = 1e-6); |
| ops.def( |
| "SVD", |
| [](XlaOp a, int64 max_iter, |
| float epsilon) -> std::tuple<XlaOp, XlaOp, XlaOp> { |
| auto svd = SVD(a, max_iter, epsilon); |
| return std::make_tuple(svd.u, svd.d, svd.v); |
| }, |
| py::arg("a"), py::arg("max_iter") = 100, py::arg("epsilon") = 1e-6); |
| ops.def("Reduce", |
| static_cast<XlaOp (*)(XlaBuilder*, absl::Span<const XlaOp>, |
| absl::Span<const XlaOp>, const XlaComputation&, |
| absl::Span<const int64>)>(&Reduce)); |
| ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"), |
| py::arg("exponent_bits"), py::arg("mantissa_bits")); |
| ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding); |
| ops.def("ReplicaId", &ReplicaId); |
| ops.def("Reshape", static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>, |
| absl::Span<const int64>)>(&Reshape)); |
| ops.def("Reshape", |
| static_cast<XlaOp (*)(XlaOp, absl::Span<const int64>)>(&Reshape)); |
| ops.def("Rev", &Rev, py::arg("operand"), py::arg("dimensions")); |
| ops.def("RngNormal", &RngNormal); |
| ops.def("RngUniform", &RngUniform); |
| ops.def("Scatter", &Scatter); |
| ops.def("Select", &Select); |
| ops.def("SelectAndScatterWithGeneralPadding", |
| &SelectAndScatterWithGeneralPadding); |
| ops.def("Slice", &Slice); |
| ops.def("SliceInDim", &SliceInDim, py::arg("operand"), py::arg("start_index"), |
| py::arg("limit_index"), py::arg("stride"), py::arg("dimno")); |
| ops.def( |
| "Sort", |
| [](XlaBuilder* builder, absl::Span<const XlaOp> operands, |
| int64 dimension) -> XlaOp { |
| return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { |
| std::vector<PrimitiveType> operand_types; |
| for (const auto& operand : operands) { |
| TF_ASSIGN_OR_RETURN(auto operand_shape, builder->GetShape(operand)); |
| operand_types.push_back(operand_shape.element_type()); |
| } |
| return Sort(operands, |
| CreateScalarLtComputation(operand_types, builder), |
| dimension); |
| }); |
| }, |
| py::arg("builder"), py::arg("operands"), py::arg("dimension") = -1); |
| ops.def("Transpose", &Transpose); |
| ops.def("TriangularSolve", &TriangularSolve); |
| ops.def("Tuple", &Tuple); |
| ops.def("While", &While); |
| |
| #define BINARY_OP(op) \ |
| ops.def( \ |
| #op, \ |
| [](XlaOp a, XlaOp b, absl::optional<std::vector<int64>> dims) { \ |
| return dims ? op(a, b, *dims) : op(a, b); \ |
| }, \ |
| py::arg("lhs"), py::arg("rhs"), \ |
| py::arg("broadcast_dimensions") = absl::nullopt) |
| BINARY_OP(Eq); |
| BINARY_OP(Ne); |
| BINARY_OP(Ge); |
| BINARY_OP(Gt); |
| BINARY_OP(Lt); |
| BINARY_OP(Le); |
| BINARY_OP(Add); |
| BINARY_OP(Sub); |
| BINARY_OP(Mul); |
| BINARY_OP(Div); |
| BINARY_OP(Rem); |
| BINARY_OP(Max); |
| BINARY_OP(Min); |
| BINARY_OP(And); |
| BINARY_OP(Or); |
| BINARY_OP(Xor); |
| BINARY_OP(ShiftLeft); |
| BINARY_OP(ShiftRightArithmetic); |
| BINARY_OP(ShiftRightLogical); |
| BINARY_OP(Atan2); |
| BINARY_OP(Pow); |
| BINARY_OP(Complex); |
| #undef BINARY_OP |
| |
| #define UNARY_OP(op) ops.def(#op, &op) |
| UNARY_OP(Not); |
| UNARY_OP(Clz); |
| UNARY_OP(Abs); |
| UNARY_OP(Exp); |
| UNARY_OP(Expm1); |
| UNARY_OP(Floor); |
| UNARY_OP(Ceil); |
| UNARY_OP(Round); |
| UNARY_OP(Log); |
| UNARY_OP(Log1p); |
| UNARY_OP(Sign); |
| UNARY_OP(Cos); |
| UNARY_OP(Sin); |
| UNARY_OP(Tanh); |
| UNARY_OP(IsFinite); |
| UNARY_OP(Neg); |
| UNARY_OP(Sqrt); |
| UNARY_OP(Rsqrt); |
| UNARY_OP(Square); |
| UNARY_OP(Reciprocal); |
| UNARY_OP(Erfc); |
| UNARY_OP(Erf); |
| UNARY_OP(ErfInv); |
| UNARY_OP(Lgamma); |
| UNARY_OP(Digamma); |
| UNARY_OP(Acos); |
| UNARY_OP(Asin); |
| UNARY_OP(Atan); |
| UNARY_OP(Tan); |
| UNARY_OP(Acosh); |
| UNARY_OP(Asinh); |
| UNARY_OP(Atanh); |
| UNARY_OP(Cosh); |
| UNARY_OP(Sinh); |
| UNARY_OP(Real); |
| UNARY_OP(Imag); |
| UNARY_OP(Conj); |
| #undef UNARY_OP |
| |
| py::enum_<TriangularSolveOptions::Transpose>( |
| m, "TriangularSolveOptions_Transpose") |
| .value("TRANSPOSE_INVALID", TriangularSolveOptions::TRANSPOSE_INVALID) |
| .value("NO_TRANSPOSE", TriangularSolveOptions::NO_TRANSPOSE) |
| .value("TRANSPOSE", TriangularSolveOptions::TRANSPOSE) |
| .value("ADJOINT", TriangularSolveOptions::ADJOINT); |
| |
| py::enum_<PrecisionConfig::Precision>(m, "PrecisionConfig_Precision") |
| .value("DEFAULT", PrecisionConfig::DEFAULT) |
| .value("HIGH", PrecisionConfig::HIGH) |
| .value("HIGHEST", PrecisionConfig::HIGHEST); |
| |
| // TODO(phawkins): improve bindings for these types. |
| py::class_<ChannelHandle>(m, "ChannelHandle"); |
| |
| tensorflow::AddXrtSubmodule(&m); |
| } // NOLINT(readability/fn_size) |
| |
| } // namespace xla |