| # Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| |
| import enum |
| import inspect |
| import types |
| import typing |
| from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union, overload |
| |
| import numpy as np |
| |
| from . import ops |
| from . import jax_jit |
| from . import mlir |
| from . import outfeed_receiver |
| from . import pmap_lib |
| from . import profiler |
| from . import pytree |
| from . import transfer_guard_lib |
| |
| _LiteralSlice = Any |
| _Status = Any |
| _Dtype = Any |
| _XlaOpMetadata = Any |
| |
| _T = TypeVar("_T") |
| |
| class XlaRuntimeError(RuntimeError): |
| pass |
| |
| class PrimitiveType(enum.IntEnum): |
| PRIMITIVE_TYPE_INVALID: PrimitiveType |
| PRED: PrimitiveType |
| S8: PrimitiveType |
| S16: PrimitiveType |
| S32: PrimitiveType |
| S64: PrimitiveType |
| U8: PrimitiveType |
| U16: PrimitiveType |
| U32: PrimitiveType |
| U64: PrimitiveType |
| BF16: PrimitiveType |
| F16: PrimitiveType |
| F32: PrimitiveType |
| F64: PrimitiveType |
| C64: PrimitiveType |
| C128: PrimitiveType |
| TUPLE: PrimitiveType |
| OPAQUE_TYPE: PrimitiveType |
| TOKEN: PrimitiveType |
| |
| def bfloat16_dtype() -> Type[Any]: ... |
| |
| # === BEGIN xla_compiler.cc |
| |
| class Shape: |
| def __init__(self, s: str): ... |
| @staticmethod |
| def tuple_shape(shapes: Sequence[Shape]) -> Shape: ... |
| @staticmethod |
| def array_shape( |
| type: Union[np.dtype, PrimitiveType], |
| dims_seq: Any = ..., |
| layout_seq: Any = ..., |
| dynamic_dimensions: Optional[List[bool]] = ...) -> Shape: ... |
| @staticmethod |
| def token_shape() -> Shape: ... |
| @staticmethod |
| def scalar_shape(type: Union[np.dtype, PrimitiveType]) -> Shape: ... |
| def dimensions(self) -> Tuple[int, ...]: ... |
| def xla_element_type(self) -> PrimitiveType: ... |
| def element_type(self) -> np.dtype: ... |
| def numpy_dtype(self) -> np.dtype: ... |
| def is_tuple(self) -> bool: ... |
| def is_array(self) -> bool: ... |
| def is_token(self) -> bool: ... |
| def is_static(self) -> bool: ... |
| def is_dynamic(self) -> bool: ... |
| def is_dynamic_dimension(self, dimension: int) -> bool: ... |
| def set_dynamic_dimension(self, dimension: int, is_dynamic: bool) -> None: ... |
| def rank(self) -> int: ... |
| def to_serialized_proto(self) -> bytes: ... |
| def tuple_shapes(self) -> List[Shape]: ... |
| def leaf_count(self) -> int: ... |
| def with_major_to_minor_layout_if_absent(self) -> Shape: ... |
| def __eq__(self, other: Shape) -> bool: ... |
| def __ne__(self, other: Shape) -> bool: ... |
| def __hash__(self) -> int: ... |
| def __repr__(self) -> str: ... |
| |
| class Layout: |
| def minor_to_major(self) -> Tuple[int, ...]: ... |
| def to_string(self) -> str: ... |
| def __eq__(self, other: Layout) -> bool: ... |
| def __ne__(self, other: Layout) -> bool: ... |
| def __hash__(self) -> int: ... |
| |
| class ProgramShape: |
| def __init__(self, params: Sequence[Shape], result: Shape) -> None: ... |
| def parameter_shapes(self) -> List[Shape]: ... |
| def result_shape(self) -> Shape: ... |
| def __repr__(self) -> str: ... |
| |
| class ShapeIndex: |
| def __init__(self, indices: List[int]) -> ShapeIndex: ... |
| def __eq__(self, other: Shape) -> bool: ... |
| def __ne__(self, other: Shape) -> bool: ... |
| def __hash__(self) -> int: ... |
| def __repr__(self) -> str: ... |
| |
| class Literal: |
| def __repr__(self) -> str: ... |
| |
| class XlaComputation: |
| def __init__(self, serialized_hlo_module_proto: bytes) -> None: ... |
| def get_hlo_module(self) -> HloModule: ... |
| def program_shape(self) -> ProgramShape: ... |
| def as_serialized_hlo_module_proto(self) -> bytes: ... |
| def as_hlo_text(self, print_large_constants: bool=False) -> str: ... |
| def as_hlo_dot_graph(self) -> str: ... |
| def hash(self) -> int: ... |
| def as_hlo_module(elf) -> HloModule: ... |
| |
| class HloPrintOptions: |
| def __init__(self) -> None: ... |
| @staticmethod |
| def short_parsable() -> HloPrintOptions: ... |
| @staticmethod |
| def canonical() -> HloPrintOptions: ... |
| @staticmethod |
| def fingerprint() -> HloPrintOptions: ... |
| print_large_constants: bool |
| print_metadata: bool |
| print_backend_config: bool |
| print_result_shape: bool |
| print_operand_shape: bool |
| print_operand_names: bool |
| print_ids: bool |
| print_extra_attributes: bool |
| print_program_shape: bool |
| print_percent: bool |
| print_control_dependencies: bool |
| compact_operands: bool |
| include_layout_in_shapes: bool |
| canonicalize_instruction_names: bool |
| canonicalize_computations: bool |
| indent_amount: int |
| is_in_nested_computation: bool |
| |
| class HloModule: |
| spmd_output_sharding: Optional[OpSharding] |
| spmd_parameters_shardings: Optional[List[OpSharding]] |
| def to_string(self, options: HloPrintOptions = ...) -> str: ... |
| def as_serialized_hlo_module_proto(self)-> bytes: ... |
| @staticmethod |
| def from_serialized_hlo_module_proto( |
| serialized_hlo_module_proto: bytes) -> HloModule: ... |
| |
| def hlo_module_to_dot_graph(hlo_module: HloModule) -> str: ... |
| |
| def hlo_module_cost_analysis( |
| client: Client, |
| module: HloModule) -> Dict[str, float]: ... |
| |
| class XlaOp: ... |
| |
| class XlaBuilder: |
| def __init__(self, name: str) -> None: ... |
| def Build(self, root: Optional[XlaOp] = ...) -> XlaComputation: ... |
| def GetShape(self, __op: XlaOp) -> Shape: ... |
| build = Build |
| def clear_op_metadata(self) -> None: ... |
| get_shape = GetShape |
| def get_program_shape(self, root: Optional[XlaOp] = ...) -> ProgramShape: ... |
| def is_constant(self, __op: XlaOp) -> bool: ... |
| def set_op_metadata(self, metadata: _XlaOpMetadata) -> None: ... |
| def set_sharding(self, sharding: OpSharding_Type) -> None: ... |
| def clear_sharding(self) -> None: ... |
| def setup_alias( |
| self, |
| __output_index: Sequence[int], |
| __param_number: int, |
| __param_index: Sequence[int]) -> None: ... |
| |
| class DeviceAssignment: |
| @staticmethod |
| def create(array: np.ndarray) -> DeviceAssignment: ... |
| def replica_count(self) -> int: ... |
| def computation_count(self) -> int: ... |
| def __repr__(self) -> str: ... |
| def serialize(self) -> bytes: ... |
| |
| class CompileOptions: |
| def __init__(self) -> None: ... |
| argument_layouts: Optional[List[Shape]] |
| parameter_is_tupled_arguments: bool |
| executable_build_options: ExecutableBuildOptions |
| tuple_arguments: bool |
| num_replicas: int |
| num_partitions: int |
| device_assignment: Optional[DeviceAssignment] |
| |
| def register_custom_call_target(fn_name: str, capsule: Any, platform: str) -> _Status: ... |
| |
| class DebugOptions: |
| def __repr__(self) -> str: ... |
| xla_cpu_enable_fast_math: bool |
| xla_cpu_fast_math_honor_infs: bool |
| xla_cpu_fast_math_honor_nans: bool |
| xla_cpu_fast_math_honor_division: bool |
| xla_cpu_fast_math_honor_functions: bool |
| xla_gpu_enable_fast_min_max: bool |
| xla_backend_optimization_level: int |
| xla_cpu_enable_xprof_traceme: bool |
| xla_llvm_disable_expensive_passes: bool |
| xla_test_all_input_layouts: bool |
| |
| class CompiledMemoryStats: |
| generated_code_size_in_bytes: int |
| argument_size_in_bytes: int |
| output_size_in_bytes: int |
| alias_size_in_bytes: int |
| temp_size_in_bytes: int |
| def __str__(self) -> str: ... |
| |
| |
| class ExecutableBuildOptions: |
| def __init__(self) -> None: ... |
| def __repr__(self) -> str: ... |
| result_layout: Optional[Shape] |
| num_replicas: int |
| num_partitions: int |
| debug_options: DebugOptions |
| device_assignment: Optional[DeviceAssignment] |
| use_spmd_partitioning: bool |
| use_auto_spmd_partitioning: bool |
| auto_spmd_partitioning_mesh_shape: List[int] |
| auto_spmd_partitioning_mesh_ids: List[int] |
| |
| class PrecisionConfig_Precision(enum.IntEnum): |
| DEFAULT: int |
| HIGH: int |
| HIGHEST: int |
| |
| class OpSharding_Type(enum.IntEnum): |
| REPLICATED: int |
| MAXIMAL: int |
| TUPLE: int |
| OTHER: int |
| MANUAL: int |
| |
| class OpSharding: |
| Type: typing.Type[OpSharding_Type] |
| type: OpSharding_Type |
| replicate_on_last_tile_dim: bool |
| last_tile_dims: Sequence[Type] |
| tile_assignment_dimensions: Sequence[int] |
| tile_assignment_devices: Sequence[int] |
| tuple_shardings: Sequence[OpSharding] |
| def SerializeToString(self) -> bytes: ... |
| |
| class ChannelHandle_ChannelType(enum.IntEnum): |
| CHANNEL_TYPE_INVALID: int |
| DEVICE_TO_DEVICE: int |
| DEVICE_TO_HOST: int |
| HOST_TO_DEVICE: int |
| |
| class ChannelHandle: |
| type: ChannelHandle_ChannelType |
| handle: int |
| def __repr__(self) -> str: ... |
| |
| class FftType(enum.IntEnum): |
| FFT: int |
| IFFT: int |
| RFFT: int |
| IRFFT: int |
| |
| # === END xla_compiler.cc |
| |
| class Device: |
| id: int |
| host_id: int |
| process_index: int |
| platform: str |
| device_kind: str |
| client: Client |
| def __repr__(self) -> str: ... |
| def __str__(self) -> str: ... |
| def transfer_to_infeed(self, literal: _LiteralSlice): ... |
| def transfer_from_outfeed(self, shape: Shape): ... |
| def live_buffers(self) -> List[Buffer]: ... |
| |
| class GpuDevice(Device): |
| device_vendor: str |
| |
| class TpuDevice(Device): |
| coords: Tuple[int, ...] |
| core_on_chip: int |
| |
| class _GpuAllocatorKind(enum.IntEnum): |
| DEFAULT: int |
| PLATFORM: int |
| BFC: int |
| CUDA_ASYNC: int |
| |
| class GpuAllocatorConfig: |
| # TODO(b/194673104): Remove once pytype correctly resolves a nested enum. |
| Kind = _GpuAllocatorKind |
| |
| def __init__( |
| self, |
| kind: _GpuAllocatorKind = ..., |
| memory_fraction: float = ..., |
| preallocate: bool = ...) -> None: ... |
| |
| class HostBufferSemantics(enum.IntEnum): |
| IMMUTABLE_ONLY_DURING_CALL: HostBufferSemantics |
| IMMUTABLE_UNTIL_TRANSFER_COMPLETES: HostBufferSemantics |
| ZERO_COPY: HostBufferSemantics |
| |
| class Client: |
| platform: str |
| platform_version: str |
| runtime_type: str |
| def device_count(self) -> int: ... |
| def local_device_count(self) -> int: ... |
| def devices(self) -> List[Device]: ... |
| def local_devices(self) -> List[Device]: ... |
| def live_buffers(self) -> List[Buffer]: ... |
| def live_executables(self) -> List[Executable]: ... |
| def host_id(self) -> int: ... |
| def process_index(self) -> int: ... |
| @overload |
| def get_default_device_assignment( |
| self, |
| num_replicas: int, |
| num_partitions: int) -> List[List[Device]]: ... |
| @overload |
| def get_default_device_assignment( |
| self, |
| num_replicas: int) -> List[Device]: ... |
| def create_channel_handle(self) -> ChannelHandle: ... |
| def create_device_to_host_channel_handle(self) -> ChannelHandle: ... |
| def create_host_to_device_channel_handle(self) -> ChannelHandle: ... |
| def buffer_from_pyval( |
| self, |
| argument: Any, |
| device: Device = ..., |
| force_copy: bool = ..., |
| host_buffer_semantics: HostBufferSemantics = ...) -> Buffer: ... |
| def make_cross_host_receive_buffers( |
| self, |
| shapes: Sequence[Shape], |
| device: Device) -> List[Tuple[Buffer, bytes]]: ... |
| def compile( |
| self, |
| computation: XlaComputation, |
| compile_options: CompileOptions = ...) -> Executable: ... |
| def serialize_executable(self, executable: Executable) -> bytes: ... |
| def deserialize_executable( |
| self, serialized: bytes, |
| options: CompileOptions) -> Executable: ... |
| # TODO(skyewm): remove when jax stop providing hlo_module |
| def deserialize_executable( |
| self, serialized: bytes, |
| hlo_module: HloModule, |
| options: CompileOptions) -> Executable: ... |
| def heap_profile(self) -> bytes: ... |
| def defragment(self) -> _Status: ... |
| def emit_python_callback( |
| self, callable: Callable, builder: XlaBuilder, operands: Sequence[XlaOp], |
| results_shapes: Sequence[Shape], |
| operand_layouts: Optional[Sequence[Shape]] = ..., |
| has_side_effects: bool = ...) -> Tuple[XlaOp, Any]: ... |
| |
| |
| def get_cpu_client(asynchronous: bool = ...) -> Client: ... |
| def get_tfrt_cpu_client(asynchronous: bool = ...) -> Client: ... |
| def get_interpreter_client() -> Client: ... |
| def get_gpu_client( |
| asynchronous: bool = ..., |
| allocator_config: GpuAllocatorConfig = ..., |
| distributed_client: Optional[DistributedRuntimeClient] = ..., |
| node_id: int = ...) -> Client:... |
| def get_tpu_client(max_inflight_computations: int = ...) -> Client: ... |
| |
| class DeviceArrayBase: ... |
| |
| class DeviceArray(DeviceArrayBase): |
| __array_priority__: int |
| _device: Optional[Device] |
| aval: Any |
| weak_type: Optional[bool] |
| @property |
| def device_buffer(self: _T) -> _T: ... |
| shape: Tuple[int, ...] |
| dtype: np.dtype |
| size: int |
| ndim: int |
| _value: np.ndarray |
| def copy_to_device(self, dst_device: Device) -> DeviceArray: ... |
| def copy_to_remote_device(self, |
| descriptor: bytes) -> Tuple[_Status, bool]: ... |
| def on_device_size_in_bytes(self) -> int: ... |
| def delete(self) -> None: ... |
| def is_ready(self) -> bool: ... |
| def is_known_ready(self) -> bool: ... |
| def block_until_ready(self) -> DeviceArray: ... |
| def copy_to_host_async(self) -> _Status: ... |
| def to_py(self) -> np.ndarray: ... |
| def xla_shape(self) -> Shape: ... |
| def xla_dynamic_shape(self) -> Shape: ... |
| client: Client |
| def device(self) -> Device: ... |
| def platform(self) -> str: ... |
| def is_deleted(self) -> bool: ... |
| def unsafe_buffer_pointer(self) -> Any: ... |
| __cuda_array_interface__: Dict[str, Any] |
| traceback: Traceback |
| def clone(self) -> DeviceArray: ... |
| |
| PyLocalBuffer = DeviceArray |
| Buffer = DeviceArray |
| |
| class Executable: |
| client: Client |
| def local_logical_device_ids(self) -> List[Tuple[int, int]]: ... |
| def local_devices(self) -> List[Device]: ... |
| def size_of_generated_code_in_bytes(self) -> int: ... |
| def delete(self) -> None: ... |
| def execute(self, arguments: Sequence[DeviceArray]) -> List[DeviceArray]: ... |
| def execute_sharded_on_local_devices( |
| self, |
| arguments: Sequence[List[DeviceArray]]) -> List[List[DeviceArray]]: ... |
| def hlo_modules(self) -> List[HloModule]: ... |
| def keep_alive(self) -> None: ... |
| traceback: Traceback |
| fingerprint: Optional[bytes] |
| |
| def buffer_to_dlpack_managed_tensor( |
| buffer: Buffer, |
| take_ownership: bool = ...) -> Any: ... |
| def dlpack_managed_tensor_to_buffer( |
| tensor: Any, cpu_backend: Optional[Client] = ..., |
| gpu_backend: Optional[Client] = ...) -> Buffer: ... |
| |
| # === BEGIN py_traceback.cc |
| |
| class Frame: |
| file_name: str |
| function_name: str |
| function_line_start: int |
| line_num: int |
| def __repr__(self) -> str: ... |
| |
| class Traceback: |
| enabled: ClassVar[bool] |
| @staticmethod |
| def get_traceback() -> Traceback: ... |
| frames: Sequence[Frame] |
| def __str__(self) -> str: ... |
| def as_python_traceback(self) -> Any: ... |
| def raw_frames(self) -> Tuple[List[types.CodeType], List[int]]: ... |
| |
| @staticmethod |
| def code_addr2line(code: types.CodeType, lasti: int) -> int: ... |
| |
| def replace_thread_exc_traceback(traceback: Any): ... |
| |
| # === END py_traceback.cc |
| |
| class DistributedRuntimeService: |
| def shutdown(self) -> None: ... |
| class DistributedRuntimeClient: |
| def connect(self) -> _Status: ... |
| def shutdown(self) -> _Status: ... |
| |
| def get_distributed_runtime_service( |
| address: str, |
| num_nodes: int, |
| heartbeat_interval: Optional[int] = ..., |
| max_missing_heartbeats: Optional[int] = ..., |
| enumerate_devices_timeout: Optional[int] = ..., |
| shutdown_timeout: Optional[int] = ...) -> DistributedRuntimeService: ... |
| def get_distributed_runtime_client( |
| address: str, |
| node_id: int, |
| rpc_timeout: Optional[int] = ..., |
| init_timeout: Optional[int] = ..., |
| shutdown_timeout: Optional[int] = ..., |
| heartbeat_interval: Optional[int] = ..., |
| max_missing_heartbeats: Optional[int] = ..., |
| missed_heartbeat_callback: Optional[Any] = ..., |
| shutdown_on_destruction: Optional[bool] = ...) -> DistributedRuntimeClient: ... |
| |
| def collect_garbage() -> None: ... |
| |
| def is_optimized_build() -> bool: ... |
| |
| def json_to_pprof_profile(json: str) -> bytes: ... |
| def pprof_profile_to_json(proto: bytes) -> str: ... |
| |
| |
| class CompiledFunction: |
| def __call__(self, *args, **kwargs) -> Any: ... |
| def __getstate__(self) -> Any: ... |
| def __setstate__(self, Any): ... |
| __signature__: inspect.Signature |
| def _cache_size(self) -> int: ... |
| def _clear_cache(self) -> None: ... |
| |
| class PmapFunction: |
| def __call__(self, *args, **kwargs) -> Any: ... |
| def __getstate__(self) -> Any: ... |
| def __setstate__(self, Any): ... |
| __signature__: inspect.Signature |
| def _cache_size(self) -> int: ... |
| def _clear_cache(self) -> None: ... |