| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ |
| #define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ |
| |
| #include "absl/container/inlined_vector.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/str_join.h" |
| #include "pybind11/pybind11.h" |
| #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" |
| #include "tensorflow/compiler/xla/python/py_client.h" |
| #include "tensorflow/compiler/xla/python/py_values.h" |
| #include "tensorflow/compiler/xla/python/pytree.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| |
| namespace jax { |
| |
| // Returns the value for jax_enable_x64 (defined by a thread-local value if |
| // defined, defaulting to the value of the flag otherwise). |
| bool GetEnableX64(); |
| |
| // The signature of Python jitted function call, partitioned into: |
| // - dynamic positional arguments (i.e. positional args which are not static) |
| // - static positional arguments (i.e. the args associated to static_argnums) |
| // - keyword arguments |
| // The CallSignature should unambiguously identify a function call, thus, |
| // equality is based on: |
| // (a) Same PyTree for all dynamic positional arguments and keyword arguments |
| // (a) equality of the arguments and keyword arguments ArgSignature |
| // (a) equality (delegated to Python) of the static arguments. |
| struct CallSignature { |
| // A PyTreeDef for each dynamic argument, positional arguments first |
| // followed by keyword arguments. Keyword arguments are in the order given |
| // by dynamic_arg_names. |
| absl::InlinedVector<xla::PyTreeDef, 2> dynamic_arg_treedefs; |
| // Dynamic keyword argument names. Interned, and sorted by the keyword |
| // name. |
| std::vector<pybind11::object> dynamic_arg_names; |
| // Shape and dtype for both the dynamic positional arguments and the keyword |
| // arguments (sorted by keyword name). |
| absl::InlinedVector<xla::PyArgSignature, 2> dynamic_arg_signatures; |
| |
| // Static arguments. Contains the positional arguments sorted in argument |
| // order, followed by static keyword arguments in the order given by |
| // `static_arg_names`. |
| std::vector<pybind11::object> static_args; |
| // Static keyword argument names. Interned, and sorted by keyword name. |
| std::vector<pybind11::object> static_arg_names; |
| |
| xla::PjRtDevice* device; |
| bool jax_enable_x64; |
| |
| // Opaque additional context that should be included as part of the cache key. |
| pybind11::object global_extra_jit_context; |
| absl::optional<pybind11::object> thread_local_extra_jit_context; |
| |
| bool operator==(const CallSignature& other) const; |
| bool operator!=(const CallSignature& other) const { |
| return !(*this == other); |
| } |
| |
| std::string DebugString() const; |
| }; |
| |
| template <typename H> |
| H AbslHashValue(H h, const CallSignature& s); |
| |
| // The resulting information of the parsing and conversion of the arguments. |
| struct ParsedArgumentsAsBuffers { |
| // The call signature will be filled during 2 steps: |
| // - `ParseArguments` will fill the static arguments and the pytree |
| // structures |
| // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`. |
| CallSignature signature; |
| // The concatenation of the dynamic positional arguments and the sorted |
| // keyword arguments. |
| absl::InlinedVector<pybind11::object, 2> flat_dynamic_args; |
| std::vector<pybind11::object> keep_alive_objects; |
| |
| // The following is only valid if the parsing succeeds. |
| std::vector<xla::PjRtBuffer*> arg_buffers; |
| // We may need to keep these objects around, because: |
| // (a) we need to extend the lifetime of objects created within |
| // `CopyBuffersToDevice` |
| // (b) `arg_buffers` do not maintain ownership |
| std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive; |
| }; |
| |
| // Filter out static arguments, flatten and concatenate other arguments (i.e. |
| // dynamic positional and keyword arguments), filling `arguments` in place. |
| xla::Status ParseArguments(pybind11::handle args, |
| const absl::optional<pybind11::kwargs>& py_kwargs, |
| absl::Span<int const> static_argnums, |
| absl::Span<pybind11::str const> static_argnames, |
| ParsedArgumentsAsBuffers& arguments); |
| |
| // The function to call in `xla.cc` to add the bindings for this module. |
| void BuildJaxjitSubmodule(pybind11::module& m); |
| |
| } // namespace jax |
| |
| #endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ |