Improve pytype checking of XLA types inside JAX.
Add an explicit `.pyi` file for jax/_src/lib/__init__.pyi, which works around a bug in pytype where the types of modules that are re-exported becomes `Any`.
[XLA:Python] Fix type declaration for sharding specs.
PiperOrigin-RevId: 404313123
Change-Id: Iadeef1f4a996566f4b6e02f850f4932d4f4265d9
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index e1e38ff..a03ea11 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -29,6 +29,8 @@
],
)
+exports_files(["xla_client.pyi"])
+
pyx_library(
name = "custom_call_for_test",
testonly = True,
diff --git a/tensorflow/compiler/xla/python/xla_extension/pmap_lib.pyi b/tensorflow/compiler/xla/python/xla_extension/pmap_lib.pyi
index cf8541b..e0180d2 100644
--- a/tensorflow/compiler/xla/python/xla_extension/pmap_lib.pyi
+++ b/tensorflow/compiler/xla/python/xla_extension/pmap_lib.pyi
@@ -58,7 +58,7 @@
sharding: Iterable[_AvalDimSharding],
mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ...
@property
- def sharding(self) -> Tuple[_AvalDimSharding]: ...
+ def sharding(self) -> Tuple[_AvalDimSharding, ...]: ...
@property
def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ...
def __eq__(self, __other: ShardingSpec) -> bool: ...