Treat "test_xla_gpu" as GPU_TEST in "NamedGPUCombination".
PiperOrigin-RevId: 309276564
Change-Id: Iea914898c6a090c8397d0ad0890e239948eec943
diff --git a/tensorflow/python/distribute/combinations.py b/tensorflow/python/distribute/combinations.py
index ffa03ee..9a479a3 100644
--- a/tensorflow/python/distribute/combinations.py
+++ b/tensorflow/python/distribute/combinations.py
@@ -22,6 +22,7 @@
from __future__ import division
from __future__ import print_function
+import re
import sys
import types
import unittest
@@ -94,10 +95,10 @@
Attributes:
GPU_TEST: The environment is considered to have GPU hardware available if
- the name of the program contains "test_gpu".
+ the name of the program contains "test_gpu" or "test_xla_gpu".
"""
- GPU_TEST = "test_gpu" in sys.argv[0]
+ GPU_TEST = re.search(r"(test_gpu|test_xla_gpu)$", sys.argv[0])
def should_execute_combination(self, kwargs):
distributions = [