Improve pytype checking of XLA types inside JAX.

Add an explicit `.pyi` file for jax/_src/lib/__init__.pyi, which works around a bug in pytype where the types of modules that are re-exported becomes `Any`.

[XLA:Python] Fix type declaration for sharding specs.

PiperOrigin-RevId: 404313123
Change-Id: Iadeef1f4a996566f4b6e02f850f4932d4f4265d9
diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD
index e1e38ff..a03ea11 100644
--- a/tensorflow/compiler/xla/python/BUILD
+++ b/tensorflow/compiler/xla/python/BUILD
@@ -29,6 +29,8 @@
     ],
 )
 
+exports_files(["xla_client.pyi"])
+
 pyx_library(
     name = "custom_call_for_test",
     testonly = True,
diff --git a/tensorflow/compiler/xla/python/xla_extension/pmap_lib.pyi b/tensorflow/compiler/xla/python/xla_extension/pmap_lib.pyi
index cf8541b..e0180d2 100644
--- a/tensorflow/compiler/xla/python/xla_extension/pmap_lib.pyi
+++ b/tensorflow/compiler/xla/python/xla_extension/pmap_lib.pyi
@@ -58,7 +58,7 @@
                sharding: Iterable[_AvalDimSharding],
                mesh_mapping: Iterable[_MeshDimAssignment]) -> None: ...
   @property
-  def sharding(self) -> Tuple[_AvalDimSharding]: ...
+  def sharding(self) -> Tuple[_AvalDimSharding, ...]: ...
   @property
   def mesh_mapping(self) -> Tuple[_MeshDimAssignment]: ...
   def __eq__(self, __other: ShardingSpec) -> bool: ...